use std::sync::Arc;
use futures::stream::StreamExt;
use crate::agent::recovery::{RecoveryPolicy, classify_error};
use super::message::{DeltaPhase, StreamEvent};
use super::stream::StreamFn;
#[cfg(test)]
use super::tool::AbortSignal;
const STALL_RECOVERY_NUDGE: &str = "Your previous response did not complete in time — you may have been stuck in a long reasoning loop. Stop deliberating: state your conclusion concisely and take the next concrete action (a tool call, or your final answer) now.";
fn is_timeout_error(msg: &str) -> bool {
let l = msg.to_ascii_lowercase();
l.contains("timeout") || l.contains("timed out")
}
pub fn retrying_stream_fn(inner: StreamFn, policy: RecoveryPolicy) -> StreamFn {
let policy = Arc::new(policy);
Arc::new(move |ctx, opts: super::stream::StreamOptions| {
let inner = inner.clone();
let policy = policy.clone();
let signal_outer = opts.signal.clone();
let mut ctx = ctx;
Box::pin(async_stream::stream! {
let mut attempts: usize = 0;
let mut stall_nudged = false;
loop {
if signal_outer.is_cancelled() {
yield StreamEvent::Error {
error: "operation aborted before stream started".to_string(),
};
return;
}
let mut inner_stream = inner(ctx.clone(), opts.clone());
let mut committed = false;
let mut retry_msg: Option<String> = None;
while let Some(evt) = inner_stream.next().await {
match &evt {
StreamEvent::Error { error } => {
if !committed {
let kind = classify_error(error);
if policy.should_retry(attempts, kind) {
retry_msg = Some(error.clone());
break;
}
}
yield evt;
return;
}
StreamEvent::Delta { phase, .. } => {
if is_content_delta(*phase) {
committed = true;
}
yield evt;
}
StreamEvent::Done { .. } => {
yield evt;
return;
}
StreamEvent::Start { .. } => {
yield evt;
}
StreamEvent::Retry { .. } => {
yield evt;
return;
}
}
}
match retry_msg {
Some(err_msg) => {
let backoff = policy.backoff_duration_for_msg(attempts, &err_msg);
tokio::time::sleep(backoff).await;
if signal_outer.is_cancelled() {
yield StreamEvent::Error {
error: "operation aborted during retry backoff".to_string(),
};
return;
}
attempts += 1;
let was_timeout = is_timeout_error(&err_msg);
yield StreamEvent::Retry {
attempt: attempts as u32,
delay_ms: backoff.as_millis() as u64,
error: err_msg,
};
if was_timeout && !stall_nudged {
stall_nudged = true;
ctx.messages.push(serde_json::json!({
"role": "user",
"content": STALL_RECOVERY_NUDGE,
}));
}
}
None => {
return;
}
}
}
})
})
}
fn is_content_delta(phase: DeltaPhase) -> bool {
matches!(
phase,
DeltaPhase::TextStart
| DeltaPhase::TextDelta
| DeltaPhase::ThinkingStart
| DeltaPhase::ThinkingDelta
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::agent_loop::message::{AssistantMessage, ContentBlock, StopReason};
use crate::agent::agent_loop::stream::LlmContext;
use crate::agent::recovery::RecoveryPolicy;
use std::sync::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
async fn drain(
mut s: std::pin::Pin<Box<dyn futures::Stream<Item = StreamEvent> + Send>>,
) -> Vec<StreamEvent> {
let mut out = Vec::new();
while let Some(e) = s.next().await {
out.push(e);
}
out
}
fn canned_stream_fn(events: Vec<Vec<StreamEvent>>) -> StreamFn {
let counter = Arc::new(AtomicUsize::new(0));
let events = Arc::new(Mutex::new(events));
Arc::new(move |_ctx, _opts| {
let n = counter.fetch_add(1, Ordering::SeqCst);
let attempts = events.lock().unwrap();
let attempt_events = attempts.get(n).cloned().unwrap_or_default();
Box::pin(futures::stream::iter(attempt_events))
})
}
fn counted_canned(events: Vec<Vec<StreamEvent>>) -> (StreamFn, Arc<AtomicUsize>) {
let counter = Arc::new(AtomicUsize::new(0));
let events = Arc::new(Mutex::new(events));
let counter_clone = counter.clone();
let factory: StreamFn = Arc::new(move |_ctx, _opts| {
let n = counter_clone.fetch_add(1, Ordering::SeqCst);
let attempts = events.lock().unwrap();
let attempt_events = attempts.get(n).cloned().unwrap_or_default();
Box::pin(futures::stream::iter(attempt_events))
});
(factory, counter)
}
fn ctx() -> LlmContext {
LlmContext {
system_prompt: String::new(),
messages: vec![serde_json::json!({"role": "user", "content": "hi"})],
}
}
fn empty_assistant() -> AssistantMessage {
AssistantMessage::new(vec![], StopReason::Stop)
}
fn assistant_with(text: &str) -> AssistantMessage {
AssistantMessage::new(
vec![ContentBlock::Text {
text: text.to_string(),
}],
StopReason::Stop,
)
}
#[tokio::test]
async fn passthrough_when_no_errors() {
let inner = canned_stream_fn(vec![vec![
StreamEvent::Start {
partial: empty_assistant(),
},
StreamEvent::Done {
reason: StopReason::Stop,
message: assistant_with("hello"),
usage: None,
},
]]);
let wrapped = retrying_stream_fn(inner, RecoveryPolicy::default());
let events = drain(wrapped(
ctx(),
crate::agent::agent_loop::StreamOptions::from_signal(AbortSignal::new()),
))
.await;
assert_eq!(events.len(), 2);
assert!(matches!(events[0], StreamEvent::Start { .. }));
assert!(matches!(events[1], StreamEvent::Done { .. }));
}
#[tokio::test]
async fn retries_on_network_error() {
let policy = RecoveryPolicy::default();
let (factory, counter) = counted_canned(vec![
vec![StreamEvent::Error {
error: "connection timed out".to_string(),
}],
vec![
StreamEvent::Start {
partial: empty_assistant(),
},
StreamEvent::Done {
reason: StopReason::Stop,
message: assistant_with("after retry"),
usage: None,
},
],
]);
let wrapped = retrying_stream_fn(factory, policy);
tokio::time::pause();
let drain_task = tokio::spawn(async move {
drain(wrapped(
ctx(),
crate::agent::agent_loop::StreamOptions::from_signal(AbortSignal::new()),
))
.await
});
tokio::time::advance(std::time::Duration::from_secs(10)).await;
let events = drain_task.await.unwrap();
assert_eq!(counter.load(Ordering::SeqCst), 2);
let kinds: Vec<&str> = events
.iter()
.map(|e| match e {
StreamEvent::Start { .. } => "start",
StreamEvent::Delta { .. } => "delta",
StreamEvent::Done { .. } => "done",
StreamEvent::Error { .. } => "error",
StreamEvent::Retry { .. } => "retry",
})
.collect();
assert_eq!(kinds, vec!["retry", "start", "done"]);
}
#[tokio::test]
async fn does_not_retry_auth_error() {
let (factory, counter) = counted_canned(vec![vec![StreamEvent::Error {
error: "401 unauthorized: invalid api key".to_string(),
}]]);
let wrapped = retrying_stream_fn(factory, RecoveryPolicy::default());
let events = drain(wrapped(
ctx(),
crate::agent::agent_loop::StreamOptions::from_signal(AbortSignal::new()),
))
.await;
assert_eq!(counter.load(Ordering::SeqCst), 1);
assert_eq!(events.len(), 1);
assert!(matches!(events[0], StreamEvent::Error { .. }));
}
#[tokio::test]
async fn does_not_retry_context_length_error() {
let (factory, counter) = counted_canned(vec![vec![StreamEvent::Error {
error: "context length exceeded: prompt is too long".to_string(),
}]]);
let wrapped = retrying_stream_fn(factory, RecoveryPolicy::default());
let events = drain(wrapped(
ctx(),
crate::agent::agent_loop::StreamOptions::from_signal(AbortSignal::new()),
))
.await;
assert_eq!(counter.load(Ordering::SeqCst), 1);
assert!(matches!(events[0], StreamEvent::Error { .. }));
}
#[tokio::test]
async fn does_not_retry_after_content_committed() {
let (factory, counter) = counted_canned(vec![vec![
StreamEvent::Start {
partial: empty_assistant(),
},
StreamEvent::Delta {
partial: assistant_with("partial "),
phase: DeltaPhase::TextStart,
},
StreamEvent::Error {
error: "connection reset".to_string(),
},
]]);
let wrapped = retrying_stream_fn(factory, RecoveryPolicy::default());
let events = drain(wrapped(
ctx(),
crate::agent::agent_loop::StreamOptions::from_signal(AbortSignal::new()),
))
.await;
assert_eq!(counter.load(Ordering::SeqCst), 1);
assert!(matches!(events[0], StreamEvent::Start { .. }));
assert!(matches!(
events[1],
StreamEvent::Delta {
phase: DeltaPhase::TextStart,
..
}
));
assert!(matches!(events[2], StreamEvent::Error { .. }));
}
#[tokio::test]
async fn retries_on_rate_limit_with_retry_after() {
let (factory, counter) = counted_canned(vec![
vec![StreamEvent::Error {
error: "rate limit hit. retry-after-ms: 50".to_string(),
}],
vec![StreamEvent::Done {
reason: StopReason::Stop,
message: assistant_with("ok"),
usage: None,
}],
]);
let wrapped = retrying_stream_fn(factory, RecoveryPolicy::default());
tokio::time::pause();
let task = tokio::spawn(async move {
drain(wrapped(
ctx(),
crate::agent::agent_loop::StreamOptions::from_signal(AbortSignal::new()),
))
.await
});
tokio::time::advance(std::time::Duration::from_secs(5)).await;
let events = task.await.unwrap();
assert_eq!(counter.load(Ordering::SeqCst), 2);
assert!(matches!(events.last(), Some(StreamEvent::Done { .. })));
}
#[tokio::test]
async fn surfaces_error_after_max_retries() {
let attempts: Vec<Vec<StreamEvent>> = (0..6)
.map(|_| {
vec![StreamEvent::Error {
error: "network: connection timed out".to_string(),
}]
})
.collect();
let (factory, counter) = counted_canned(attempts);
let wrapped = retrying_stream_fn(factory, RecoveryPolicy::default());
tokio::time::pause();
let task = tokio::spawn(async move {
drain(wrapped(
ctx(),
crate::agent::agent_loop::StreamOptions::from_signal(AbortSignal::new()),
))
.await
});
tokio::time::advance(std::time::Duration::from_secs(600)).await;
let events = task.await.unwrap();
assert_eq!(counter.load(Ordering::SeqCst), 6);
assert!(matches!(events.last(), Some(StreamEvent::Error { .. })));
}
#[tokio::test]
async fn aborted_before_attempt_emits_error() {
let (factory, counter) = counted_canned(vec![]);
let wrapped = retrying_stream_fn(factory, RecoveryPolicy::default());
let signal = AbortSignal::new();
signal.cancel();
let events = drain(wrapped(
ctx(),
crate::agent::agent_loop::StreamOptions::from_signal(signal),
))
.await;
assert_eq!(counter.load(Ordering::SeqCst), 0);
assert!(matches!(events[0], StreamEvent::Error { .. }));
}
#[tokio::test]
async fn aborted_during_backoff_emits_error() {
let (factory, counter) = counted_canned(vec![
vec![StreamEvent::Error {
error: "network glitch".to_string(),
}],
vec![StreamEvent::Done {
reason: StopReason::Stop,
message: assistant_with("never seen"),
usage: None,
}],
]);
let wrapped = retrying_stream_fn(factory, RecoveryPolicy::default());
let signal = AbortSignal::new();
let signal_clone = signal.clone();
tokio::time::pause();
let task = tokio::spawn(async move {
drain(wrapped(
ctx(),
crate::agent::agent_loop::StreamOptions::from_signal(signal_clone),
))
.await
});
for _ in 0..5 {
tokio::task::yield_now().await;
}
signal.cancel();
tokio::time::advance(std::time::Duration::from_secs(600)).await;
let events = task.await.unwrap();
assert_eq!(counter.load(Ordering::SeqCst), 1);
assert!(matches!(events.last(), Some(StreamEvent::Error { .. })));
}
#[test]
fn is_content_delta_classifies_phases() {
for phase in [
DeltaPhase::TextStart,
DeltaPhase::TextDelta,
DeltaPhase::ThinkingStart,
DeltaPhase::ThinkingDelta,
] {
assert!(is_content_delta(phase), "{phase:?} should be content");
}
for phase in [
DeltaPhase::TextEnd,
DeltaPhase::ThinkingEnd,
DeltaPhase::ToolCallStart,
DeltaPhase::ToolCallDelta,
DeltaPhase::ToolCallEnd,
] {
assert!(!is_content_delta(phase), "{phase:?} should NOT be content");
}
}
fn recording_factory(first_error: &str) -> (StreamFn, Arc<Mutex<Vec<Vec<serde_json::Value>>>>) {
let seen: Arc<Mutex<Vec<Vec<serde_json::Value>>>> = Arc::new(Mutex::new(Vec::new()));
let seen2 = seen.clone();
let counter = Arc::new(AtomicUsize::new(0));
let first_error = first_error.to_string();
let factory: StreamFn = Arc::new(move |ctx: LlmContext, _opts| {
seen2.lock().unwrap().push(ctx.messages.clone());
let n = counter.fetch_add(1, Ordering::SeqCst);
let events = if n == 0 {
vec![StreamEvent::Error {
error: first_error.clone(),
}]
} else {
vec![StreamEvent::Done {
reason: StopReason::Stop,
message: empty_assistant(),
usage: None,
}]
};
Box::pin(futures::stream::iter(events))
});
(factory, seen)
}
fn has_stall_nudge(msgs: &[serde_json::Value]) -> bool {
msgs.iter()
.any(|m| m.get("content").and_then(|c| c.as_str()) == Some(STALL_RECOVERY_NUDGE))
}
#[tokio::test]
async fn timeout_retry_injects_stall_nudge() {
let (factory, seen) = recording_factory("request timeout after 120s");
let wrapped = retrying_stream_fn(factory, RecoveryPolicy::default());
let _ = drain(wrapped(
ctx(),
crate::agent::agent_loop::StreamOptions::from_signal(AbortSignal::new()),
))
.await;
let recorded = seen.lock().unwrap();
assert_eq!(recorded.len(), 2, "one timeout retry → two inner calls");
assert!(!has_stall_nudge(&recorded[0]), "no nudge on first attempt");
assert!(
has_stall_nudge(&recorded[1]),
"stall nudge reinjected on the timeout retry"
);
}
#[tokio::test]
async fn non_timeout_retry_does_not_nudge() {
let (factory, seen) = recording_factory("503 service unavailable");
let wrapped = retrying_stream_fn(factory, RecoveryPolicy::default());
let _ = drain(wrapped(
ctx(),
crate::agent::agent_loop::StreamOptions::from_signal(AbortSignal::new()),
))
.await;
let recorded = seen.lock().unwrap();
assert_eq!(recorded.len(), 2);
assert!(
!has_stall_nudge(&recorded[1]),
"no nudge for a non-timeout error"
);
}
}