use std::time::Duration;
use crate::cancellation::CancellationToken;
use awaken_contract::contract::content::ContentBlock;
use awaken_contract::contract::event::AgentEvent;
use awaken_contract::contract::event_sink::EventSink;
use awaken_contract::contract::executor::{
InFlightTool, InferenceExecutionError, InferenceRequest, InterruptCause, InterruptSnapshot,
LlmStreamEvent, RecoveryPlan,
};
use awaken_contract::contract::inference::{StopReason, StreamResult, TokenUsage};
use awaken_contract::contract::message::{Message, ToolCall};
use awaken_contract::contract::stream_checkpoint::{StreamCheckpoint, StreamCheckpointStore};
use futures::StreamExt;
use super::{AgentLoopError, now_ms};
use crate::engine::retry::LlmRetryPolicy;
use crate::registry::ResolvedAgent;
pub(super) struct CheckpointHandle<'a> {
pub store: &'a dyn StreamCheckpointStore,
pub run_id: &'a str,
pub thread_id: &'a str,
}
const CHECKPOINT_FLUSH_DELTAS: usize = 4;
async fn flush_checkpoint(
acc: &StreamingAccumulator,
upstream_model: &str,
handle: &CheckpointHandle<'_>,
) {
let snapshot = acc.interrupt_snapshot();
let checkpoint = StreamCheckpoint {
run_id: handle.run_id.to_string(),
thread_id: handle.thread_id.to_string(),
upstream_model: upstream_model.to_string(),
partial_text: snapshot.text.clone().unwrap_or_default(),
completed_tool_calls: snapshot.completed_tool_calls,
in_flight_tool: snapshot.in_flight_tool,
updated_at_ms: now_ms(),
};
if let Err(e) = handle.store.put(checkpoint).await {
tracing::warn!(
run_id = %handle.run_id,
error = %e,
"stream checkpoint flush failed — continuing without persistence",
);
}
}
const CONTINUATION_PROMPT: &str =
"Your previous response was interrupted mid-stream. Continue from where you left off.";
pub(super) async fn execute_streaming(
agent: &ResolvedAgent,
mut request: InferenceRequest,
sink: &dyn EventSink,
cancellation_token: Option<&CancellationToken>,
total_input_tokens: &mut u64,
total_output_tokens: &mut u64,
checkpoint: Option<CheckpointHandle<'_>>,
) -> Result<(StreamResult, Option<InFlightTool>), AgentLoopError> {
let policy = stream_retry_policy_for(agent);
let idle_timeout = idle_timeout_for(&request, &policy);
let max_retries = policy.max_stream_retries;
let mut attempt: u32 = 0;
let mut pending_resume: Option<DriveOutcome> = None;
if let Some(handle) = checkpoint.as_ref()
&& let Some(saved) = read_checkpoint(handle).await
{
let snapshot = InterruptSnapshot::from_partials(
(!saved.partial_text.is_empty()).then(|| saved.partial_text.clone()),
saved
.completed_tool_calls
.into_iter()
.map(|c| {
(
c.id,
c.name,
serde_json::to_string(&c.arguments).unwrap_or_default(),
)
})
.chain(
saved
.in_flight_tool
.into_iter()
.map(|p| (p.id, p.name, p.partial_args)),
),
saved.partial_text.len(),
);
pending_resume = Some(DriveOutcome::Interrupted {
cause: InterruptCause::ResumedFromCheckpoint,
snapshot,
});
}
loop {
let outcome = match pending_resume.take() {
Some(o) => o,
None => {
drive_one_stream(
agent,
request.clone(),
sink,
cancellation_token,
total_input_tokens,
total_output_tokens,
idle_timeout,
checkpoint.as_ref(),
)
.await
}
};
match outcome {
DriveOutcome::Completed(result) | DriveOutcome::Cancelled(result) => {
if let Some(handle) = checkpoint.as_ref() {
let _ = handle.store.delete(handle.run_id).await;
}
return Ok((result, None));
}
DriveOutcome::Error(err) => return Err(err),
DriveOutcome::Interrupted { cause, snapshot } => {
let counts_against_budget = !matches!(cause, InterruptCause::ResumedFromCheckpoint);
if counts_against_budget && attempt >= max_retries {
tracing::warn!(
attempts = attempt,
cause = %cause,
bytes_received = snapshot.bytes_received,
"stream retry budget exhausted; surfacing StreamInterrupted",
);
return Err(AgentLoopError::from(
InferenceExecutionError::StreamInterrupted {
cause,
snapshot: Box::new(snapshot),
},
));
}
match apply_recovery_plan(&mut request, sink, &cause, &snapshot).await {
RecoveryOutcome::SynthesizedToolUse { result, hint } => {
if let Some(handle) = checkpoint.as_ref() {
let _ = handle.store.delete(handle.run_id).await;
}
return Ok((result, hint));
}
RecoveryOutcome::RetryAfterPlan => {
if counts_against_budget {
let delay = stream_retry_backoff(&cause, attempt, &policy);
if !delay.is_zero() {
if let Some(token) = cancellation_token {
tokio::select! {
biased;
_ = token.cancelled() => {
return Err(AgentLoopError::from(
InferenceExecutionError::Cancelled,
));
}
_ = tokio::time::sleep(delay) => {}
}
} else {
tokio::time::sleep(delay).await;
}
}
attempt += 1;
}
continue;
}
}
}
}
}
}
async fn read_checkpoint(handle: &CheckpointHandle<'_>) -> Option<StreamCheckpoint> {
match handle.store.get(handle.run_id).await {
Ok(Some(saved)) => {
tracing::info!(
run_id = %handle.run_id,
partial_text_len = saved.partial_text.len(),
completed_tools = saved.completed_tool_calls.len(),
has_in_flight = saved.in_flight_tool.is_some(),
"restoring stream checkpoint"
);
Some(saved)
}
Ok(None) => None,
Err(e) => {
tracing::warn!(
run_id = %handle.run_id,
error = %e,
"checkpoint read failed; continuing without restore"
);
None
}
}
}
enum DriveOutcome {
Completed(StreamResult),
Cancelled(StreamResult),
Interrupted {
cause: InterruptCause,
snapshot: InterruptSnapshot,
},
Error(AgentLoopError),
}
enum RecoveryOutcome {
SynthesizedToolUse {
result: StreamResult,
hint: Option<InFlightTool>,
},
RetryAfterPlan,
}
async fn apply_recovery_plan(
request: &mut InferenceRequest,
sink: &dyn EventSink,
cause: &InterruptCause,
snapshot: &InterruptSnapshot,
) -> RecoveryOutcome {
match snapshot.plan() {
RecoveryPlan::ContinueText { assistant_prefix } => {
push_continuation(request, assistant_prefix);
RecoveryOutcome::RetryAfterPlan
}
RecoveryPlan::SynthesizeToolUse {
completed,
cancelled_tool_hint,
} => {
if let Some(hint) = &cancelled_tool_hint {
sink.emit(AgentEvent::ToolCallCancel {
id: hint.id.clone(),
name: hint.name.clone(),
reason: cause.to_string(),
})
.await;
}
for call in &completed {
sink.emit(AgentEvent::ToolCallReady {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
})
.await;
}
let content = match snapshot.text.as_ref() {
Some(t) if !t.is_empty() => vec![ContentBlock::text(t.clone())],
_ => Vec::new(),
};
RecoveryOutcome::SynthesizedToolUse {
result: StreamResult {
content,
tool_calls: completed,
usage: None,
stop_reason: Some(StopReason::ToolUse),
has_incomplete_tool_calls: false,
},
hint: cancelled_tool_hint,
}
}
RecoveryPlan::TruncateBeforeTool {
assistant_prefix,
cancelled_tool_id,
cancelled_tool_name,
} => {
sink.emit(AgentEvent::ToolCallCancel {
id: cancelled_tool_id,
name: cancelled_tool_name,
reason: cause.to_string(),
})
.await;
push_continuation(request, assistant_prefix);
RecoveryOutcome::RetryAfterPlan
}
RecoveryPlan::WholeRestart => {
sink.emit(AgentEvent::StreamReset {
reason: cause.to_string(),
})
.await;
RecoveryOutcome::RetryAfterPlan
}
}
}
fn push_continuation(request: &mut InferenceRequest, assistant_prefix: String) {
if !assistant_prefix.is_empty() {
request.messages.push(Message::assistant(assistant_prefix));
}
request.messages.push(Message::user(CONTINUATION_PROMPT));
}
async fn drive_one_stream(
agent: &ResolvedAgent,
request: InferenceRequest,
sink: &dyn EventSink,
cancellation_token: Option<&CancellationToken>,
total_input_tokens: &mut u64,
total_output_tokens: &mut u64,
idle_timeout: Duration,
checkpoint: Option<&CheckpointHandle<'_>>,
) -> DriveOutcome {
let upstream_model = request.upstream_model.clone();
let mut token_stream = match agent.llm_executor.execute_stream(request).await {
Ok(s) => s,
Err(err) => {
return DriveOutcome::Error(AgentLoopError::from(err));
}
};
let mut acc = StreamingAccumulator::default();
let mut deltas_since_last_flush: usize = 0;
loop {
let next_fut = async { tokio::time::timeout(idle_timeout, token_stream.next()).await };
let event = if let Some(token) = cancellation_token {
tokio::select! {
biased;
_ = token.cancelled() => {
acc.cancelled = true;
break;
}
r = next_fut => r,
}
} else {
next_fut.await
};
let poll = match event {
Ok(p) => p,
Err(_) => {
if let Some(handle) = checkpoint {
flush_checkpoint(&acc, &upstream_model, handle).await;
}
let snapshot = acc.interrupt_snapshot();
return DriveOutcome::Interrupted {
cause: InterruptCause::IdleStall,
snapshot,
};
}
};
let Some(event_result) = poll else {
break; };
let event = match event_result {
Ok(ev) => ev,
Err(err) => {
if let Some(handle) = checkpoint {
flush_checkpoint(&acc, &upstream_model, handle).await;
}
let snapshot = acc.interrupt_snapshot();
match classify_mid_stream(&err) {
Some(cause) => {
tracing::debug!(
cause = %cause,
bytes_received = snapshot.bytes_received,
"mid-stream error captured, entering recovery"
);
return DriveOutcome::Interrupted { cause, snapshot };
}
None => return DriveOutcome::Error(AgentLoopError::from(err)),
}
}
};
let mut saw_delta = false;
match event {
LlmStreamEvent::TextDelta(delta) => {
saw_delta = true;
acc.current_text.push_str(&delta);
sink.emit(AgentEvent::TextDelta { delta }).await;
}
LlmStreamEvent::ReasoningDelta(delta) => {
sink.emit(AgentEvent::ReasoningDelta { delta }).await;
}
LlmStreamEvent::ToolCallStart { id, name } => {
saw_delta = true;
sink.emit(AgentEvent::ToolCallStart {
id: id.clone(),
name: name.clone(),
})
.await;
acc.tool_names.insert(id.clone(), name);
acc.current_tool_args.insert(id.clone(), String::new());
acc.tool_order.push(id);
}
LlmStreamEvent::ToolCallDelta { id, args_delta } => {
saw_delta = true;
if let Some(buf) = acc.current_tool_args.get_mut(&id) {
buf.push_str(&args_delta);
}
sink.emit(AgentEvent::ToolCallDelta { id, args_delta })
.await;
}
LlmStreamEvent::ContentBlockStop => {
if !acc.current_text.is_empty() {
acc.content_blocks
.push(ContentBlock::text(std::mem::take(&mut acc.current_text)));
}
}
LlmStreamEvent::Usage(u) => {
if let Some(v) = u.prompt_tokens {
*total_input_tokens = total_input_tokens.saturating_add(v.max(0) as u64);
}
if let Some(v) = u.completion_tokens {
*total_output_tokens = total_output_tokens.saturating_add(v.max(0) as u64);
}
acc.usage = Some(u);
}
LlmStreamEvent::Stop(reason) => {
acc.stop_reason = Some(reason);
}
}
if saw_delta {
deltas_since_last_flush += 1;
if deltas_since_last_flush >= CHECKPOINT_FLUSH_DELTAS {
deltas_since_last_flush = 0;
if let Some(handle) = checkpoint {
flush_checkpoint(&acc, &upstream_model, handle).await;
}
}
}
}
let result = acc.finalize(sink).await;
if acc.cancelled {
DriveOutcome::Cancelled(result)
} else {
DriveOutcome::Completed(result)
}
}
#[derive(Default)]
struct StreamingAccumulator {
content_blocks: Vec<ContentBlock>,
usage: Option<TokenUsage>,
stop_reason: Option<StopReason>,
current_text: String,
current_tool_args: std::collections::HashMap<String, String>,
tool_names: std::collections::HashMap<String, String>,
tool_order: Vec<String>,
bytes_received: usize,
cancelled: bool,
}
impl StreamingAccumulator {
fn interrupt_snapshot(&self) -> InterruptSnapshot {
let text = if self.current_text.is_empty() {
self.content_blocks
.iter()
.filter_map(|b| match b {
ContentBlock::Text { text } if !text.is_empty() => Some(text.clone()),
_ => None,
})
.reduce(|a, b| a + &b)
} else {
Some(self.current_text.clone())
};
let partials = self.tool_order.iter().map(|id| {
(
id.clone(),
self.tool_names.get(id).cloned().unwrap_or_default(),
self.current_tool_args.get(id).cloned().unwrap_or_default(),
)
});
InterruptSnapshot::from_partials(text, partials, self.bytes_received)
}
async fn finalize(&mut self, sink: &dyn EventSink) -> StreamResult {
if !self.current_text.is_empty() {
self.content_blocks
.push(ContentBlock::text(std::mem::take(&mut self.current_text)));
}
let mut tool_calls = Vec::new();
let mut has_incomplete_tool_calls = false;
if !self.cancelled {
for id in &self.tool_order {
let args_json = self.current_tool_args.get(id).cloned().unwrap_or_default();
let name = self.tool_names.get(id).cloned().unwrap_or_default();
let arguments = serde_json::from_str(&args_json).unwrap_or(serde_json::Value::Null);
if arguments.is_null() && !args_json.is_empty() {
has_incomplete_tool_calls = true;
continue;
}
tool_calls.push(ToolCall::new(id.clone(), name.clone(), arguments.clone()));
sink.emit(AgentEvent::ToolCallReady {
id: id.clone(),
name,
arguments,
})
.await;
}
}
StreamResult {
content: std::mem::take(&mut self.content_blocks),
tool_calls,
usage: self.usage.take(),
stop_reason: if self.cancelled {
Some(StopReason::EndTurn)
} else {
self.stop_reason.take()
},
has_incomplete_tool_calls,
}
}
}
fn classify_mid_stream(err: &InferenceExecutionError) -> Option<InterruptCause> {
match err {
InferenceExecutionError::Provider(msg) | InferenceExecutionError::Timeout(msg) => {
Some(interpret_transport_message(msg))
}
InferenceExecutionError::RateLimited { message, .. }
| InferenceExecutionError::Overloaded { message, .. } => {
Some(interpret_transport_message(message))
}
InferenceExecutionError::StreamInterrupted { cause, .. } => Some(cause.clone()),
_ => None,
}
}
fn interpret_transport_message(msg: &str) -> InterruptCause {
let lower = msg.to_lowercase();
if lower.contains("goaway")
|| lower.contains("go_away")
|| lower.contains("http/2 going away")
|| lower.contains("connection: close")
{
InterruptCause::GoAway
} else if lower.contains("connection reset") || lower.contains("econnreset") {
InterruptCause::ConnectionReset
} else if lower.starts_with("502")
|| lower.starts_with("503")
|| lower.contains("502 bad gateway")
|| lower.contains("503 service unavailable")
{
InterruptCause::Provider5xxMidStream(503)
} else {
InterruptCause::ConnectionReset
}
}
fn stream_retry_policy_for(_agent: &ResolvedAgent) -> LlmRetryPolicy {
LlmRetryPolicy::default()
}
fn idle_timeout_for(request: &InferenceRequest, policy: &LlmRetryPolicy) -> Duration {
let base = Duration::from_secs(policy.stream_idle_timeout_secs);
let model = request.upstream_model.as_str();
let name_hits = model.contains("thinking")
|| model.contains("reasoning")
|| model.starts_with("o1")
|| model.starts_with("o3")
|| model.starts_with("o4");
let options_hits = request
.overrides
.as_ref()
.and_then(|o| o.reasoning_effort.as_ref())
.is_some();
if name_hits || options_hits {
base * 2
} else {
base
}
}
fn stream_retry_backoff(cause: &InterruptCause, attempt: u32, policy: &LlmRetryPolicy) -> Duration {
match cause {
InterruptCause::IdleStall => Duration::from_millis(200),
_ => policy.delay_before_retry(
&InferenceExecutionError::Provider("mid-stream".into()),
attempt,
),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use crate::cancellation::CancellationToken;
use crate::registry::ResolvedAgent;
use async_trait::async_trait;
use awaken_contract::contract::content::ContentBlock;
use awaken_contract::contract::event::AgentEvent;
use awaken_contract::contract::event_sink::VecEventSink;
use awaken_contract::contract::executor::{
InferenceExecutionError, InferenceRequest, InferenceStream, LlmStreamEvent,
};
use awaken_contract::contract::inference::{StopReason, StreamResult, TokenUsage};
use awaken_contract::contract::message::Message;
struct ScriptedPerAttemptExecutor {
scripts: Vec<Vec<Result<LlmStreamEvent, InferenceExecutionError>>>,
attempt: std::sync::atomic::AtomicUsize,
}
impl ScriptedPerAttemptExecutor {
fn new(scripts: Vec<Vec<Result<LlmStreamEvent, InferenceExecutionError>>>) -> Self {
assert!(!scripts.is_empty(), "need at least one attempt script");
Self {
scripts,
attempt: std::sync::atomic::AtomicUsize::new(0),
}
}
fn attempts(&self) -> usize {
self.attempt.load(std::sync::atomic::Ordering::SeqCst)
}
}
#[async_trait]
impl awaken_contract::contract::executor::LlmExecutor for ScriptedPerAttemptExecutor {
async fn execute(
&self,
_r: InferenceRequest,
) -> Result<StreamResult, InferenceExecutionError> {
Err(InferenceExecutionError::Provider("unused".into()))
}
fn execute_stream(
&self,
_request: InferenceRequest,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<InferenceStream, InferenceExecutionError>>
+ Send
+ '_,
>,
> {
let n = self
.attempt
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
let idx = n.min(self.scripts.len() - 1);
let events = self.scripts[idx].clone();
Box::pin(async move { Ok(Box::pin(futures::stream::iter(events)) as InferenceStream) })
}
fn name(&self) -> &str {
"scripted-per-attempt"
}
}
fn make_agent(events: Vec<Result<LlmStreamEvent, InferenceExecutionError>>) -> ResolvedAgent {
agent_with(Arc::new(ScriptedPerAttemptExecutor::new(vec![events])))
}
fn agent_with(exec: Arc<ScriptedPerAttemptExecutor>) -> ResolvedAgent {
ResolvedAgent::new("test-agent", "test-model", "system prompt", exec)
}
fn make_request() -> InferenceRequest {
InferenceRequest {
upstream_model: "test-model".into(),
messages: vec![Message::user("hello")],
tools: vec![],
system: vec![],
overrides: None,
enable_prompt_cache: false,
}
}
#[tokio::test]
async fn collects_text_deltas_into_content_blocks() {
let agent = make_agent(vec![
Ok(LlmStreamEvent::TextDelta("Hello ".into())),
Ok(LlmStreamEvent::TextDelta("world!".into())),
Ok(LlmStreamEvent::ContentBlockStop),
Ok(LlmStreamEvent::Stop(StopReason::EndTurn)),
]);
let sink = VecEventSink::new();
let mut input_tokens = 0u64;
let mut output_tokens = 0u64;
let (result, _hint) = execute_streaming(
&agent,
make_request(),
&sink,
None,
&mut input_tokens,
&mut output_tokens,
None,
)
.await
.unwrap();
assert_eq!(result.content.len(), 1);
match &result.content[0] {
ContentBlock::Text { text } => assert_eq!(text, "Hello world!"),
other => panic!("expected Text block, got: {other:?}"),
}
assert_eq!(result.stop_reason, Some(StopReason::EndTurn));
}
#[tokio::test]
async fn emits_text_delta_events_to_sink() {
let agent = make_agent(vec![
Ok(LlmStreamEvent::TextDelta("hi".into())),
Ok(LlmStreamEvent::ContentBlockStop),
]);
let sink = VecEventSink::new();
let mut it = 0u64;
let mut ot = 0u64;
execute_streaming(&agent, make_request(), &sink, None, &mut it, &mut ot, None)
.await
.unwrap();
let events = sink.take();
assert!(
events
.iter()
.any(|e| matches!(e, AgentEvent::TextDelta { delta } if delta == "hi")),
"expected TextDelta event in sink"
);
}
#[tokio::test]
async fn accumulates_token_usage() {
let agent = make_agent(vec![
Ok(LlmStreamEvent::Usage(TokenUsage {
prompt_tokens: Some(50),
completion_tokens: Some(25),
total_tokens: Some(75),
..Default::default()
})),
Ok(LlmStreamEvent::Stop(StopReason::EndTurn)),
]);
let sink = VecEventSink::new();
let mut input_tokens = 10u64;
let mut output_tokens = 5u64;
let (result, _hint) = execute_streaming(
&agent,
make_request(),
&sink,
None,
&mut input_tokens,
&mut output_tokens,
None,
)
.await
.unwrap();
assert_eq!(input_tokens, 60); assert_eq!(output_tokens, 30); assert!(result.usage.is_some());
}
#[tokio::test]
async fn token_counting_handles_negative_values() {
let agent = make_agent(vec![Ok(LlmStreamEvent::Usage(TokenUsage {
prompt_tokens: Some(-5),
completion_tokens: Some(-10),
..Default::default()
}))]);
let sink = VecEventSink::new();
let mut input_tokens = 100u64;
let mut output_tokens = 50u64;
execute_streaming(
&agent,
make_request(),
&sink,
None,
&mut input_tokens,
&mut output_tokens,
None,
)
.await
.unwrap();
assert_eq!(input_tokens, 100);
assert_eq!(output_tokens, 50);
}
#[tokio::test]
async fn collects_tool_calls_from_stream() {
let agent = make_agent(vec![
Ok(LlmStreamEvent::ToolCallStart {
id: "tc1".into(),
name: "get_weather".into(),
}),
Ok(LlmStreamEvent::ToolCallDelta {
id: "tc1".into(),
args_delta: r#"{"city":"#.into(),
}),
Ok(LlmStreamEvent::ToolCallDelta {
id: "tc1".into(),
args_delta: r#""NYC"}"#.into(),
}),
Ok(LlmStreamEvent::ContentBlockStop),
Ok(LlmStreamEvent::Stop(StopReason::ToolUse)),
]);
let sink = VecEventSink::new();
let mut it = 0u64;
let mut ot = 0u64;
let (result, _hint) =
execute_streaming(&agent, make_request(), &sink, None, &mut it, &mut ot, None)
.await
.unwrap();
assert_eq!(result.tool_calls.len(), 1);
assert_eq!(result.tool_calls[0].name, "get_weather");
assert_eq!(result.tool_calls[0].arguments["city"], "NYC");
assert!(!result.has_incomplete_tool_calls);
}
#[tokio::test]
async fn emits_tool_call_start_and_delta_events() {
let agent = make_agent(vec![
Ok(LlmStreamEvent::ToolCallStart {
id: "tc1".into(),
name: "search".into(),
}),
Ok(LlmStreamEvent::ToolCallDelta {
id: "tc1".into(),
args_delta: r#"{"q":"test"}"#.into(),
}),
Ok(LlmStreamEvent::Stop(StopReason::ToolUse)),
]);
let sink = VecEventSink::new();
let mut it = 0u64;
let mut ot = 0u64;
execute_streaming(&agent, make_request(), &sink, None, &mut it, &mut ot, None)
.await
.unwrap();
let events = sink.take();
assert!(events.iter().any(|e| matches!(
e,
AgentEvent::ToolCallStart { id, name } if id == "tc1" && name == "search"
)));
assert!(
events
.iter()
.any(|e| matches!(e, AgentEvent::ToolCallDelta { id, .. } if id == "tc1"))
);
}
#[tokio::test]
async fn truncated_tool_call_json_marked_incomplete() {
let agent = make_agent(vec![
Ok(LlmStreamEvent::ToolCallStart {
id: "tc1".into(),
name: "fetch".into(),
}),
Ok(LlmStreamEvent::ToolCallDelta {
id: "tc1".into(),
args_delta: r#"{"url":"https://exam"#.into(), }),
Ok(LlmStreamEvent::Stop(StopReason::MaxTokens)),
]);
let sink = VecEventSink::new();
let mut it = 0u64;
let mut ot = 0u64;
let (result, _hint) =
execute_streaming(&agent, make_request(), &sink, None, &mut it, &mut ot, None)
.await
.unwrap();
assert!(result.tool_calls.is_empty());
assert!(result.has_incomplete_tool_calls);
}
#[tokio::test]
async fn multiple_tool_calls_preserve_declaration_order() {
let agent = make_agent(vec![
Ok(LlmStreamEvent::ToolCallStart {
id: "tc1".into(),
name: "tool_a".into(),
}),
Ok(LlmStreamEvent::ToolCallDelta {
id: "tc1".into(),
args_delta: "{}".into(),
}),
Ok(LlmStreamEvent::ToolCallStart {
id: "tc2".into(),
name: "tool_b".into(),
}),
Ok(LlmStreamEvent::ToolCallDelta {
id: "tc2".into(),
args_delta: r#"{"x":1}"#.into(),
}),
Ok(LlmStreamEvent::Stop(StopReason::ToolUse)),
]);
let sink = VecEventSink::new();
let mut it = 0u64;
let mut ot = 0u64;
let (result, _hint) =
execute_streaming(&agent, make_request(), &sink, None, &mut it, &mut ot, None)
.await
.unwrap();
assert_eq!(result.tool_calls.len(), 2);
assert_eq!(result.tool_calls[0].name, "tool_a");
assert_eq!(result.tool_calls[1].name, "tool_b");
}
#[tokio::test]
async fn cancellation_returns_end_turn_and_drops_tool_calls() {
let agent = make_agent(vec![
Ok(LlmStreamEvent::TextDelta("partial ".into())),
Ok(LlmStreamEvent::ToolCallStart {
id: "tc1".into(),
name: "my_tool".into(),
}),
Ok(LlmStreamEvent::ToolCallDelta {
id: "tc1".into(),
args_delta: r#"{"key":"value"}"#.into(),
}),
Ok(LlmStreamEvent::Stop(StopReason::ToolUse)),
]);
let token = CancellationToken::new();
token.cancel();
let sink = VecEventSink::new();
let mut it = 0u64;
let mut ot = 0u64;
let (result, _hint) = execute_streaming(
&agent,
make_request(),
&sink,
Some(&token),
&mut it,
&mut ot,
None,
)
.await
.unwrap();
assert_eq!(result.stop_reason, Some(StopReason::EndTurn));
assert!(result.tool_calls.is_empty());
}
#[tokio::test]
async fn no_cancellation_token_processes_full_stream() {
let agent = make_agent(vec![
Ok(LlmStreamEvent::TextDelta("complete".into())),
Ok(LlmStreamEvent::ContentBlockStop),
Ok(LlmStreamEvent::Stop(StopReason::EndTurn)),
]);
let sink = VecEventSink::new();
let mut it = 0u64;
let mut ot = 0u64;
let (result, _hint) =
execute_streaming(&agent, make_request(), &sink, None, &mut it, &mut ot, None)
.await
.unwrap();
assert_eq!(result.content.len(), 1);
assert_eq!(result.stop_reason, Some(StopReason::EndTurn));
}
#[tokio::test]
async fn reasoning_deltas_emitted_to_sink() {
let agent = make_agent(vec![
Ok(LlmStreamEvent::ReasoningDelta("thinking...".into())),
Ok(LlmStreamEvent::TextDelta("answer".into())),
Ok(LlmStreamEvent::ContentBlockStop),
Ok(LlmStreamEvent::Stop(StopReason::EndTurn)),
]);
let sink = VecEventSink::new();
let mut it = 0u64;
let mut ot = 0u64;
execute_streaming(&agent, make_request(), &sink, None, &mut it, &mut ot, None)
.await
.unwrap();
let events = sink.take();
assert!(events.iter().any(|e| matches!(
e,
AgentEvent::ReasoningDelta { delta } if delta == "thinking..."
)));
}
#[tokio::test]
async fn empty_stream_returns_empty_result() {
let agent = make_agent(vec![]);
let sink = VecEventSink::new();
let mut it = 0u64;
let mut ot = 0u64;
let (result, _hint) =
execute_streaming(&agent, make_request(), &sink, None, &mut it, &mut ot, None)
.await
.unwrap();
assert!(result.content.is_empty());
assert!(result.tool_calls.is_empty());
assert!(result.usage.is_none());
assert!(result.stop_reason.is_none());
}
#[tokio::test]
async fn flushes_remaining_text_at_end_of_stream() {
let agent = make_agent(vec![
Ok(LlmStreamEvent::TextDelta("no block stop".into())),
Ok(LlmStreamEvent::Stop(StopReason::EndTurn)),
]);
let sink = VecEventSink::new();
let mut it = 0u64;
let mut ot = 0u64;
let (result, _hint) =
execute_streaming(&agent, make_request(), &sink, None, &mut it, &mut ot, None)
.await
.unwrap();
assert_eq!(result.content.len(), 1);
match &result.content[0] {
ContentBlock::Text { text } => assert_eq!(text, "no block stop"),
other => panic!("expected Text, got: {other:?}"),
}
}
#[tokio::test]
async fn stream_error_propagated_as_agent_loop_error() {
let agent = make_agent(vec![
Ok(LlmStreamEvent::TextDelta("before error".into())),
Err(InferenceExecutionError::Provider("rate limited".into())),
]);
let sink = VecEventSink::new();
let mut it = 0u64;
let mut ot = 0u64;
let err = execute_streaming(&agent, make_request(), &sink, None, &mut it, &mut ot, None)
.await
.unwrap_err();
match err {
AgentLoopError::InferenceFailed(msg) => {
assert!(
msg.contains("stream interrupted"),
"expected stream-interrupt message, got: {msg}"
);
}
other => panic!("expected InferenceFailed, got: {other:?}"),
}
}
#[tokio::test]
async fn emits_tool_call_ready_event_for_complete_tool() {
let agent = make_agent(vec![
Ok(LlmStreamEvent::ToolCallStart {
id: "tc1".into(),
name: "calculator".into(),
}),
Ok(LlmStreamEvent::ToolCallDelta {
id: "tc1".into(),
args_delta: r#"{"expr":"1+1"}"#.into(),
}),
Ok(LlmStreamEvent::Stop(StopReason::ToolUse)),
]);
let sink = VecEventSink::new();
let mut it = 0u64;
let mut ot = 0u64;
execute_streaming(&agent, make_request(), &sink, None, &mut it, &mut ot, None)
.await
.unwrap();
let events = sink.take();
assert!(events.iter().any(|e| matches!(
e,
AgentEvent::ToolCallReady { id, name, .. } if id == "tc1" && name == "calculator"
)));
}
struct FailAfterNEventsExecutor {
events_before_fail: usize,
}
#[async_trait]
impl awaken_contract::contract::executor::LlmExecutor for FailAfterNEventsExecutor {
async fn execute(
&self,
_request: InferenceRequest,
) -> Result<StreamResult, InferenceExecutionError> {
Err(InferenceExecutionError::Provider("not implemented".into()))
}
fn execute_stream(
&self,
_request: InferenceRequest,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<InferenceStream, InferenceExecutionError>>
+ Send
+ '_,
>,
> {
let n = self.events_before_fail;
Box::pin(async move {
let mut events: Vec<Result<LlmStreamEvent, InferenceExecutionError>> = Vec::new();
for i in 0..n {
events.push(Ok(LlmStreamEvent::TextDelta(format!("chunk-{i}"))));
}
events.push(Err(InferenceExecutionError::Provider(
"injected mid-stream failure".into(),
)));
Ok(Box::pin(futures::stream::iter(events)) as InferenceStream)
})
}
fn name(&self) -> &str {
"fail-after-n"
}
}
fn make_failing_agent(events_before_fail: usize) -> ResolvedAgent {
ResolvedAgent::new(
"test-agent",
"test-model",
"system prompt",
Arc::new(FailAfterNEventsExecutor { events_before_fail }),
)
}
#[tokio::test]
async fn error_after_zero_events_returns_inference_failed() {
let agent = make_failing_agent(0);
let sink = VecEventSink::new();
let mut it = 0u64;
let mut ot = 0u64;
let err = execute_streaming(&agent, make_request(), &sink, None, &mut it, &mut ot, None)
.await
.unwrap_err();
match err {
AgentLoopError::InferenceFailed(msg) => {
assert!(
msg.contains("stream interrupted"),
"expected stream-interrupt message, got: {msg}"
);
}
other => panic!("expected InferenceFailed, got: {other:?}"),
}
let events = sink.take();
assert!(
events
.iter()
.any(|e| matches!(e, AgentEvent::StreamReset { .. })),
"expected at least one StreamReset event, got: {events:?}"
);
}
#[tokio::test]
async fn error_after_n_events_emits_partial_deltas_then_fails() {
let agent = make_failing_agent(3);
let sink = VecEventSink::new();
let mut it = 0u64;
let mut ot = 0u64;
let err = execute_streaming(&agent, make_request(), &sink, None, &mut it, &mut ot, None)
.await
.unwrap_err();
assert!(matches!(err, AgentLoopError::InferenceFailed(_)));
let events = sink.take();
let text_deltas: Vec<_> = events
.iter()
.filter(|e| matches!(e, AgentEvent::TextDelta { .. }))
.collect();
assert!(
text_deltas.len() >= 3,
"expected >=3 text deltas (with possible retries), got {}",
text_deltas.len()
);
}
struct ImmediateStreamFailExecutor;
#[async_trait]
impl awaken_contract::contract::executor::LlmExecutor for ImmediateStreamFailExecutor {
async fn execute(
&self,
_request: InferenceRequest,
) -> Result<StreamResult, InferenceExecutionError> {
Err(InferenceExecutionError::Provider("execute failed".into()))
}
fn execute_stream(
&self,
_request: InferenceRequest,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<InferenceStream, InferenceExecutionError>>
+ Send
+ '_,
>,
> {
Box::pin(async move {
Err(InferenceExecutionError::Provider(
"stream creation failed".into(),
))
})
}
fn name(&self) -> &str {
"immediate-fail"
}
}
#[tokio::test]
async fn executor_stream_creation_failure_surfaces_as_error() {
let agent = ResolvedAgent::new(
"test-agent",
"test-model",
"system prompt",
Arc::new(ImmediateStreamFailExecutor),
);
let sink = VecEventSink::new();
let mut it = 0u64;
let mut ot = 0u64;
let err = execute_streaming(&agent, make_request(), &sink, None, &mut it, &mut ot, None)
.await
.unwrap_err();
match err {
AgentLoopError::InferenceFailed(msg) => {
assert!(msg.contains("stream creation failed"));
}
other => panic!("expected InferenceFailed, got: {other:?}"),
}
}
#[tokio::test]
async fn rate_limited_error_surfaces_correctly() {
let agent = make_agent(vec![Err(InferenceExecutionError::rate_limited(
"429 too many requests",
))]);
let sink = VecEventSink::new();
let mut it = 0u64;
let mut ot = 0u64;
let err = execute_streaming(&agent, make_request(), &sink, None, &mut it, &mut ot, None)
.await
.unwrap_err();
match err {
AgentLoopError::InferenceFailed(msg) => {
assert!(
msg.contains("stream interrupted"),
"expected stream-interrupt message, got: {msg}"
);
}
other => panic!("expected InferenceFailed, got: {other:?}"),
}
}
#[tokio::test]
async fn timeout_error_surfaces_correctly() {
let agent = make_agent(vec![Err(InferenceExecutionError::Timeout(
"30s exceeded".into(),
))]);
let sink = VecEventSink::new();
let mut it = 0u64;
let mut ot = 0u64;
let err = execute_streaming(&agent, make_request(), &sink, None, &mut it, &mut ot, None)
.await
.unwrap_err();
match err {
AgentLoopError::InferenceFailed(msg) => {
assert!(
msg.contains("stream interrupted"),
"expected stream-interrupt message, got: {msg}"
);
let _ = "30s exceeded"; }
other => panic!("expected InferenceFailed, got: {other:?}"),
}
}
struct HangingExecutor;
#[async_trait]
impl awaken_contract::contract::executor::LlmExecutor for HangingExecutor {
async fn execute(
&self,
_request: InferenceRequest,
) -> Result<StreamResult, InferenceExecutionError> {
std::future::pending::<()>().await;
unreachable!()
}
fn execute_stream(
&self,
_request: InferenceRequest,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<InferenceStream, InferenceExecutionError>>
+ Send
+ '_,
>,
> {
Box::pin(async move {
let stream = futures::stream::pending();
Ok(Box::pin(stream) as InferenceStream)
})
}
fn name(&self) -> &str {
"hanging"
}
}
#[tokio::test(start_paused = true)]
async fn hanging_executor_is_caught_by_cancellation_token() {
let agent = ResolvedAgent::new(
"test-agent",
"test-model",
"system prompt",
Arc::new(HangingExecutor),
);
let sink = VecEventSink::new();
let mut it = 0u64;
let mut ot = 0u64;
let token = CancellationToken::new();
let token_clone = token.clone();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
token_clone.cancel();
});
let (result, _hint) = execute_streaming(
&agent,
make_request(),
&sink,
Some(&token),
&mut it,
&mut ot,
None,
)
.await
.unwrap();
assert_eq!(result.stop_reason, Some(StopReason::EndTurn));
assert!(result.content.is_empty());
assert!(result.tool_calls.is_empty());
}
#[tokio::test]
async fn error_mid_tool_call_returns_inference_error() {
let agent = make_agent(vec![
Ok(LlmStreamEvent::ToolCallStart {
id: "tc1".into(),
name: "search".into(),
}),
Ok(LlmStreamEvent::ToolCallDelta {
id: "tc1".into(),
args_delta: r#"{"q":"partial"#.into(),
}),
Err(InferenceExecutionError::Provider("connection reset".into())),
]);
let sink = VecEventSink::new();
let mut it = 0u64;
let mut ot = 0u64;
let err = execute_streaming(&agent, make_request(), &sink, None, &mut it, &mut ot, None)
.await
.unwrap_err();
match err {
AgentLoopError::InferenceFailed(msg) => {
assert!(
msg.contains("stream interrupted"),
"expected stream-interrupt message, got: {msg}"
);
}
other => panic!("expected InferenceFailed, got: {other:?}"),
}
let events = sink.take();
assert!(
events
.iter()
.any(|e| matches!(e, AgentEvent::ToolCallStart { .. }))
);
assert!(
events
.iter()
.any(|e| matches!(e, AgentEvent::StreamReset { .. }))
);
}
#[tokio::test]
async fn r1_text_only_interruption_recovers_via_continuation() {
let exec = Arc::new(ScriptedPerAttemptExecutor::new(vec![
vec![
Ok(LlmStreamEvent::TextDelta("Hello, ".into())),
Ok(LlmStreamEvent::TextDelta("this is".into())),
Err(InferenceExecutionError::Provider("connection reset".into())),
],
vec![
Ok(LlmStreamEvent::TextDelta(" the second half.".into())),
Ok(LlmStreamEvent::Stop(StopReason::EndTurn)),
],
]));
let agent = agent_with(exec.clone());
let sink = VecEventSink::new();
let mut it = 0u64;
let mut ot = 0u64;
let (result, _hint) =
execute_streaming(&agent, make_request(), &sink, None, &mut it, &mut ot, None)
.await
.expect("R1 should succeed after one retry");
assert_eq!(exec.attempts(), 2, "expected exactly two attempts");
assert_eq!(result.text(), " the second half.");
assert_eq!(result.stop_reason, Some(StopReason::EndTurn));
let events = sink.take();
assert!(
!events
.iter()
.any(|e| matches!(e, AgentEvent::StreamReset { .. })),
"R1 must not emit StreamReset"
);
assert!(
!events
.iter()
.any(|e| matches!(e, AgentEvent::ToolCallCancel { .. })),
"R1 must not emit ToolCallCancel"
);
}
#[tokio::test]
async fn r2_completed_tool_synthesizes_tool_use_without_another_round_trip() {
let exec = Arc::new(ScriptedPerAttemptExecutor::new(vec![
vec![
Ok(LlmStreamEvent::ToolCallStart {
id: "a".into(),
name: "search".into(),
}),
Ok(LlmStreamEvent::ToolCallDelta {
id: "a".into(),
args_delta: r#"{"q":"rust"}"#.into(),
}),
Ok(LlmStreamEvent::ToolCallStart {
id: "b".into(),
name: "fetch".into(),
}),
Ok(LlmStreamEvent::ToolCallDelta {
id: "b".into(),
args_delta: r#"{"url":"#.into(), }),
Err(InferenceExecutionError::Provider("connection reset".into())),
],
vec![Err(InferenceExecutionError::Provider(
"R2 should not retry".into(),
))],
]));
let agent = agent_with(exec.clone());
let sink = VecEventSink::new();
let mut it = 0u64;
let mut ot = 0u64;
let (result, _hint) =
execute_streaming(&agent, make_request(), &sink, None, &mut it, &mut ot, None)
.await
.expect("R2 short-circuits to synthesized tool_use");
assert_eq!(exec.attempts(), 1, "R2 must not trigger a retry");
assert_eq!(result.stop_reason, Some(StopReason::ToolUse));
assert_eq!(result.tool_calls.len(), 1);
assert_eq!(result.tool_calls[0].id, "a");
assert_eq!(result.tool_calls[0].name, "search");
let events = sink.take();
assert!(
events
.iter()
.any(|e| matches!(e, AgentEvent::ToolCallCancel { id, name, .. }
if id == "b" && name == "fetch")),
"expected ToolCallCancel for the in-flight tool"
);
assert!(
events
.iter()
.any(|e| matches!(e, AgentEvent::ToolCallReady { id, .. } if id == "a")),
"expected ToolCallReady for the completed tool"
);
}
#[tokio::test]
async fn r3_text_plus_partial_tool_truncates_and_continues() {
let exec = Arc::new(ScriptedPerAttemptExecutor::new(vec![
vec![
Ok(LlmStreamEvent::TextDelta("Looking it up: ".into())),
Ok(LlmStreamEvent::ToolCallStart {
id: "t1".into(),
name: "lookup".into(),
}),
Ok(LlmStreamEvent::ToolCallDelta {
id: "t1".into(),
args_delta: r#"{"id":"#.into(),
}),
Err(InferenceExecutionError::Provider("connection reset".into())),
],
vec![
Ok(LlmStreamEvent::TextDelta("done.".into())),
Ok(LlmStreamEvent::Stop(StopReason::EndTurn)),
],
]));
let agent = agent_with(exec.clone());
let sink = VecEventSink::new();
let mut it = 0u64;
let mut ot = 0u64;
let (result, _hint) =
execute_streaming(&agent, make_request(), &sink, None, &mut it, &mut ot, None)
.await
.expect("R3 recovers after truncation");
assert_eq!(exec.attempts(), 2);
assert_eq!(result.text(), "done.");
let events = sink.take();
assert!(
events
.iter()
.any(|e| matches!(e, AgentEvent::ToolCallCancel { id, name, .. }
if id == "t1" && name == "lookup")),
"R3 must emit ToolCallCancel for the unclosed tool"
);
assert!(
!events
.iter()
.any(|e| matches!(e, AgentEvent::StreamReset { .. })),
"R3 must NOT emit StreamReset"
);
}
#[tokio::test]
async fn r4_empty_snapshot_whole_restarts_and_emits_stream_reset() {
let exec = Arc::new(ScriptedPerAttemptExecutor::new(vec![
vec![Err(InferenceExecutionError::Provider("reset".into()))],
vec![
Ok(LlmStreamEvent::TextDelta("fresh start".into())),
Ok(LlmStreamEvent::Stop(StopReason::EndTurn)),
],
]));
let agent = agent_with(exec.clone());
let sink = VecEventSink::new();
let mut it = 0u64;
let mut ot = 0u64;
let (result, _hint) =
execute_streaming(&agent, make_request(), &sink, None, &mut it, &mut ot, None)
.await
.expect("R4 recovers after whole restart");
assert_eq!(exec.attempts(), 2);
assert_eq!(result.text(), "fresh start");
let events = sink.take();
assert!(
events
.iter()
.any(|e| matches!(e, AgentEvent::StreamReset { .. })),
"R4 must emit StreamReset"
);
}
#[tokio::test]
async fn retry_budget_exhausted_surfaces_stream_interrupted() {
let exec = Arc::new(ScriptedPerAttemptExecutor::new(vec![vec![Err(
InferenceExecutionError::Provider("reset".into()),
)]]));
let agent = agent_with(exec.clone());
let sink = VecEventSink::new();
let mut it = 0u64;
let mut ot = 0u64;
let err = execute_streaming(&agent, make_request(), &sink, None, &mut it, &mut ot, None)
.await
.unwrap_err();
assert_eq!(
exec.attempts(),
3,
"expected 1 initial + 2 retries = 3 attempts"
);
match err {
AgentLoopError::InferenceFailed(msg) => {
assert!(
msg.contains("stream interrupted"),
"expected stream-interrupt message, got: {msg}"
);
}
other => panic!("expected InferenceFailed, got: {other:?}"),
}
}
struct StallingExecutor {
attempt: std::sync::atomic::AtomicUsize,
}
#[async_trait]
impl awaken_contract::contract::executor::LlmExecutor for StallingExecutor {
async fn execute(
&self,
_r: InferenceRequest,
) -> Result<StreamResult, InferenceExecutionError> {
Err(InferenceExecutionError::Provider("unused".into()))
}
fn execute_stream(
&self,
_request: InferenceRequest,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<InferenceStream, InferenceExecutionError>>
+ Send
+ '_,
>,
> {
let n = self
.attempt
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Box::pin(async move {
if n == 0 {
let hung = futures::stream::unfold((), |()| async move {
futures::future::pending::<()>().await;
None
});
let prefix: Vec<Result<LlmStreamEvent, InferenceExecutionError>> =
vec![Ok(LlmStreamEvent::TextDelta("partial".into()))];
let combined = futures::stream::iter(prefix)
.chain(hung)
.map(|r: Result<LlmStreamEvent, InferenceExecutionError>| r);
Ok(Box::pin(combined) as InferenceStream)
} else {
let events: Vec<Result<LlmStreamEvent, InferenceExecutionError>> = vec![
Ok(LlmStreamEvent::TextDelta(" done.".into())),
Ok(LlmStreamEvent::Stop(StopReason::EndTurn)),
];
Ok(Box::pin(futures::stream::iter(events)) as InferenceStream)
}
})
}
fn name(&self) -> &str {
"stalling"
}
}
#[tokio::test(start_paused = true)]
async fn idle_stall_triggers_recovery_and_second_attempt_succeeds() {
let exec = Arc::new(StallingExecutor {
attempt: std::sync::atomic::AtomicUsize::new(0),
});
let agent = ResolvedAgent::new("test-agent", "test-model", "system prompt", exec.clone());
let sink = VecEventSink::new();
let mut it = 0u64;
let mut ot = 0u64;
let exec_fut =
execute_streaming(&agent, make_request(), &sink, None, &mut it, &mut ot, None);
let drive = async {
tokio::time::sleep(Duration::from_millis(1)).await;
tokio::time::advance(Duration::from_secs(70)).await;
};
let (outcome, ()) = tokio::join!(exec_fut, drive);
let (result, _hint) = outcome.expect("idle-stall should recover");
assert_eq!(
exec.attempt.load(std::sync::atomic::Ordering::SeqCst),
2,
"expected 2 attempts after stall recovery"
);
assert!(result.text().contains("done"));
}
#[test]
fn idle_timeout_for_doubles_on_thinking_model_names() {
let policy = LlmRetryPolicy::default().with_stream_idle_timeout_secs(30);
let base = Duration::from_secs(30);
let plain = InferenceRequest {
upstream_model: "gpt-4o-mini".into(),
messages: vec![],
tools: vec![],
system: vec![],
overrides: None,
enable_prompt_cache: false,
};
assert_eq!(idle_timeout_for(&plain, &policy), base);
let thinking = InferenceRequest {
upstream_model: "claude-opus-4-7-thinking".into(),
..plain.clone()
};
assert_eq!(idle_timeout_for(&thinking, &policy), base * 2);
let reasoning = InferenceRequest {
upstream_model: "o1-mini".into(),
..plain.clone()
};
assert_eq!(idle_timeout_for(&reasoning, &policy), base * 2);
let o3 = InferenceRequest {
upstream_model: "o3-preview".into(),
..plain.clone()
};
assert_eq!(idle_timeout_for(&o3, &policy), base * 2);
}
#[test]
fn classify_mid_stream_maps_goaway_substring_to_goaway_cause() {
let err = InferenceExecutionError::Provider("HTTP/2 GOAWAY frame received".into());
assert!(matches!(
classify_mid_stream(&err),
Some(InterruptCause::GoAway)
));
}
#[test]
fn classify_mid_stream_maps_connection_reset_substring_to_connection_reset() {
let err = InferenceExecutionError::Provider("ECONNRESET: connection reset by peer".into());
assert!(matches!(
classify_mid_stream(&err),
Some(InterruptCause::ConnectionReset)
));
}
#[test]
fn classify_mid_stream_maps_503_substring_to_provider_5xx() {
let err = InferenceExecutionError::Provider("503 Service Unavailable".into());
assert!(matches!(
classify_mid_stream(&err),
Some(InterruptCause::Provider5xxMidStream(_))
));
}
#[test]
fn classify_mid_stream_preserves_cause_from_stream_interrupted() {
let err = InferenceExecutionError::StreamInterrupted {
cause: InterruptCause::IdleStall,
snapshot: Box::new(InterruptSnapshot {
text: None,
completed_tool_calls: vec![],
in_flight_tool: None,
bytes_received: 0,
}),
};
assert!(matches!(
classify_mid_stream(&err),
Some(InterruptCause::IdleStall)
));
}
#[test]
fn classify_mid_stream_refuses_permanent_errors() {
assert!(
classify_mid_stream(&InferenceExecutionError::ContextOverflow("x".into())).is_none()
);
assert!(classify_mid_stream(&InferenceExecutionError::Unauthorized("x".into())).is_none());
assert!(
classify_mid_stream(&InferenceExecutionError::ContentFiltered("x".into())).is_none()
);
assert!(classify_mid_stream(&InferenceExecutionError::Cancelled).is_none());
}
#[tokio::test]
async fn checkpoint_is_flushed_on_mid_stream_interruption() {
use awaken_contract::contract::stream_checkpoint::{
InMemoryStreamCheckpointStore, StreamCheckpointStore,
};
let deltas: Vec<Result<LlmStreamEvent, InferenceExecutionError>> = (0..8)
.map(|i| Ok(LlmStreamEvent::TextDelta(format!("d{i}"))))
.chain(std::iter::once(Err(InferenceExecutionError::Provider(
"reset".into(),
))))
.collect();
let exec = Arc::new(ScriptedPerAttemptExecutor::new(vec![
deltas.clone(),
deltas,
]));
let agent = agent_with(exec.clone());
let sink = VecEventSink::new();
let store: Arc<InMemoryStreamCheckpointStore> =
Arc::new(InMemoryStreamCheckpointStore::new());
let handle = CheckpointHandle {
store: store.as_ref(),
run_id: "run-checkpoint-flush",
thread_id: "thread-1",
};
let mut it = 0u64;
let mut ot = 0u64;
let _ = execute_streaming(
&agent,
make_request(),
&sink,
None,
&mut it,
&mut ot,
Some(handle),
)
.await;
let saved = store
.get("run-checkpoint-flush")
.await
.unwrap()
.expect("checkpoint must have been persisted before failure");
assert_eq!(saved.run_id, "run-checkpoint-flush");
assert_eq!(saved.thread_id, "thread-1");
assert!(
saved.partial_text.contains("d0") && saved.partial_text.contains("d7"),
"partial_text should contain all 8 deltas, got: {}",
saved.partial_text
);
}
#[tokio::test]
async fn cross_process_resume_injects_continuation_from_checkpoint() {
use awaken_contract::contract::stream_checkpoint::{
InMemoryStreamCheckpointStore, StreamCheckpoint, StreamCheckpointStore,
};
let store: Arc<InMemoryStreamCheckpointStore> =
Arc::new(InMemoryStreamCheckpointStore::new());
store
.put(StreamCheckpoint {
run_id: "run-resumed".into(),
thread_id: "thread-1".into(),
upstream_model: "test-model".into(),
partial_text: "half-written answer".into(),
completed_tool_calls: vec![],
in_flight_tool: None,
updated_at_ms: 1_000,
})
.await
.unwrap();
struct CapturingExec {
captured: Arc<std::sync::Mutex<Vec<InferenceRequest>>>,
}
#[async_trait]
impl awaken_contract::contract::executor::LlmExecutor for CapturingExec {
async fn execute(
&self,
_r: InferenceRequest,
) -> Result<StreamResult, InferenceExecutionError> {
Err(InferenceExecutionError::Provider("unused".into()))
}
fn execute_stream(
&self,
request: InferenceRequest,
) -> std::pin::Pin<
Box<
dyn std::future::Future<
Output = Result<InferenceStream, InferenceExecutionError>,
> + Send
+ '_,
>,
> {
self.captured.lock().unwrap().push(request);
Box::pin(async move {
let events: Vec<Result<LlmStreamEvent, InferenceExecutionError>> = vec![
Ok(LlmStreamEvent::TextDelta(" — conclusion.".into())),
Ok(LlmStreamEvent::Stop(StopReason::EndTurn)),
];
Ok(Box::pin(futures::stream::iter(events)) as InferenceStream)
})
}
fn name(&self) -> &str {
"capturing"
}
}
let captured: Arc<std::sync::Mutex<Vec<InferenceRequest>>> =
Arc::new(std::sync::Mutex::new(Vec::new()));
let exec = Arc::new(CapturingExec {
captured: captured.clone(),
});
let agent = ResolvedAgent::new("test", "test-model", "sys", exec);
let sink = VecEventSink::new();
let handle = CheckpointHandle {
store: store.as_ref(),
run_id: "run-resumed",
thread_id: "thread-1",
};
let mut it = 0u64;
let mut ot = 0u64;
let (result, _hint) = execute_streaming(
&agent,
make_request(),
&sink,
None,
&mut it,
&mut ot,
Some(handle),
)
.await
.expect("resume should succeed");
let reqs = captured.lock().unwrap();
assert_eq!(reqs.len(), 1);
let last_two: Vec<_> = reqs[0]
.messages
.iter()
.rev()
.take(2)
.rev()
.cloned()
.collect();
assert_eq!(last_two.len(), 2);
assert_eq!(
last_two[0].text(),
"half-written answer",
"assistant prefix must carry saved partial text"
);
assert!(
last_two[1].text().contains("interrupted mid-stream"),
"user continuation prompt must follow the prefix, got: {}",
last_two[1].text()
);
assert_eq!(result.text(), " — conclusion.");
assert_eq!(result.stop_reason, Some(StopReason::EndTurn));
assert!(
store.get("run-resumed").await.unwrap().is_none(),
"checkpoint must be deleted after successful resume"
);
}
#[tokio::test]
async fn cross_process_resume_with_completed_tool_checkpoint_short_circuits_to_tool_use() {
use awaken_contract::contract::stream_checkpoint::{
InMemoryStreamCheckpointStore, StreamCheckpoint, StreamCheckpointStore,
};
use serde_json::json;
let store: Arc<InMemoryStreamCheckpointStore> =
Arc::new(InMemoryStreamCheckpointStore::new());
store
.put(StreamCheckpoint {
run_id: "run-r2-resumed".into(),
thread_id: "thread-1".into(),
upstream_model: "test-model".into(),
partial_text: "thinking...".into(),
completed_tool_calls: vec![ToolCall::new("tc-1", "search", json!({"q": "rust"}))],
in_flight_tool: None,
updated_at_ms: 1_000,
})
.await
.unwrap();
struct NeverCallMe;
#[async_trait]
impl awaken_contract::contract::executor::LlmExecutor for NeverCallMe {
async fn execute(
&self,
_r: InferenceRequest,
) -> Result<StreamResult, InferenceExecutionError> {
panic!("R2 checkpoint resume must not reopen a stream");
}
fn execute_stream(
&self,
_r: InferenceRequest,
) -> std::pin::Pin<
Box<
dyn std::future::Future<
Output = Result<InferenceStream, InferenceExecutionError>,
> + Send
+ '_,
>,
> {
panic!("R2 checkpoint resume must not reopen a stream");
}
fn name(&self) -> &str {
"never-call"
}
}
let agent = ResolvedAgent::new("test", "test-model", "sys", Arc::new(NeverCallMe));
let sink = VecEventSink::new();
let handle = CheckpointHandle {
store: store.as_ref(),
run_id: "run-r2-resumed",
thread_id: "thread-1",
};
let mut it = 0u64;
let mut ot = 0u64;
let (result, _hint) = execute_streaming(
&agent,
make_request(),
&sink,
None,
&mut it,
&mut ot,
Some(handle),
)
.await
.expect("R2 resume should short-circuit successfully");
assert_eq!(result.stop_reason, Some(StopReason::ToolUse));
assert_eq!(result.tool_calls.len(), 1);
assert_eq!(result.tool_calls[0].name, "search");
assert_eq!(result.text(), "thinking...");
assert!(store.get("run-r2-resumed").await.unwrap().is_none());
let events = sink.events();
assert!(
events.iter().any(|e| matches!(
e,
AgentEvent::ToolCallReady { id, .. } if id == "tc-1"
)),
"expected ToolCallReady for the resumed tool"
);
}
#[tokio::test(start_paused = true)]
async fn cancellation_during_backoff_aborts_retry_loop_with_cancelled_error() {
use crate::cancellation::CancellationToken;
let exec = Arc::new(ScriptedPerAttemptExecutor::new(vec![
vec![Err(InferenceExecutionError::Provider("reset".into()))],
vec![Err(InferenceExecutionError::Provider(
"should-not-be-reached".into(),
))],
]));
let agent = agent_with(exec.clone());
let sink = VecEventSink::new();
let token = CancellationToken::new();
let mut it = 0u64;
let mut ot = 0u64;
let exec_fut = execute_streaming(
&agent,
make_request(),
&sink,
Some(&token),
&mut it,
&mut ot,
None,
);
let drive = async {
tokio::time::sleep(Duration::from_millis(1)).await;
token.cancel();
tokio::time::advance(Duration::from_secs(30)).await;
};
let (result, ()) = tokio::join!(exec_fut, drive);
let err = result.expect_err("cancellation must abort the retry loop");
match err {
AgentLoopError::InferenceFailed(msg) => {
assert!(
msg.contains("cancelled"),
"expected cancellation message, got: {msg}"
);
}
other => panic!("expected InferenceFailed(cancelled), got: {other:?}"),
}
assert_eq!(exec.attempts(), 1, "retry must not proceed after cancel");
}
}