use super::helpers::{calculate_backoff_delay, send_event};
use super::types::{LlmEventContext, LlmOutcome, LlmStreamIds, StreamError};
use crate::events::AgentEvent;
use crate::hooks::AgentHooks;
use crate::llm::{
ChatOutcome, ChatRequest, ChatResponse, LlmProvider, StreamAccumulator, StreamDelta, Usage,
};
use crate::types::{AgentConfig, AgentError};
use futures::StreamExt;
use log::{error, warn};
use std::sync::Arc;
use tokio::time::sleep;
#[cfg(feature = "otel")]
pub(super) struct LlmSpanObserver<'a> {
pub(super) span: &'a mut opentelemetry::global::BoxedSpan,
pub(super) provider_name: &'static str,
pub(super) request_model: &'a str,
}
#[cfg(feature = "otel")]
impl LlmSpanObserver<'_> {
fn record_first_chunk(&mut self, turn: usize, streaming: bool, ttfc_secs: f64) {
use crate::observability::{attrs, metrics, spans};
use opentelemetry::KeyValue;
use opentelemetry::trace::Span;
spans::add_event(
self.span,
"llm.stream.first_chunk",
vec![
attrs::kv_i64(attrs::SDK_TURN_NUMBER, i64::try_from(turn).unwrap_or(0)),
attrs::kv_bool(attrs::SDK_LLM_STREAMING, streaming),
],
);
self.span.set_attribute(KeyValue::new(
attrs::GEN_AI_RESPONSE_TIME_TO_FIRST_CHUNK,
ttfc_secs,
));
let metrics_handle = metrics::Metrics::global();
metrics_handle.time_to_first_chunk.record(
ttfc_secs,
&[
KeyValue::new(attrs::GEN_AI_OPERATION_NAME, "chat"),
KeyValue::new(attrs::GEN_AI_PROVIDER_NAME, self.provider_name),
KeyValue::new(attrs::GEN_AI_REQUEST_MODEL, self.request_model.to_string()),
],
);
}
fn record_subsequent_chunk(&self, tpoc_secs: f64) {
use crate::observability::{attrs, metrics};
use opentelemetry::KeyValue;
let metrics_handle = metrics::Metrics::global();
metrics_handle.time_per_output_chunk.record(
tpoc_secs,
&[
KeyValue::new(attrs::GEN_AI_OPERATION_NAME, "chat"),
KeyValue::new(attrs::GEN_AI_PROVIDER_NAME, self.provider_name),
KeyValue::new(attrs::GEN_AI_REQUEST_MODEL, self.request_model.to_string()),
],
);
}
fn record_completed(&mut self, delta_count: u64, duration_ms: u64) {
use crate::observability::{attrs, spans};
spans::add_event(
self.span,
"llm.stream.completed",
vec![
attrs::kv_i64(
attrs::SDK_LLM_STREAM_DELTA_COUNT,
i64::try_from(delta_count).unwrap_or(i64::MAX),
),
attrs::kv_i64(
attrs::SDK_LLM_STREAM_DURATION_MS,
i64::try_from(duration_ms).unwrap_or(i64::MAX),
),
],
);
}
fn record_dropped(&mut self, reason: &'static str, delta_count: u64, error_type: &str) {
use crate::observability::{attrs, spans};
spans::add_event(
self.span,
"llm.stream.dropped",
vec![
opentelemetry::KeyValue::new(attrs::SDK_LLM_STREAM_DROP_REASON, reason),
attrs::kv_i64(
attrs::SDK_LLM_STREAM_DELTA_COUNT,
i64::try_from(delta_count).unwrap_or(i64::MAX),
),
opentelemetry::KeyValue::new(attrs::ERROR_TYPE, error_type.to_string()),
],
);
}
fn record_retry(&mut self, attempt: u32, max_attempts: u32, delay_ms: u64, error_type: &str) {
use crate::observability::{attrs, metrics, spans};
spans::add_event(
self.span,
"llm.retry",
vec![
opentelemetry::KeyValue::new(attrs::GEN_AI_PROVIDER_NAME, self.provider_name),
attrs::kv_i64(attrs::SDK_LLM_RETRY_ATTEMPT, i64::from(attempt)),
attrs::kv_i64(attrs::SDK_LLM_RETRY_MAX_ATTEMPTS, i64::from(max_attempts)),
attrs::kv_i64(
attrs::SDK_LLM_RETRY_DELAY_MS,
i64::try_from(delay_ms).unwrap_or(i64::MAX),
),
opentelemetry::KeyValue::new(attrs::ERROR_TYPE, error_type.to_string()),
],
);
let metrics_handle = metrics::Metrics::global();
metrics_handle.llm_retries.add(
1,
&[
opentelemetry::KeyValue::new(attrs::GEN_AI_PROVIDER_NAME, self.provider_name),
opentelemetry::KeyValue::new(attrs::ERROR_TYPE, error_type.to_string()),
],
);
}
}
enum ProviderCall {
Outcome(ChatOutcome),
Cancelled,
Error(AgentError),
}
async fn chat_or_cancel<P, H>(
provider: &Arc<P>,
request: &ChatRequest,
event_ctx: &LlmEventContext<'_, H>,
) -> ProviderCall
where
P: LlmProvider,
H: AgentHooks,
{
tokio::select! {
biased;
() = event_ctx.cancel_token.cancelled() => {
log::info!("LLM call cancelled (turn={})", event_ctx.turn);
ProviderCall::Cancelled
}
res = provider.chat(request.clone()) => match res {
Ok(outcome) => ProviderCall::Outcome(outcome),
Err(e) => ProviderCall::Error(AgentError::new(format!("LLM error: {e}"), false)),
},
}
}
pub(super) async fn call_llm_with_retry<P, H>(
provider: &Arc<P>,
request: ChatRequest,
config: &AgentConfig,
event_ctx: &LlmEventContext<'_, H>,
#[cfg(feature = "otel")] mut span_observer: Option<LlmSpanObserver<'_>>,
) -> (LlmOutcome, u32)
where
P: LlmProvider,
H: AgentHooks,
{
let max_retries = config.retry.max_retries;
let mut attempt = 0u32;
loop {
let outcome = match chat_or_cancel(provider, &request, event_ctx).await {
ProviderCall::Outcome(outcome) => outcome,
ProviderCall::Cancelled => return (LlmOutcome::Cancelled, attempt),
ProviderCall::Error(error) => return (LlmOutcome::Error(error), attempt),
};
let (kind, retry_reason, failure_message) = match outcome {
ChatOutcome::Success(response) => {
if attempt > 0 {
send_auto_retry_end_event(event_ctx, attempt, true, None).await;
}
return (LlmOutcome::Response(response), attempt);
}
ChatOutcome::InvalidRequest(msg) => {
error!("Invalid request to LLM: {msg}");
return (
LlmOutcome::Error(AgentError::new(format!("Invalid request: {msg}"), false)),
attempt,
);
}
ChatOutcome::RateLimited => (
"rate_limited",
"Rate limited by LLM provider".to_string(),
format!("Rate limited after {max_retries} retries"),
),
ChatOutcome::ServerError(msg) => (
"server_error",
msg.clone(),
format!("Server error after {max_retries} retries: {msg}"),
),
_ => (
"server_error",
"Unrecognized provider outcome".to_string(),
format!("Unrecognized provider outcome after {max_retries} retries"),
),
};
attempt += 1;
match handle_retry_backoff(RetryBackoff {
event_ctx,
config,
attempt,
max_retries,
error_kind: kind,
retry_reason,
failure_message,
#[cfg(feature = "otel")]
span_observer: span_observer.as_mut(),
#[cfg(not(feature = "otel"))]
_observer: std::marker::PhantomData,
})
.await
{
RetryStep::Retry => {}
RetryStep::Cancelled => return (LlmOutcome::Cancelled, attempt),
RetryStep::GiveUp(outcome) => return (outcome, attempt),
}
}
}
enum RetryStep {
Retry,
Cancelled,
GiveUp(LlmOutcome),
}
struct RetryBackoff<'a, 'o, H> {
event_ctx: &'a LlmEventContext<'a, H>,
config: &'a AgentConfig,
attempt: u32,
max_retries: u32,
error_kind: &'static str,
retry_reason: String,
failure_message: String,
#[cfg(feature = "otel")]
span_observer: Option<&'a mut LlmSpanObserver<'o>>,
#[cfg(not(feature = "otel"))]
_observer: std::marker::PhantomData<&'o ()>,
}
async fn handle_retry_backoff<H>(params: RetryBackoff<'_, '_, H>) -> RetryStep
where
H: AgentHooks,
{
let RetryBackoff {
event_ctx,
config,
attempt,
max_retries,
error_kind,
retry_reason,
failure_message,
#[cfg(feature = "otel")]
span_observer,
#[cfg(not(feature = "otel"))]
_observer: _,
} = params;
if attempt > max_retries {
error!("LLM {error_kind} exhausted retries: {failure_message}");
send_auto_retry_end_event(event_ctx, attempt - 1, false, Some(failure_message.clone()))
.await;
if let Err(error) = send_llm_error_event(event_ctx, &failure_message).await {
return RetryStep::GiveUp(LlmOutcome::Error(error));
}
return RetryStep::GiveUp(LlmOutcome::Error(AgentError::new(failure_message, true)));
}
let delay = calculate_backoff_delay(attempt, &config.retry);
let delay_ms = u64::try_from(delay.as_millis()).unwrap_or(u64::MAX);
warn!("LLM {error_kind}, retrying (attempt={attempt}, delay_ms={delay_ms})");
#[cfg(feature = "otel")]
if let Some(observer) = span_observer {
observer.record_retry(attempt, max_retries, delay_ms, error_kind);
}
send_auto_retry_start_event(event_ctx, attempt, max_retries, delay_ms, &retry_reason).await;
if sleep_or_cancel(delay, event_ctx.cancel_token)
.await
.is_break()
{
return RetryStep::Cancelled;
}
RetryStep::Retry
}
async fn sleep_or_cancel(
delay: std::time::Duration,
cancel_token: &tokio_util::sync::CancellationToken,
) -> std::ops::ControlFlow<()> {
tokio::select! {
biased;
() = cancel_token.cancelled() => std::ops::ControlFlow::Break(()),
() = sleep(delay) => std::ops::ControlFlow::Continue(()),
}
}
pub(super) async fn call_llm_streaming<P, H>(
provider: &Arc<P>,
request: ChatRequest,
config: &AgentConfig,
event_ctx: &LlmEventContext<'_, H>,
stream_ids: LlmStreamIds<'_>,
#[cfg(feature = "otel")] mut span_observer: Option<LlmSpanObserver<'_>>,
) -> (LlmOutcome, u32)
where
P: LlmProvider,
H: AgentHooks,
{
let max_retries = config.retry.max_retries;
let mut attempt = 0u32;
loop {
let result = process_stream(
provider,
&request,
event_ctx,
stream_ids,
#[cfg(feature = "otel")]
span_observer.as_mut(),
)
.await;
match result {
Ok(response) => {
if attempt > 0 {
send_auto_retry_end_event(event_ctx, attempt, true, None).await;
}
return (LlmOutcome::Response(response), attempt);
}
Err(StreamError::Recoverable(msg)) => {
attempt += 1;
if attempt > max_retries {
error!("Streaming error after {max_retries} retries: {msg}");
let err_msg = format!("Streaming error after {max_retries} retries: {msg}");
send_auto_retry_end_event(event_ctx, attempt - 1, false, Some(err_msg.clone()))
.await;
if let Err(error) = send_llm_error_event(event_ctx, &err_msg).await {
return (LlmOutcome::Error(error), attempt);
}
return (LlmOutcome::Error(AgentError::new(err_msg, true)), attempt);
}
let delay = calculate_backoff_delay(attempt, &config.retry);
warn!(
"Streaming error, retrying (attempt={attempt}, delay_ms={}, error={msg})",
delay.as_millis()
);
let delay_ms = u64::try_from(delay.as_millis()).unwrap_or(u64::MAX);
#[cfg(feature = "otel")]
if let Some(observer) = span_observer.as_mut() {
observer.record_retry(attempt, max_retries, delay_ms, "stream_error");
}
send_auto_retry_start_event(event_ctx, attempt, max_retries, delay_ms, &msg).await;
if sleep_or_cancel(delay, event_ctx.cancel_token)
.await
.is_break()
{
return (LlmOutcome::Cancelled, attempt);
}
}
Err(StreamError::Fatal(msg)) => {
error!("Streaming error (non-recoverable): {msg}");
return (
LlmOutcome::Error(AgentError::new(format!("Streaming error: {msg}"), false)),
attempt,
);
}
Err(StreamError::Cancelled) => {
return (LlmOutcome::Cancelled, attempt);
}
}
}
}
async fn process_stream<P, H>(
provider: &Arc<P>,
request: &ChatRequest,
event_ctx: &LlmEventContext<'_, H>,
stream_ids: LlmStreamIds<'_>,
#[cfg(feature = "otel")] mut span_observer: Option<&mut LlmSpanObserver<'_>>,
) -> Result<ChatResponse, StreamError>
where
P: LlmProvider,
H: AgentHooks,
{
let mut stream = std::pin::pin!(provider.chat_stream(request.clone()));
let mut accumulator = StreamAccumulator::new();
let mut delta_count: u64 = 0;
#[cfg(feature = "otel")]
let stream_started_at = std::time::Instant::now();
#[cfg(feature = "otel")]
let mut first_chunk_recorded = false;
#[cfg(feature = "otel")]
let mut last_chunk_at: Option<std::time::Instant> = None;
log::debug!("Starting to consume LLM stream");
loop {
let result = tokio::select! {
biased;
() = event_ctx.cancel_token.cancelled() => {
log::info!(
"LLM stream cancelled (turn={}, delta_count={delta_count})",
event_ctx.turn
);
#[cfg(feature = "otel")]
if let Some(observer) = span_observer.as_mut() {
observer.record_dropped("cancelled", delta_count, "cancelled");
}
return Err(StreamError::Cancelled);
}
next = stream.next() => match next {
Some(result) => result,
None => break,
},
};
if delta_count > 0 && delta_count.is_multiple_of(50) {
log::debug!("Stream progress: delta_count={delta_count}");
}
let delta = match result {
Ok(d) => d,
Err(e) => {
log::error!("Stream iteration error delta_count={delta_count} error={e}");
#[cfg(feature = "otel")]
if let Some(observer) = span_observer.as_mut() {
observer.record_dropped("recoverable_error", delta_count, "stream_error");
}
return Err(StreamError::Recoverable(format!("Stream error: {e}")));
}
};
delta_count += 1;
accumulator.apply(&delta);
#[cfg(feature = "otel")]
if is_content_delta(&delta) {
let now = std::time::Instant::now();
if !first_chunk_recorded {
first_chunk_recorded = true;
if let Some(observer) = span_observer.as_mut() {
let ttfc_secs = stream_started_at.elapsed().as_secs_f64();
observer.record_first_chunk(event_ctx.turn, true, ttfc_secs);
}
} else if let Some(prev) = last_chunk_at
&& let Some(observer) = span_observer.as_ref()
{
let tpoc_secs = now.duration_since(prev).as_secs_f64();
observer.record_subsequent_chunk(tpoc_secs);
}
last_chunk_at = Some(now);
}
if let Some(stream_err) = dispatch_stream_delta(
&delta,
event_ctx,
stream_ids,
delta_count,
#[cfg(feature = "otel")]
span_observer.as_deref_mut(),
)
.await
{
return Err(stream_err);
}
}
log::debug!("Stream while loop exited normally at delta_count={delta_count}");
#[cfg(feature = "otel")]
if let Some(observer) = span_observer.as_mut() {
let duration_ms =
u64::try_from(stream_started_at.elapsed().as_millis()).unwrap_or(u64::MAX);
observer.record_completed(delta_count, duration_ms);
}
Ok(finalize_stream_response(
accumulator,
provider.model(),
delta_count,
))
}
fn finalize_stream_response(
accumulator: StreamAccumulator,
model: &str,
delta_count: u64,
) -> ChatResponse {
let usage = accumulator.usage().cloned().unwrap_or(Usage {
input_tokens: 0,
output_tokens: 0,
cached_input_tokens: 0,
cache_creation_input_tokens: 0,
});
let stop_reason = accumulator.stop_reason().copied();
let content_blocks = accumulator.into_content_blocks();
log::debug!(
"LLM stream completed successfully delta_count={delta_count} stop_reason={stop_reason:?} content_block_count={} input_tokens={} output_tokens={}",
content_blocks.len(),
usage.input_tokens,
usage.output_tokens
);
ChatResponse {
id: uuid::Uuid::new_v4().to_string(),
content: content_blocks,
model: model.to_string(),
stop_reason,
usage,
}
}
async fn dispatch_stream_delta<H>(
delta: &StreamDelta,
event_ctx: &LlmEventContext<'_, H>,
stream_ids: LlmStreamIds<'_>,
delta_count: u64,
#[cfg(feature = "otel")] mut span_observer: Option<&mut LlmSpanObserver<'_>>,
) -> Option<StreamError>
where
H: AgentHooks,
{
let LlmStreamIds {
message_id,
thinking_id,
} = stream_ids;
match delta {
StreamDelta::TextDelta { delta, .. } => {
if let Err(error) = send_event(
event_ctx.event_store,
event_ctx.thread_id,
event_ctx.turn,
event_ctx.hooks,
event_ctx.authority,
AgentEvent::text_delta(message_id, delta.clone()),
)
.await
{
#[cfg(feature = "otel")]
if let Some(observer) = span_observer.as_mut() {
observer.record_dropped(
"event_channel_send_failed",
delta_count,
"event_channel",
);
}
return Some(StreamError::Fatal(error.message));
}
}
StreamDelta::ThinkingDelta { delta, .. } => {
if let Err(error) = send_event(
event_ctx.event_store,
event_ctx.thread_id,
event_ctx.turn,
event_ctx.hooks,
event_ctx.authority,
AgentEvent::thinking_delta(thinking_id, delta.clone()),
)
.await
{
#[cfg(feature = "otel")]
if let Some(observer) = span_observer.as_mut() {
observer.record_dropped(
"event_channel_send_failed",
delta_count,
"event_channel",
);
}
return Some(StreamError::Fatal(error.message));
}
}
StreamDelta::Error { message, kind } => {
log::warn!(
"Stream error received delta_count={delta_count} message={message} kind={kind:?}"
);
let recoverable = kind.is_recoverable();
#[cfg(feature = "otel")]
if let Some(observer) = span_observer.as_mut() {
let reason = if recoverable {
"recoverable_error"
} else {
"fatal_error"
};
observer.record_dropped(reason, delta_count, stream_error_kind_attr(*kind));
}
return Some(if recoverable {
StreamError::Recoverable(message.clone())
} else {
StreamError::Fatal(message.clone())
});
}
_ => {}
}
None
}
#[cfg(feature = "otel")]
const fn stream_error_kind_attr(kind: crate::llm::StreamErrorKind) -> &'static str {
match kind {
crate::llm::StreamErrorKind::RateLimited => "rate_limited",
crate::llm::StreamErrorKind::ServerError => "server_error",
crate::llm::StreamErrorKind::InvalidRequest => "invalid_request",
_ => "unknown",
}
}
#[cfg(feature = "otel")]
const fn is_content_delta(delta: &StreamDelta) -> bool {
matches!(
delta,
StreamDelta::TextDelta { .. }
| StreamDelta::ThinkingDelta { .. }
| StreamDelta::ToolUseStart { .. }
| StreamDelta::ToolInputDelta { .. }
| StreamDelta::SignatureDelta { .. }
| StreamDelta::RedactedThinking { .. }
)
}
async fn send_llm_error_event<H>(
event_ctx: &LlmEventContext<'_, H>,
error_msg: &str,
) -> Result<(), AgentError>
where
H: AgentHooks,
{
send_event(
event_ctx.event_store,
event_ctx.thread_id,
event_ctx.turn,
event_ctx.hooks,
event_ctx.authority,
AgentEvent::error(error_msg, true),
)
.await
}
async fn send_auto_retry_start_event<H>(
event_ctx: &LlmEventContext<'_, H>,
attempt: u32,
max_attempts: u32,
delay_ms: u64,
error_message: &str,
) where
H: AgentHooks,
{
let _ = send_event(
event_ctx.event_store,
event_ctx.thread_id,
event_ctx.turn,
event_ctx.hooks,
event_ctx.authority,
AgentEvent::AutoRetryStart {
attempt,
max_attempts,
delay_ms,
error_message: error_message.to_string(),
},
)
.await;
}
async fn send_auto_retry_end_event<H>(
event_ctx: &LlmEventContext<'_, H>,
attempt: u32,
success: bool,
final_error: Option<String>,
) where
H: AgentHooks,
{
let _ = send_event(
event_ctx.event_store,
event_ctx.thread_id,
event_ctx.turn,
event_ctx.hooks,
event_ctx.authority,
AgentEvent::AutoRetryEnd {
attempt,
success,
final_error,
},
)
.await;
}