use crate::error::AppError;
use crate::handlers::AppState;
use crate::metrics::Metrics;
use crate::middleware::RequestId;
use crate::models::ModelSelector;
use crate::shared::query::record_routing_metrics;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use super::find_endpoint_by_name;
use axum::{
Extension, Json,
extract::State,
response::{
IntoResponse, Response,
sse::{Event, KeepAlive, Sse},
},
};
use futures::stream::{self, StreamExt};
use std::convert::Infallible;
use std::time::Duration;
use super::types::{
ChatCompletionChunk, ChatCompletionRequest, ModelChoice, TimestampResult, current_timestamp,
};
fn serialize_chunk(chunk: &ChatCompletionChunk, request_id: &RequestId) -> String {
match serde_json::to_string(chunk) {
Ok(json) => json,
Err(e) => {
tracing::error!(
request_id = %request_id,
error = %e,
chunk_id = %chunk.id,
"BUG: ChatCompletionChunk serialization failed. This indicates a bug in the \
struct definition (e.g., non-serializable field added). Returning error event."
);
format!(
r#"{{"error":"serialization_failed","request_id":"{}","chunk_id":"{}"}}"#,
request_id, chunk.id
)
}
}
}
pub async fn handler(
State(state): State<AppState>,
Extension(request_id): Extension<RequestId>,
Json(request): Json<ChatCompletionRequest>,
) -> Result<Response, AppError> {
tracing::debug!(
request_id = %request_id,
model = ?request.model(),
messages_count = request.messages().len(),
"Received streaming chat completions request"
);
let prompt = request.to_prompt_string();
let request_temperature = request.temperature();
let request_max_tokens = request.max_tokens();
let (endpoint, target_tier) = if let ModelChoice::Specific(name) = request.model() {
let (tier, endpoint) = find_endpoint_by_name(state.config(), name)?;
tracing::info!(
request_id = %request_id,
model_name = %name,
endpoint_name = %endpoint.name(),
target_tier = ?tier,
"Specific model selection - streaming directly to endpoint"
);
let decision =
crate::router::RoutingDecision::new(tier, crate::router::RoutingStrategy::Rule);
record_routing_metrics(&state, &decision, 0.0, request_id);
(endpoint, tier)
} else {
let decision = match request.model() {
ModelChoice::Auto => {
let metadata = request.to_route_metadata();
let routing_start = std::time::Instant::now();
let decision = state
.router()
.route(&prompt, &metadata, state.selector())
.await?;
let routing_duration_ms = routing_start.elapsed().as_secs_f64() * 1000.0;
tracing::info!(
request_id = %request_id,
target_tier = ?decision.target(),
routing_strategy = ?decision.strategy(),
routing_duration_ms = %routing_duration_ms,
"Routing decision made (auto, streaming)"
);
record_routing_metrics(&state, &decision, routing_duration_ms, request_id);
decision
}
ModelChoice::Fast | ModelChoice::Balanced | ModelChoice::Deep => {
let tier = match request.model() {
ModelChoice::Fast => crate::router::TargetModel::Fast,
ModelChoice::Balanced => crate::router::TargetModel::Balanced,
ModelChoice::Deep => crate::router::TargetModel::Deep,
_ => unreachable!("outer match arm guarantees Fast/Balanced/Deep"),
};
let decision =
crate::router::RoutingDecision::new(tier, crate::router::RoutingStrategy::Rule);
tracing::info!(
request_id = %request_id,
target_tier = ?tier,
"Direct tier selection (streaming)"
);
record_routing_metrics(&state, &decision, 0.0, request_id);
decision
}
ModelChoice::Specific(_) => unreachable!("handled above"),
};
let failed_endpoints = crate::models::ExclusionSet::new();
let endpoint = state
.selector()
.select(decision.target(), &failed_endpoints)
.await
.ok_or_else(|| {
AppError::RoutingFailed(format!(
"No available healthy endpoints for tier {:?}",
decision.target()
))
})?
.clone();
(endpoint, decision.target())
};
let effective_max_tokens = request_max_tokens.unwrap_or(endpoint.max_tokens() as u32);
let effective_temperature = request_temperature
.map(|t| t as f32)
.unwrap_or(endpoint.temperature() as f32);
let options = open_agent::AgentOptions::builder()
.model(endpoint.name())
.base_url(endpoint.base_url())
.max_tokens(effective_max_tokens)
.temperature(effective_temperature)
.build()
.map_err(|e| {
tracing::error!(
request_id = %request_id,
endpoint_name = %endpoint.name(),
max_tokens = effective_max_tokens,
temperature = effective_temperature,
error = %e,
"Failed to build AgentOptions for streaming"
);
AppError::ModelQuery(crate::error::ModelQueryError::AgentOptionsConfigError {
endpoint: endpoint.base_url().to_string(),
details: format!("{}", e),
})
})?;
let completion_id = format!("chatcmpl-{}", uuid::Uuid::new_v4().simple());
let TimestampResult {
timestamp: created,
warning: clock_warning,
} = current_timestamp(Some(state.metrics().as_ref()), Some(&request_id));
if let Some(w) = clock_warning {
tracing::warn!(request_id = %request_id, warning = %w, "Clock error during streaming");
}
let response_model = endpoint.name().to_string();
let timeout_seconds = state.config().timeout_for_tier(target_tier);
tracing::info!(
request_id = %request_id,
completion_id = %completion_id,
model = %response_model,
endpoint_name = %endpoint.name(),
timeout_seconds = timeout_seconds,
"Starting streaming response"
);
let stream = create_sse_stream(
prompt,
options,
completion_id,
response_model,
created,
request_id,
endpoint.name().to_string(),
target_tier,
timeout_seconds,
state.selector_arc(),
state.metrics(),
);
Ok(Sse::new(stream)
.keep_alive(
KeepAlive::new().interval(Duration::from_secs(15)).text(":"), )
.into_response())
}
#[allow(clippy::too_many_arguments)] fn create_sse_stream(
prompt: String,
options: open_agent::AgentOptions,
completion_id: String,
model: String,
created: i64,
request_id: RequestId,
endpoint_name: String,
target_tier: crate::router::TargetModel,
timeout_seconds: u64,
selector: Arc<ModelSelector>,
metrics: Arc<Metrics>,
) -> impl futures::Stream<Item = Result<Event, Infallible>> {
stream::once(async move {
let timeout_duration = Duration::from_secs(timeout_seconds);
let query_result = tokio::time::timeout(
timeout_duration,
open_agent::query(&prompt, &options),
)
.await;
let model_stream = match query_result {
Ok(Ok(s)) => s,
Ok(Err(e)) => {
tracing::error!(
request_id = %request_id,
endpoint_name = %endpoint_name,
error = %e,
"Failed to start streaming query"
);
if let Err(health_err) = selector.health_checker().mark_failure(&endpoint_name).await
{
tracing::warn!(
request_id = %request_id,
endpoint_name = %endpoint_name,
error = %health_err,
"Health tracking failed for streaming query failure"
);
metrics.health_tracking_failure(&endpoint_name, health_err.error_type());
}
let error_chunk = ChatCompletionChunk::content(
&completion_id,
&model,
created,
&format!("[Error: Failed to start model query. Request ID: {}. Please retry.]", request_id),
);
return stream::iter(vec![
Ok(Event::default().data(serialize_chunk(&error_chunk, &request_id))),
Ok(Event::default().data("[DONE]")),
])
.boxed();
}
Err(_elapsed) => {
tracing::error!(
request_id = %request_id,
endpoint_name = %endpoint_name,
timeout_seconds = timeout_seconds,
"Streaming query timed out waiting for initial connection"
);
if let Err(health_err) = selector.health_checker().mark_failure(&endpoint_name).await
{
tracing::warn!(
request_id = %request_id,
endpoint_name = %endpoint_name,
error = %health_err,
"Health tracking failed for streaming query timeout"
);
metrics.health_tracking_failure(&endpoint_name, health_err.error_type());
}
let error_chunk = ChatCompletionChunk::content(
&completion_id,
&model,
created,
&format!("[Error: Request timed out. Request ID: {}. Please retry.]", request_id),
);
return stream::iter(vec![
Ok(Event::default().data(serialize_chunk(&error_chunk, &request_id))),
Ok(Event::default().data("[DONE]")),
])
.boxed();
}
};
let initial = ChatCompletionChunk::initial(&completion_id, &model, created);
let initial_event = Ok(Event::default().data(serialize_chunk(&initial, &request_id)));
let request_id_for_finish = request_id;
let error_occurred = Arc::new(AtomicBool::new(false));
let content_stream = model_stream
.filter_map({
let completion_id = completion_id.clone();
let model = model.clone();
let request_id = request_id;
let endpoint_name = endpoint_name.clone();
let error_occurred = error_occurred.clone();
let metrics = metrics.clone();
move |result| {
let completion_id = completion_id.clone();
let model = model.clone();
let request_id = request_id;
let endpoint_name = endpoint_name.clone();
let error_occurred = error_occurred.clone();
let metrics = metrics.clone();
async move {
match result {
Ok(block) => {
use open_agent::ContentBlock;
match block {
ContentBlock::Text(text_block) => {
let chunk = ChatCompletionChunk::content(
&completion_id,
&model,
created,
&text_block.text,
);
Some(Ok(Event::default().data(
serialize_chunk(&chunk, &request_id),
)))
}
other_block => {
tracing::warn!(
request_id = %request_id,
endpoint_name = %endpoint_name,
block_type = ?other_block,
"Received non-text content block, skipping (not supported - text blocks only)"
);
None
}
}
}
Err(e) => {
error_occurred.store(true, Ordering::SeqCst);
metrics.mid_stream_failure(&endpoint_name);
tracing::error!(
request_id = %request_id,
endpoint_name = %endpoint_name,
error = %e,
"Stream error during content delivery - notifying client"
);
let error_chunk = ChatCompletionChunk::content(
&completion_id,
&model,
created,
&format!("[Stream Error: Content may be incomplete. Request ID: {}. Please retry.]", request_id),
);
Some(Ok(Event::default().data(
serialize_chunk(&error_chunk, &request_id),
)))
}
}
}
}
})
.boxed();
let finish_chunk = ChatCompletionChunk::finish(&completion_id, &model, created);
let finish_events = {
let error_occurred = error_occurred.clone();
let request_id = request_id_for_finish;
stream::once(async move {
if error_occurred.load(Ordering::SeqCst) {
tracing::debug!(
request_id = %request_id,
"Skipping finish chunk due to stream error"
);
vec![Ok(Event::default().data("[DONE]"))]
} else {
vec![
Ok(Event::default().data(serialize_chunk(&finish_chunk, &request_id))),
Ok(Event::default().data("[DONE]")),
]
}
})
.flat_map(stream::iter)
};
let success_tracker = {
let selector = selector.clone();
let metrics = metrics.clone();
let endpoint_name = endpoint_name.clone();
let request_id = request_id_for_finish;
let error_occurred = error_occurred.clone();
stream::once(async move {
if error_occurred.load(Ordering::SeqCst) {
tracing::debug!(
request_id = %request_id,
endpoint_name = %endpoint_name,
"Skipping health/metrics tracking due to stream error"
);
} else {
let tier_enum = match target_tier {
crate::router::TargetModel::Fast => crate::metrics::Tier::Fast,
crate::router::TargetModel::Balanced => crate::metrics::Tier::Balanced,
crate::router::TargetModel::Deep => crate::metrics::Tier::Deep,
};
if let Err(e) = metrics.record_model_invocation(tier_enum) {
metrics.metrics_recording_failure("record_model_invocation");
tracing::error!(
request_id = %request_id,
error = %e,
tier = ?tier_enum,
"Metrics recording failed. Observability degraded but stream continues."
);
}
if let Err(e) = selector.health_checker().mark_success(&endpoint_name).await {
tracing::warn!(
request_id = %request_id,
endpoint_name = %endpoint_name,
error = %e,
"Health tracking failed for successful streaming completion"
);
metrics.health_tracking_failure(&endpoint_name, e.error_type());
} else {
tracing::debug!(
request_id = %request_id,
endpoint_name = %endpoint_name,
"Marked endpoint healthy after successful stream completion"
);
}
}
None::<Result<Event, Infallible>>
})
.filter_map(|x| async { x })
};
stream::iter(vec![initial_event])
.chain(content_stream)
.chain(finish_events)
.chain(success_tracker)
.boxed()
})
.flatten()
}