use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use tracing::debug;
use crate::types::{AgentMessage, LlmMessage, StopReason};
use super::stream::stream_with_retry;
use super::turn::{
build_snapshot, emit_turn_end_and_agent_end, handle_cancellation, run_context_transformers,
};
use super::{
AgentEvent, AgentLoopConfig, LoopState, StreamResult, TurnEndReason, TurnOutcome,
build_error_message, emit,
};
pub(super) enum OverflowRecoveryResult {
Recovered(Box<StreamResult>),
Failed(TurnOutcome),
}
#[allow(clippy::too_many_arguments)]
pub(super) async fn attempt_overflow_recovery(
config: &Arc<AgentLoopConfig>,
state: &mut LoopState,
system_prompt: &str,
agent_context: &crate::types::AgentContext,
api_key: Option<String>,
cancellation_token: &CancellationToken,
tx: &mpsc::Sender<AgentEvent>,
) -> OverflowRecoveryResult {
if state.overflow_recovery_attempted {
debug!("second overflow in same turn — surfacing error");
return overflow_error(config, state, tx).await;
}
if config.async_transform_context.is_none() && config.transform_context.is_none() {
debug!("no context transformer configured — cannot recover from overflow");
return overflow_error(config, state, tx).await;
}
state.overflow_recovery_attempted = true;
state.overflow_signal = true;
let any_compacted =
run_context_transformers(config, &mut state.context_messages, true, tx).await;
state.overflow_signal = false;
if !any_compacted {
debug!("transformers ran but no compaction occurred — surfacing error");
return overflow_error(config, state, tx).await;
}
if cancellation_token.is_cancelled() {
return OverflowRecoveryResult::Failed(handle_cancellation(config, state, tx).await);
}
let llm_messages: Vec<LlmMessage> = state
.context_messages
.iter()
.filter_map(|m| (config.convert_to_llm)(m))
.collect();
let retry_result = stream_with_retry(
config,
agent_context,
&llm_messages,
system_prompt,
api_key,
cancellation_token,
tx,
)
.await;
if matches!(retry_result, StreamResult::ContextOverflow) {
debug!("retry after compaction still overflowed — surfacing error");
return overflow_error(config, state, tx).await;
}
OverflowRecoveryResult::Recovered(Box::new(retry_result))
}
pub(super) async fn overflow_error(
config: &Arc<AgentLoopConfig>,
state: &mut LoopState,
tx: &mpsc::Sender<AgentEvent>,
) -> OverflowRecoveryResult {
let error = crate::error::AgentError::ContextWindowOverflow {
model: config.model.model_id.clone(),
};
let error_msg = build_error_message(&config.model, &error);
let msg_for_event = error_msg.clone();
state
.context_messages
.push(AgentMessage::Llm(LlmMessage::Assistant(error_msg)));
let _ = emit(
tx,
AgentEvent::MessageEnd {
message: msg_for_event.clone(),
},
)
.await;
let snapshot = build_snapshot(state, StopReason::Error, None);
OverflowRecoveryResult::Failed(
emit_turn_end_and_agent_end(
msg_for_event,
vec![],
TurnEndReason::Error,
snapshot,
state,
tx,
)
.await,
)
}