use crate::error::AppError;
use crate::handlers::AppState;
use crate::middleware::RequestId;
use crate::shared::query::{
QueryConfig, SamplingParams, execute_query_with_retry, query_model, record_routing_metrics,
};
use axum::{
Extension, Json,
extract::State,
http::{HeaderName, HeaderValue},
response::{IntoResponse, Response},
};
use super::extractor::OpenAiJson;
use super::find_endpoint_by_name;
use super::types::{
ChatCompletion, ChatCompletionRequest, ModelChoice, TimestampResult, current_timestamp,
};
pub const X_OCTOROUTE_WARNING: &str = "x-octoroute-warning";
fn build_response_with_warnings<T: serde::Serialize>(body: T, warnings: &[String]) -> Response {
let json_response = Json(body).into_response();
if warnings.is_empty() {
return json_response;
}
let warning_value = warnings.join("; ");
let warning_value: String = warning_value
.chars()
.map(|c| {
if c.is_control() && c != ' ' {
' ' } else if !c.is_ascii() {
'?' } else {
c
}
})
.collect();
let warning_value = if warning_value.len() > 500 {
format!("{}...", &warning_value[..497])
} else {
warning_value
};
let (mut parts, body) = json_response.into_parts();
if let Ok(header_value) = HeaderValue::from_str(&warning_value) {
parts
.headers
.insert(HeaderName::from_static(X_OCTOROUTE_WARNING), header_value);
} else {
tracing::warn!(
original_warning = %warning_value,
warning_length = warning_value.len(),
warnings_count = warnings.len(),
"Warning header contains invalid HTTP characters even after sanitization, using fallback. \
Original warning logged for debugging."
);
let fallback = format!(
"health-tracking-degraded; warnings-count={}",
warnings.len()
);
parts.headers.insert(
HeaderName::from_static(X_OCTOROUTE_WARNING),
HeaderValue::from_str(&fallback).unwrap_or_else(|_| {
HeaderValue::from_static("health-tracking-degraded")
}),
);
}
Response::from_parts(parts, body)
}
pub async fn handler(
State(state): State<AppState>,
Extension(request_id): Extension<RequestId>,
OpenAiJson(request): OpenAiJson<ChatCompletionRequest>,
) -> Result<Response, AppError> {
tracing::debug!(
request_id = %request_id,
model = ?request.model(),
messages_count = request.messages().len(),
stream = request.stream(),
"Received chat completions request"
);
if request.stream() {
return super::streaming::handler(State(state), Extension(request_id), Json(request)).await;
}
let prompt = request.to_prompt_string();
let prompt_chars = prompt.chars().count();
let sampling_params = SamplingParams {
temperature: request.temperature(),
max_tokens: request.max_tokens(),
};
if let ModelChoice::Specific(name) = request.model() {
let (tier, endpoint) = find_endpoint_by_name(state.config(), name)?;
tracing::info!(
request_id = %request_id,
model_name = %name,
endpoint_name = %endpoint.name(),
target_tier = ?tier,
"Specific model selection - querying endpoint directly"
);
let decision =
crate::router::RoutingDecision::new(tier, crate::router::RoutingStrategy::Rule);
record_routing_metrics(&state, &decision, 0.0, request_id);
let timeout_seconds = state.config().timeout_for_tier(tier);
let content = match query_model(
&endpoint,
&prompt,
timeout_seconds,
request_id,
1,
1,
Some(&sampling_params),
)
.await
{
Ok(content) => content,
Err(e) => {
if let Err(health_err) = state
.selector()
.health_checker()
.mark_failure(endpoint.name())
.await
{
tracing::warn!(
request_id = %request_id,
endpoint_name = %endpoint.name(),
query_error = %e,
health_error = %health_err,
"Health tracking failed while marking endpoint failure"
);
state
.metrics()
.health_tracking_failure(endpoint.name(), health_err.error_type());
}
return Err(e);
}
};
let tier_enum = match tier {
crate::router::TargetModel::Fast => crate::metrics::Tier::Fast,
crate::router::TargetModel::Balanced => crate::metrics::Tier::Balanced,
crate::router::TargetModel::Deep => crate::metrics::Tier::Deep,
};
if let Err(e) = state.metrics().record_model_invocation(tier_enum) {
state
.metrics()
.metrics_recording_failure("record_model_invocation");
tracing::error!(
request_id = %request_id,
error = %e,
tier = ?tier_enum,
"Metrics recording failed. Observability degraded but request continues."
);
}
let mut warnings: Vec<String> = Vec::new();
if let Err(e) = state
.selector()
.health_checker()
.mark_success(endpoint.name())
.await
{
tracing::warn!(
request_id = %request_id,
endpoint_name = %endpoint.name(),
error = %e,
"Health tracking failed for specific model query"
);
state
.metrics()
.health_tracking_failure(endpoint.name(), e.error_type());
warnings.push(format!(
"Health tracking failed: {} (endpoint health state may be stale)",
e
));
}
let TimestampResult {
timestamp: created,
warning: clock_warning,
} = current_timestamp(Some(state.metrics().as_ref()), Some(&request_id));
if let Some(w) = clock_warning {
warnings.push(w);
}
let response =
ChatCompletion::new(content, endpoint.name().to_string(), prompt_chars, created);
tracing::info!(
request_id = %request_id,
model = %response.model,
response_length = response.choices[0].message.content().len(),
warnings_count = warnings.len(),
"Chat completion successful (specific model)"
);
return Ok(build_response_with_warnings(response, &warnings));
}
let decision = match request.model() {
ModelChoice::Auto => {
let metadata = request.to_route_metadata();
let routing_start = std::time::Instant::now();
let decision = state
.router()
.route(&prompt, &metadata, state.selector())
.await?;
let routing_duration_ms = routing_start.elapsed().as_secs_f64() * 1000.0;
tracing::info!(
request_id = %request_id,
target_tier = ?decision.target(),
routing_strategy = ?decision.strategy(),
routing_duration_ms = %routing_duration_ms,
"Routing decision made (auto)"
);
record_routing_metrics(&state, &decision, routing_duration_ms, request_id);
decision
}
ModelChoice::Fast | ModelChoice::Balanced | ModelChoice::Deep => {
let tier = match request.model() {
ModelChoice::Fast => crate::router::TargetModel::Fast,
ModelChoice::Balanced => crate::router::TargetModel::Balanced,
ModelChoice::Deep => crate::router::TargetModel::Deep,
_ => unreachable!("outer match arm guarantees Fast/Balanced/Deep"),
};
let decision =
crate::router::RoutingDecision::new(tier, crate::router::RoutingStrategy::Rule);
tracing::info!(
request_id = %request_id,
target_tier = ?tier,
"Direct tier selection (no routing)"
);
record_routing_metrics(&state, &decision, 0.0, request_id);
decision
}
ModelChoice::Specific(_) => unreachable!("handled above"),
};
let config = QueryConfig::default();
let result = execute_query_with_retry(
&state,
&decision,
&prompt,
request_id,
&config,
Some(&sampling_params),
)
.await?;
let response_model = result.endpoint.name().to_string();
let mut warnings = result.warnings;
let TimestampResult {
timestamp: created,
warning: clock_warning,
} = current_timestamp(Some(state.metrics().as_ref()), Some(&request_id));
if let Some(w) = clock_warning {
warnings.push(w);
}
let response = ChatCompletion::new(result.content, response_model, prompt_chars, created);
tracing::info!(
request_id = %request_id,
model = %response.model,
response_length = response.choices[0].message.content().len(),
warnings_count = warnings.len(),
"Chat completion successful"
);
Ok(build_response_with_warnings(response, &warnings))
}
#[cfg(test)]
mod tests {
#[allow(unused_imports)]
use super::*;
fn truncate_warning_for_header(warnings: &[String]) -> String {
let warning_value = warnings.join("; ");
let warning_value: String = warning_value
.chars()
.map(|c| {
if c.is_control() && c != ' ' {
' '
} else if !c.is_ascii() {
'?'
} else {
c
}
})
.collect();
if warning_value.len() > 500 {
format!("{}...", &warning_value[..497])
} else {
warning_value
}
}
#[test]
fn test_warning_truncation_ascii_only() {
let long_warning = "a".repeat(600);
let result = truncate_warning_for_header(&[long_warning]);
assert_eq!(result.len(), 500); assert!(result.ends_with("..."));
assert!(result.is_char_boundary(result.len())); }
#[test]
fn test_warning_truncation_sanitizes_multibyte_to_placeholder() {
let prefix = "x".repeat(495);
let emoji = "π¦"; let suffix = "y".repeat(100);
let warning = format!("{}{}{}", prefix, emoji, suffix);
let result = truncate_warning_for_header(&[warning]);
assert!(
std::str::from_utf8(result.as_bytes()).is_ok(),
"Truncated warning must be valid UTF-8"
);
assert!(
result.ends_with("..."),
"Truncated warning should end with ..."
);
assert_eq!(
result.len(),
500,
"Truncated warning should be exactly 500 bytes"
);
assert!(
result.is_ascii(),
"Result should be ASCII after sanitization"
);
}
#[test]
fn test_warning_truncation_chinese_characters_become_placeholders() {
let chinese = "δΈ".repeat(600); let result = truncate_warning_for_header(&[chinese]);
assert!(
std::str::from_utf8(result.as_bytes()).is_ok(),
"Truncated text must be valid UTF-8"
);
assert!(result.ends_with("..."));
assert_eq!(result.len(), 500);
assert!(
result.is_ascii(),
"Result should be ASCII after sanitization"
);
assert!(
result[..497].chars().all(|c| c == '?'),
"Chinese chars should become '?'"
);
}
#[test]
fn test_warning_truncation_mixed_multibyte_sanitizes_non_ascii() {
let mixed = format!(
"{}{}{}{}{}",
"a".repeat(200), "Γ©".repeat(150), "δΈ".repeat(100), "π¦".repeat(50), "z".repeat(50) );
assert!(
mixed.chars().count() > 500,
"Test setup: need >500 chars to trigger truncation"
);
let result = truncate_warning_for_header(&[mixed]);
assert!(
std::str::from_utf8(result.as_bytes()).is_ok(),
"Truncated mixed text must be valid UTF-8"
);
assert!(result.ends_with("..."));
assert_eq!(result.len(), 500);
assert!(
result.is_ascii(),
"Result should be ASCII after sanitization"
);
assert!(
result.starts_with(&"a".repeat(200)),
"First 200 should be 'a'"
);
assert!(
result[200..201].chars().all(|c| c == '?'),
"After 'a's should be '?' from sanitized Γ©"
);
}
#[test]
fn test_warning_under_limit_not_truncated() {
let short = "This is a short warning";
let result = truncate_warning_for_header(&[short.to_string()]);
assert_eq!(result, short);
assert!(!result.ends_with("..."));
}
#[test]
fn test_warning_exactly_500_chars_not_truncated() {
let exactly_500 = "a".repeat(500);
let result = truncate_warning_for_header(std::slice::from_ref(&exactly_500));
assert_eq!(result, exactly_500);
assert!(!result.ends_with("..."));
}
#[test]
fn test_warning_501_chars_gets_truncated() {
let chars_501 = "a".repeat(501);
let result = truncate_warning_for_header(&[chars_501]);
assert!(result.ends_with("..."));
assert_eq!(result.chars().count(), 500);
}
#[test]
fn test_build_response_with_valid_warning_header() {
use axum::http::HeaderName;
let body = serde_json::json!({"test": "value"});
let warnings = vec!["valid warning message".to_string()];
let response = build_response_with_warnings(body, &warnings);
let header = response
.headers()
.get(HeaderName::from_static(X_OCTOROUTE_WARNING));
assert!(header.is_some(), "Warning header should be present");
assert_eq!(header.unwrap().to_str().unwrap(), "valid warning message");
}
#[test]
fn test_build_response_with_newline_sanitizes() {
use axum::http::HeaderName;
let body = serde_json::json!({"test": "value"});
let warnings = vec!["warning with\nnewline".to_string()];
let response = build_response_with_warnings(body, &warnings);
let header = response
.headers()
.get(HeaderName::from_static(X_OCTOROUTE_WARNING));
assert!(header.is_some(), "Warning header should still be present");
assert_eq!(header.unwrap().to_str().unwrap(), "warning with newline");
}
#[test]
fn test_build_response_with_control_char_sanitizes() {
use axum::http::HeaderName;
let body = serde_json::json!({"test": "value"});
let warnings = vec!["warning with\x00null".to_string()];
let response = build_response_with_warnings(body, &warnings);
let header = response
.headers()
.get(HeaderName::from_static(X_OCTOROUTE_WARNING));
assert!(header.is_some());
assert_eq!(header.unwrap().to_str().unwrap(), "warning with null");
}
#[test]
fn test_build_response_empty_warnings_no_header() {
use axum::http::HeaderName;
let body = serde_json::json!({"test": "value"});
let warnings: Vec<String> = vec![];
let response = build_response_with_warnings(body, &warnings);
let header = response
.headers()
.get(HeaderName::from_static(X_OCTOROUTE_WARNING));
assert!(header.is_none(), "No warning header when warnings empty");
}
#[test]
fn test_build_response_with_non_ascii_sanitizes() {
use axum::http::HeaderName;
let body = serde_json::json!({"test": "value"});
let warnings = vec!["Health check failed for εδΊ¬-server π¦".to_string()];
let response = build_response_with_warnings(body, &warnings);
let header = response
.headers()
.get(HeaderName::from_static(X_OCTOROUTE_WARNING));
assert!(header.is_some(), "Warning header should be present");
let value = header.unwrap().to_str().unwrap();
assert_ne!(
value, "health-tracking-degraded",
"Should not fall back to generic message, should sanitize non-ASCII"
);
assert!(
value.contains("Health check failed"),
"Should preserve ASCII content. Got: {}",
value
);
}
#[test]
fn test_build_response_with_emoji_sanitizes() {
use axum::http::HeaderName;
let body = serde_json::json!({"test": "value"});
let warnings = vec!["Error π΄ warning π‘ info π’".to_string()];
let response = build_response_with_warnings(body, &warnings);
let header = response
.headers()
.get(HeaderName::from_static(X_OCTOROUTE_WARNING));
assert!(header.is_some());
let value = header.unwrap().to_str().unwrap();
assert!(
value.contains("Error"),
"Should preserve 'Error'. Got: {}",
value
);
assert!(
value.contains("warning"),
"Should preserve 'warning'. Got: {}",
value
);
assert!(
value.contains("info"),
"Should preserve 'info'. Got: {}",
value
);
}
}