use std::borrow::Cow;
use std::sync::Arc;
use futures::StreamExt;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, info_span, warn};
use crate::error::AgentError;
use crate::stream::{
AssistantMessageDelta, AssistantMessageEvent, StreamFn, StreamOptions, accumulate_message,
};
use crate::types::{
AgentContext, AgentMessage, AssistantMessage, LlmMessage, ModelSpec, StopReason, ThinkingLevel,
};
use super::{
AgentEvent, AgentLoopConfig, StreamResult, build_abort_message, build_error_message,
classify_stream_error, emit,
};
pub async fn stream_with_retry(
config: &Arc<AgentLoopConfig>,
agent_context: &AgentContext,
llm_messages: &[LlmMessage],
system_prompt: &str,
api_key: Option<String>,
cancellation_token: &CancellationToken,
tx: &mpsc::Sender<AgentEvent>,
) -> StreamResult {
let primary_result = stream_with_retry_single(
&config.model,
&config.stream_fn,
config,
agent_context,
llm_messages,
system_prompt,
api_key.clone(),
cancellation_token,
tx,
)
.await;
let last_error_msg = match &primary_result {
StreamResult::Message(msg)
if msg.stop_reason != StopReason::Error || !is_fallback_eligible_error(msg) =>
{
return primary_result;
}
StreamResult::ContextOverflow | StreamResult::Aborted | StreamResult::ChannelClosed => {
return primary_result;
}
StreamResult::Message(msg) => msg.clone(),
};
let fallback = match config.fallback {
Some(ref fb) if !fb.is_empty() => fb,
_ => return StreamResult::Message(last_error_msg),
};
let mut last_result = StreamResult::Message(last_error_msg);
for (fb_model, fb_stream_fn) in fallback.models() {
if !emit(
tx,
AgentEvent::ModelFallback {
from_model: config.model.clone(),
to_model: fb_model.clone(),
},
)
.await
{
return StreamResult::ChannelClosed;
}
warn!(
from = %config.model.model_id,
to = %fb_model.model_id,
"falling back to alternate model"
);
let fb_result = stream_with_retry_single(
fb_model,
fb_stream_fn,
config,
agent_context,
llm_messages,
system_prompt,
api_key.clone(),
cancellation_token,
tx,
)
.await;
match &fb_result {
StreamResult::Message(msg)
if msg.stop_reason != StopReason::Error || !is_fallback_eligible_error(msg) =>
{
return fb_result;
}
StreamResult::ContextOverflow | StreamResult::Aborted | StreamResult::ChannelClosed => {
return fb_result;
}
StreamResult::Message(_) => {
last_result = fb_result;
}
}
}
last_result
}
fn is_fallback_eligible_error(msg: &AssistantMessage) -> bool {
let Some(error_message) = msg.error_message.as_deref() else {
return false;
};
let harness_error = classify_stream_error(error_message, StopReason::Error, msg.error_kind);
harness_error.is_retryable()
}
#[allow(clippy::too_many_arguments)]
async fn stream_with_retry_single(
model: &ModelSpec,
stream_fn: &Arc<dyn StreamFn>,
config: &Arc<AgentLoopConfig>,
agent_context: &AgentContext,
llm_messages: &[LlmMessage],
system_prompt: &str,
api_key: Option<String>,
cancellation_token: &CancellationToken,
tx: &mpsc::Sender<AgentEvent>,
) -> StreamResult {
let llm_span = info_span!(
"agent.llm_call",
agent.model = %model.model_id,
agent.tokens.input = tracing::field::Empty,
agent.tokens.output = tracing::field::Empty,
agent.cost.total = tracing::field::Empty,
otel.status_code = tracing::field::Empty,
);
let _llm_guard = llm_span.enter();
let mut attempt: u32 = 0;
loop {
attempt += 1;
debug!(attempt, model_id = %model.model_id, "starting stream attempt");
if cancellation_token.is_cancelled() {
return StreamResult::Aborted;
}
let call_context = AgentContext {
system_prompt: system_prompt.to_string(),
messages: llm_messages
.iter()
.map(|m| AgentMessage::Llm(m.clone()))
.collect(),
tools: agent_context.tools.clone(),
};
let mut stream_options = config.stream_options.clone();
stream_options.api_key = api_key.clone();
if !emit(tx, AgentEvent::MessageStart).await {
return StreamResult::ChannelClosed;
}
let attempt_result = stream_single_attempt(
model,
stream_fn,
&call_context,
&stream_options,
cancellation_token,
tx,
)
.await;
let (events, had_error) = match attempt_result {
StreamAttemptResult::EarlyExit(result) => return result,
StreamAttemptResult::Collected { events, error } => (events, error),
};
if let Some((stop_reason, error_message, _usage, error_kind)) = had_error {
let retry_result = handle_stream_error(
model,
config,
&stop_reason,
&error_message,
error_kind,
attempt,
tx,
)
.await;
match retry_result {
StreamErrorAction::ContextOverflow => return StreamResult::ContextOverflow,
StreamErrorAction::Retry(delay) => {
tokio::time::sleep(delay).await;
continue;
}
StreamErrorAction::FatalError(msg) => {
llm_span.record("otel.status_code", "ERROR");
return msg;
}
StreamErrorAction::ChannelClosed => return StreamResult::ChannelClosed,
}
}
let result = finalize_stream_message(model, events, tx).await;
if let StreamResult::Message(ref msg) = result {
llm_span.record("agent.tokens.input", msg.usage.input);
llm_span.record("agent.tokens.output", msg.usage.output);
llm_span.record("agent.cost.total", msg.cost.total);
}
return result;
}
}
#[allow(clippy::large_enum_variant)]
enum StreamErrorAction {
ContextOverflow,
Retry(std::time::Duration),
FatalError(StreamResult),
ChannelClosed,
}
enum StreamAttemptResult {
Collected {
events: Vec<AssistantMessageEvent>,
error: Option<(
StopReason,
String,
Option<crate::types::Usage>,
Option<crate::stream::StreamErrorKind>,
)>,
},
EarlyExit(StreamResult),
}
async fn stream_single_attempt(
model: &ModelSpec,
stream_fn: &Arc<dyn StreamFn>,
call_context: &AgentContext,
stream_options: &StreamOptions,
cancellation_token: &CancellationToken,
tx: &mpsc::Sender<AgentEvent>,
) -> StreamAttemptResult {
let effective_model = apply_capability_overrides(model);
let mut stream = stream_fn.stream(
&effective_model,
call_context,
stream_options,
cancellation_token.clone(),
);
let mut events: Vec<AssistantMessageEvent> = Vec::new();
let mut had_error: Option<(
StopReason,
String,
Option<crate::types::Usage>,
Option<crate::stream::StreamErrorKind>,
)> = None;
while let Some(event) = stream.next().await {
if cancellation_token.is_cancelled() {
let abort_msg = build_abort_message(model);
let _ = emit(tx, AgentEvent::MessageEnd { message: abort_msg }).await;
return StreamAttemptResult::EarlyExit(StreamResult::Aborted);
}
if let Some(early_exit) = emit_delta_event(&event, tx).await {
return StreamAttemptResult::EarlyExit(early_exit);
}
if let AssistantMessageEvent::Error {
stop_reason,
error_message,
usage,
error_kind,
} = &event
{
had_error = Some((
*stop_reason,
error_message.clone(),
usage.clone(),
*error_kind,
));
}
events.push(event);
}
StreamAttemptResult::Collected {
events,
error: had_error,
}
}
async fn emit_delta_event(
event: &AssistantMessageEvent,
tx: &mpsc::Sender<AgentEvent>,
) -> Option<StreamResult> {
let delta = match event {
AssistantMessageEvent::TextDelta {
content_index,
delta,
} => Some(AssistantMessageDelta::Text {
content_index: *content_index,
delta: Cow::Owned(delta.clone()),
}),
AssistantMessageEvent::ThinkingDelta {
content_index,
delta,
} => Some(AssistantMessageDelta::Thinking {
content_index: *content_index,
delta: Cow::Owned(delta.clone()),
}),
AssistantMessageEvent::ToolCallDelta {
content_index,
delta,
} => Some(AssistantMessageDelta::ToolCall {
content_index: *content_index,
delta: Cow::Owned(delta.clone()),
}),
_ => None,
};
if let Some(d) = delta
&& !emit(tx, AgentEvent::MessageUpdate { delta: d }).await
{
return Some(StreamResult::ChannelClosed);
}
None
}
async fn handle_stream_error(
model: &ModelSpec,
config: &Arc<AgentLoopConfig>,
stop_reason: &StopReason,
error_message: &str,
error_kind: Option<crate::stream::StreamErrorKind>,
attempt: u32,
tx: &mpsc::Sender<AgentEvent>,
) -> StreamErrorAction {
let harness_error = classify_stream_error(error_message, *stop_reason, error_kind);
if matches!(harness_error, AgentError::ContextWindowOverflow { .. }) {
warn!("context window overflow, signaling prune");
let _ = emit(
tx,
AgentEvent::MessageEnd {
message: build_error_message(model, &harness_error),
},
)
.await;
return StreamErrorAction::ContextOverflow;
}
if matches!(harness_error, AgentError::CacheMiss) {
warn!("provider cache miss, resetting cache state for retry");
{
let mut cache_state = config
.cache_state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
cache_state.reset();
}
return StreamErrorAction::Retry(std::time::Duration::ZERO);
}
if matches!(harness_error, AgentError::Aborted) {
let abort_msg = build_abort_message(model);
if !emit(tx, AgentEvent::MessageEnd { message: abort_msg }).await {
return StreamErrorAction::ChannelClosed;
}
return StreamErrorAction::FatalError(StreamResult::Aborted);
}
if config.retry_strategy.should_retry(&harness_error, attempt) {
let delay = config.retry_strategy.delay(attempt);
warn!(attempt, ?delay, error = %harness_error, "retrying after transient error");
return StreamErrorAction::Retry(delay);
}
error!(error = %harness_error, "non-retryable stream error");
let error_msg = build_error_message(model, &harness_error);
if !emit(
tx,
AgentEvent::MessageEnd {
message: error_msg.clone(),
},
)
.await
{
return StreamErrorAction::ChannelClosed;
}
StreamErrorAction::FatalError(StreamResult::Message(error_msg))
}
fn apply_capability_overrides(model: &ModelSpec) -> Cow<'_, ModelSpec> {
let Some(ref caps) = model.capabilities else {
return Cow::Borrowed(model);
};
let mut changed = false;
let mut overridden = model.clone();
if !caps.supports_thinking && overridden.thinking_level != ThinkingLevel::Off {
debug!(
model_id = %model.model_id,
"model does not support thinking — forcing thinking_level to Off"
);
overridden.thinking_level = ThinkingLevel::Off;
changed = true;
}
if changed {
Cow::Owned(overridden)
} else {
Cow::Borrowed(model)
}
}
pub fn capability_filter_tools(
model: &ModelSpec,
tools: &[Arc<dyn crate::tool::AgentTool>],
) -> Vec<Arc<dyn crate::tool::AgentTool>> {
if let Some(ref caps) = model.capabilities
&& !caps.supports_tool_use
&& !tools.is_empty()
{
debug!(
model_id = %model.model_id,
tool_count = tools.len(),
"model does not support tool use — stripping tools from context"
);
return Vec::new();
}
tools.to_vec()
}
async fn finalize_stream_message(
model: &ModelSpec,
events: Vec<AssistantMessageEvent>,
tx: &mpsc::Sender<AgentEvent>,
) -> StreamResult {
let message = match accumulate_message(events, &model.provider, &model.model_id) {
Ok(msg) => msg,
Err(e) => {
let err = AgentError::StreamError {
source: Box::new(std::io::Error::new(std::io::ErrorKind::InvalidData, e)),
};
let error_msg = build_error_message(model, &err);
let _ = emit(
tx,
AgentEvent::MessageEnd {
message: error_msg.clone(),
},
)
.await;
return StreamResult::Message(error_msg);
}
};
info!(
input_tokens = message.usage.input,
output_tokens = message.usage.output,
total_tokens = message.usage.total,
stop_reason = ?message.stop_reason,
"stream completed"
);
if !emit(
tx,
AgentEvent::MessageEnd {
message: message.clone(),
},
)
.await
{
return StreamResult::ChannelClosed;
}
StreamResult::Message(message)
}