use super::helpers::{calculate_backoff_delay, send_event};
use super::types::StreamError;
use crate::events::{AgentEvent, AgentEventEnvelope, SequenceCounter};
use crate::hooks::AgentHooks;
use crate::llm::{
ChatOutcome, ChatRequest, ChatResponse, LlmProvider, StreamAccumulator, StreamDelta, Usage,
};
use crate::types::{AgentConfig, AgentError};
use futures::StreamExt;
use log::{error, warn};
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::time::sleep;
pub(super) async fn call_llm_with_retry<P, H>(
provider: &Arc<P>,
request: ChatRequest,
config: &AgentConfig,
tx: &mpsc::Sender<AgentEventEnvelope>,
hooks: &Arc<H>,
seq: &SequenceCounter,
) -> Result<ChatResponse, AgentError>
where
P: LlmProvider,
H: AgentHooks,
{
let max_retries = config.retry.max_retries;
let mut attempt = 0u32;
loop {
let outcome = match provider.chat(request.clone()).await {
Ok(o) => o,
Err(e) => {
return Err(AgentError::new(format!("LLM error: {e}"), false));
}
};
match outcome {
ChatOutcome::Success(response) => return Ok(response),
ChatOutcome::RateLimited => {
attempt += 1;
if attempt > max_retries {
error!("Rate limited by LLM provider after {max_retries} retries");
let error_msg = format!("Rate limited after {max_retries} retries");
send_event(tx, hooks, seq, AgentEvent::error(&error_msg, true)).await;
return Err(AgentError::new(error_msg, true));
}
let delay = calculate_backoff_delay(attempt, &config.retry);
warn!(
"Rate limited, retrying after backoff (attempt={}, delay_ms={})",
attempt,
delay.as_millis()
);
sleep(delay).await;
}
ChatOutcome::InvalidRequest(msg) => {
error!("Invalid request to LLM: {msg}");
return Err(AgentError::new(format!("Invalid request: {msg}"), false));
}
ChatOutcome::ServerError(msg) => {
attempt += 1;
if attempt > max_retries {
error!("LLM server error after {max_retries} retries: {msg}");
let error_msg = format!("Server error after {max_retries} retries: {msg}");
send_event(tx, hooks, seq, AgentEvent::error(&error_msg, true)).await;
return Err(AgentError::new(error_msg, true));
}
let delay = calculate_backoff_delay(attempt, &config.retry);
warn!(
"Server error, retrying after backoff (attempt={attempt}, delay_ms={}, error={msg})",
delay.as_millis()
);
sleep(delay).await;
}
}
}
}
pub(super) async fn call_llm_streaming<P, H>(
provider: &Arc<P>,
request: ChatRequest,
config: &AgentConfig,
tx: &mpsc::Sender<AgentEventEnvelope>,
hooks: &Arc<H>,
seq: &SequenceCounter,
ids: (&str, &str),
) -> Result<ChatResponse, AgentError>
where
P: LlmProvider,
H: AgentHooks,
{
let (message_id, thinking_id) = ids;
let max_retries = config.retry.max_retries;
let mut attempt = 0u32;
loop {
let result =
process_stream(provider, &request, tx, hooks, seq, message_id, thinking_id).await;
match result {
Ok(response) => return Ok(response),
Err(StreamError::Recoverable(msg)) => {
attempt += 1;
if attempt > max_retries {
error!("Streaming error after {max_retries} retries: {msg}");
let err_msg = format!("Streaming error after {max_retries} retries: {msg}");
send_event(tx, hooks, seq, AgentEvent::error(&err_msg, true)).await;
return Err(AgentError::new(err_msg, true));
}
let delay = calculate_backoff_delay(attempt, &config.retry);
warn!(
"Streaming error, retrying (attempt={attempt}, delay_ms={}, error={msg})",
delay.as_millis()
);
sleep(delay).await;
}
Err(StreamError::Fatal(msg)) => {
error!("Streaming error (non-recoverable): {msg}");
return Err(AgentError::new(format!("Streaming error: {msg}"), false));
}
}
}
}
async fn process_stream<P, H>(
provider: &Arc<P>,
request: &ChatRequest,
tx: &mpsc::Sender<AgentEventEnvelope>,
hooks: &Arc<H>,
seq: &SequenceCounter,
message_id: &str,
thinking_id: &str,
) -> Result<ChatResponse, StreamError>
where
P: LlmProvider,
H: AgentHooks,
{
let mut stream = std::pin::pin!(provider.chat_stream(request.clone()));
let mut accumulator = StreamAccumulator::new();
let mut delta_count: u64 = 0;
log::debug!("Starting to consume LLM stream");
let mut channel_closed = false;
while let Some(result) = stream.next().await {
if delta_count > 0 && delta_count.is_multiple_of(50) {
log::debug!("Stream progress: delta_count={delta_count}");
}
match result {
Ok(delta) => {
delta_count += 1;
accumulator.apply(&delta);
match &delta {
StreamDelta::TextDelta { delta, .. } => {
if !channel_closed {
if tx.is_closed() {
log::warn!(
"Event channel closed by receiver at delta_count={delta_count} - consumer may have disconnected"
);
channel_closed = true;
} else {
send_event(
tx,
hooks,
seq,
AgentEvent::text_delta(message_id, delta.clone()),
)
.await;
}
}
}
StreamDelta::ThinkingDelta { delta, .. } => {
if !channel_closed {
if tx.is_closed() {
log::warn!(
"Event channel closed by receiver at delta_count={delta_count}"
);
channel_closed = true;
} else {
send_event(
tx,
hooks,
seq,
AgentEvent::thinking_delta(thinking_id, delta.clone()),
)
.await;
}
}
}
StreamDelta::Error {
message,
recoverable,
} => {
log::warn!(
"Stream error received delta_count={delta_count} message={message} recoverable={recoverable}"
);
return if *recoverable {
Err(StreamError::Recoverable(message.clone()))
} else {
Err(StreamError::Fatal(message.clone()))
};
}
StreamDelta::Done { .. }
| StreamDelta::Usage(_)
| StreamDelta::ToolUseStart { .. }
| StreamDelta::ToolInputDelta { .. }
| StreamDelta::SignatureDelta { .. }
| StreamDelta::RedactedThinking { .. } => {}
}
}
Err(e) => {
log::error!("Stream iteration error delta_count={delta_count} error={e}");
return Err(StreamError::Recoverable(format!("Stream error: {e}")));
}
}
}
log::debug!("Stream while loop exited normally at delta_count={delta_count}");
let usage = accumulator.usage().cloned().unwrap_or(Usage {
input_tokens: 0,
output_tokens: 0,
cached_input_tokens: 0,
});
let stop_reason = accumulator.stop_reason().copied();
let content_blocks = accumulator.into_content_blocks();
log::debug!(
"LLM stream completed successfully delta_count={delta_count} stop_reason={stop_reason:?} content_block_count={} input_tokens={} output_tokens={}",
content_blocks.len(),
usage.input_tokens,
usage.output_tokens
);
Ok(ChatResponse {
id: uuid::Uuid::new_v4().to_string(),
content: content_blocks,
model: provider.model().to_string(),
stop_reason,
usage,
})
}