use std::sync::Arc;
use std::time::Duration;
use agent_client_protocol_schema::{ContentBlock, TextContent};
use futures::StreamExt;
use rand::RngExt;
use serde_json::Value as JsonValue;
use tracing::Instrument;
use crate::event::{AgentEvent, LlmRequestSnapshot};
use crate::llm::{
CompletionRequest, Message, MessageContent, ProviderChunk, ProviderStream, RetryHint, Role,
StopReason as LlmStopReason, Usage,
};
use crate::session::TurnError;
use super::{TurnRunner, TurnState};
impl TurnRunner<'_> {
pub(super) async fn call_llm_with_retry(
&self,
req: &CompletionRequest,
state: &mut TurnState,
) -> Result<(ProviderStream, u32), TurnError> {
let max_attempts = self.config.max_llm_retries.saturating_add(1).max(1);
let vendor = self.provider.info().vendor.to_string();
let mut attempt: u32 = 0;
loop {
attempt += 1;
state.request_count = state.request_count.saturating_add(1);
let span = tracing::info_span!(
"llm_call",
vendor = %vendor,
model = %req.model,
attempt,
);
let step = self
.call_llm_attempt(req, attempt, max_attempts)
.instrument(span)
.await;
match step {
LlmAttempt::Done(stream) => return Ok((stream, attempt)),
LlmAttempt::Failed(err) => return Err(TurnError::Provider(err)),
LlmAttempt::Cancelled => return Ok((empty_stream(), attempt)),
LlmAttempt::Retry => continue,
}
}
}
async fn call_llm_attempt(
&self,
req: &CompletionRequest,
attempt: u32,
max_attempts: u32,
) -> LlmAttempt {
self.events
.emit(AgentEvent::LlmCallStarted {
model: req.model.clone(),
attempt,
request: Arc::new(LlmRequestSnapshot {
system: req.system.clone(),
messages: req.messages.clone(),
}),
})
.await;
match self
.provider
.complete(req.clone(), self.cancel.clone())
.await
{
Ok(stream) => {
LlmAttempt::Done(stream)
}
Err(err) => {
let hint = err.retry_hint();
let err_text = err.to_string();
self.events
.emit(AgentEvent::LlmCallFinished {
model: req.model.clone(),
attempt,
usage: Usage::default(),
error: Some(err_text),
})
.await;
if attempt >= max_attempts || matches!(hint, RetryHint::No) {
tracing::warn!(error = %err, ?hint, "llm call failed permanently");
return LlmAttempt::Failed(err);
}
if let Some(delay) = retry_delay(hint, attempt) {
tracing::info!(
?hint,
delay_ms = delay.as_millis() as u64,
"llm call failed, retrying after delay"
);
tokio::select! {
biased;
() = self.cancel.cancelled() => return LlmAttempt::Cancelled,
() = tokio::time::sleep(delay) => {}
}
} else {
tracing::info!(?hint, "llm call failed, retrying immediately");
}
LlmAttempt::Retry
}
}
}
pub(super) async fn drain_provider_stream(
&self,
stream: &mut ProviderStream,
state: &mut TurnState,
) -> Result<DrainOutcome, TurnError> {
let mut outcome = DrainOutcome::default();
loop {
tokio::select! {
biased;
() = self.cancel.cancelled() => {
outcome.cancelled = true;
return Ok(outcome);
}
next = stream.next() => match next {
None => {
if !outcome.saw_stop {
outcome.stop = LlmStopReason::EndTurn;
}
return Ok(outcome);
}
Some(Err(err)) => {
return Err(TurnError::Provider(err));
}
Some(Ok(chunk)) => {
if self.handle_chunk(chunk, &mut outcome, state).await {
return Ok(outcome);
}
}
}
}
}
}
async fn handle_chunk(
&self,
chunk: ProviderChunk,
outcome: &mut DrainOutcome,
state: &mut TurnState,
) -> bool {
match chunk {
ProviderChunk::MessageStart { .. } => false,
ProviderChunk::TextDelta { text } => {
outcome.text_buf.push_str(&text);
self.events
.emit(AgentEvent::AssistantText {
content: ContentBlock::Text(TextContent::new(text)),
})
.await;
false
}
ProviderChunk::ThinkingDelta { text } => {
outcome.thinking_buf.push_str(&text);
self.events
.emit(AgentEvent::AssistantThought {
content: ContentBlock::Text(TextContent::new(text)),
})
.await;
false
}
ProviderChunk::ThinkingSignature { signature } => {
outcome.thinking_signature = Some(signature);
false
}
ProviderChunk::ToolUseStart { id, name } => {
outcome.tool_uses.push(ToolUseAccumulated {
id,
name,
args_buf: String::new(),
});
false
}
ProviderChunk::ToolUseArgsDelta { id, fragment } => {
if let Some(slot) = outcome.tool_uses.iter_mut().find(|t| t.id == id) {
slot.args_buf.push_str(&fragment);
}
false
}
ProviderChunk::ToolUseEnd { .. } => false,
ProviderChunk::Stop { reason } => {
outcome.saw_stop = true;
outcome.stop = reason;
false
}
ProviderChunk::Usage(u) => {
outcome.usage = add_usage(outcome.usage, u);
state.usage = add_usage(state.usage, u);
false
}
}
}
}
enum LlmAttempt {
Done(ProviderStream),
Failed(crate::llm::ProviderError),
Cancelled,
Retry,
}
pub(super) struct DrainOutcome {
pub(super) saw_stop: bool,
pub(super) stop: LlmStopReason,
pub(super) text_buf: String,
pub(super) thinking_buf: String,
pub(super) thinking_signature: Option<String>,
pub(super) tool_uses: Vec<ToolUseAccumulated>,
pub(super) usage: Usage,
pub(super) cancelled: bool,
}
impl Default for DrainOutcome {
fn default() -> Self {
Self {
saw_stop: false,
stop: LlmStopReason::EndTurn,
text_buf: String::new(),
thinking_buf: String::new(),
thinking_signature: None,
tool_uses: Vec::new(),
usage: Usage::default(),
cancelled: false,
}
}
}
pub(super) struct ToolUseAccumulated {
pub(super) id: String,
pub(super) name: String,
pub(super) args_buf: String,
}
pub(super) fn assistant_message(outcome: &DrainOutcome) -> Message {
let mut content: Vec<MessageContent> = Vec::new();
if !outcome.thinking_buf.is_empty() || outcome.thinking_signature.is_some() {
content.push(MessageContent::Thinking {
text: outcome.thinking_buf.clone(),
signature: outcome.thinking_signature.clone(),
});
}
if !outcome.text_buf.is_empty() {
content.push(MessageContent::Text {
text: outcome.text_buf.clone(),
});
}
for tu in &outcome.tool_uses {
let args = parse_args(&tu.args_buf).unwrap_or(JsonValue::Object(Default::default()));
content.push(MessageContent::ToolUse {
id: tu.id.clone(),
name: tu.name.clone(),
args,
});
}
Message {
role: Role::Assistant,
content: content.into(),
}
}
pub(super) fn parse_args(buf: &str) -> Result<JsonValue, String> {
if buf.trim().is_empty() {
return Ok(JsonValue::Object(Default::default()));
}
serde_json::from_str(buf).map_err(|e| e.to_string())
}
fn add_usage(a: Usage, b: Usage) -> Usage {
Usage {
input_tokens: add_opt(a.input_tokens, b.input_tokens),
output_tokens: add_opt(a.output_tokens, b.output_tokens),
cache_read_input_tokens: add_opt(a.cache_read_input_tokens, b.cache_read_input_tokens),
cache_creation_input_tokens: add_opt(
a.cache_creation_input_tokens,
b.cache_creation_input_tokens,
),
}
}
pub(super) fn real_input_tokens(usage: &Usage) -> Option<u64> {
let input = usage.input_tokens;
let cache_read = usage.cache_read_input_tokens;
let cache_creation = usage.cache_creation_input_tokens;
if input.is_none() && cache_read.is_none() && cache_creation.is_none() {
return None;
}
Some(
input
.unwrap_or(0)
.saturating_add(cache_read.unwrap_or(0))
.saturating_add(cache_creation.unwrap_or(0)),
)
}
fn add_opt(a: Option<u64>, b: Option<u64>) -> Option<u64> {
match (a, b) {
(Some(x), Some(y)) => Some(x.saturating_add(y)),
(Some(x), None) | (None, Some(x)) => Some(x),
(None, None) => None,
}
}
fn retry_delay(hint: RetryHint, attempt: u32) -> Option<Duration> {
match hint {
RetryHint::No => None,
RetryHint::Immediate => Some(Duration::from_millis(0)),
RetryHint::After(d) => Some(d),
RetryHint::Backoff => Some(backoff_delay(attempt)),
RetryHint::AfterAction(_) => Some(Duration::from_millis(0)),
}
}
fn backoff_delay(attempt: u32) -> Duration {
let exp = attempt.saturating_sub(1).min(20);
let base_nanos = BACKOFF_INITIAL.as_nanos().saturating_mul(1u128 << exp);
let cap_nanos = BACKOFF_MAX.as_nanos();
let clamped = base_nanos.min(cap_nanos);
let mut rng = rand::rng();
let factor: f64 = 1.0 + rng.random_range(-BACKOFF_JITTER_FRAC..BACKOFF_JITTER_FRAC);
let nanos = (clamped as f64 * factor).round();
let nanos = nanos.clamp(0.0, cap_nanos as f64) as u128;
Duration::from_nanos(nanos.min(u128::from(u64::MAX)) as u64)
}
const BACKOFF_INITIAL: Duration = Duration::from_millis(500);
const BACKOFF_MAX: Duration = Duration::from_secs(16);
const BACKOFF_JITTER_FRAC: f64 = 0.25;
fn empty_stream() -> ProviderStream {
Box::pin(futures::stream::empty())
}
#[cfg(test)]
mod tests;