use axum::{
body::Body,
extract::State,
http::{HeaderMap, HeaderName, HeaderValue, StatusCode},
response::{IntoResponse, Response},
Json,
};
use bytes::Bytes;
use futures_util::StreamExt;
use serde::Serialize;
use serde_json::Value;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tracing::{error, info, warn};
use oxllm_core::router::{AdaptivePriorityStrategy, RoutingStrategy};
use oxllm_core::state::{AppState, CircuitState, ProviderState};
use oxllm_core::telemetry::{TelemetryClient, TelemetryEvent};
#[derive(Serialize)]
struct ModelObject {
id: String,
object: &'static str,
created: u64,
owned_by: &'static str,
}
#[derive(Serialize)]
struct ModelsResponse {
object: &'static str,
data: Vec<ModelObject>,
}
pub async fn list_models(State(app_state): State<Arc<AppState>>) -> impl IntoResponse {
let mut data = Vec::new();
let now = Instant::now();
for (vm_name, targets) in &app_state.virtual_models {
let mut is_healthy = false;
for target in targets {
if let Some(provider) = app_state
.providers
.iter()
.find(|p| p.name == target.provider)
{
let circuit = *provider.circuit.read().await;
let is_tripped = match circuit {
CircuitState::Closed | CircuitState::HalfOpen => false,
CircuitState::Open { until } => now < until,
};
let rl = *provider.rate_limited_until.read().await;
let is_rate_limited = match rl {
Some(until) => now < until,
None => false,
};
if !is_tripped && !is_rate_limited {
is_healthy = true;
break;
}
}
}
if is_healthy {
data.push(ModelObject {
id: vm_name.clone(),
object: "model",
created: 1717070400, owned_by: "oxllm-virtual",
});
}
}
Json(ModelsResponse {
object: "list",
data,
})
}
pub async fn get_status(State(app_state): State<Arc<AppState>>) -> impl IntoResponse {
#[derive(Serialize)]
struct ProviderStatus {
name: String,
circuit: String,
failures: u32,
rate_limited: bool,
}
let mut status_list = Vec::new();
let now = Instant::now();
for provider in &app_state.providers {
let circ = *provider.circuit.read().await;
let circuit_str = match circ {
CircuitState::Closed => "Closed (Healthy)".to_string(),
CircuitState::HalfOpen => "Half-Open (Probing)".to_string(),
CircuitState::Open { until } => {
let left = until.saturating_duration_since(now).as_secs();
format!("Open (Cooldown: {}s left)", left)
},
};
let rl = *provider.rate_limited_until.read().await;
let is_limited = match rl {
Some(until) => now < until,
None => false,
};
let failures = *provider.consecutive_failures.read().await;
status_list.push(ProviderStatus {
name: provider.name.clone(),
circuit: circuit_str,
failures,
rate_limited: is_limited,
});
}
Json(status_list)
}
pub async fn create_embeddings(
State((app_state, telemetry)): State<(Arc<AppState>, TelemetryClient)>,
headers: HeaderMap,
body: Bytes, ) -> impl IntoResponse {
let mut payload: Value = match serde_json::from_slice(&body) {
Ok(p) => p,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
format!("Invalid JSON payload: {}", e),
)
.into_response()
},
};
let requested_model = match payload.get("model").and_then(|m| m.as_str()) {
Some(m) => m,
None => return (StatusCode::BAD_REQUEST, "Missing required 'model' field").into_response(),
};
let candidates = app_state.resolve_candidates(requested_model);
if candidates.is_empty() {
return (
StatusCode::BAD_REQUEST,
format!("Invalid or unmapped virtual model: {}", requested_model),
)
.into_response();
}
let (trace_id, parent_span_id) = extract_traceparent(&headers);
let strategy = AdaptivePriorityStrategy;
let mut attempts = 0;
let start_time = Instant::now();
let candidate_states: Vec<&ProviderState> = candidates.iter().map(|(p, _)| *p).collect();
while let Some(selected) = strategy.select(&candidate_states).await {
attempts += 1;
let target_model = candidates
.iter()
.find(|(p, _)| p.name == selected.name)
.map(|(_, m)| m.clone())
.unwrap();
payload["model"] = Value::String(target_model.clone());
let rewritten_body = Bytes::from(serde_json::to_vec(&payload).unwrap());
let endpoint_url = match selected.base_url.join("embeddings") {
Ok(url) => url,
Err(e) => {
error!("Invalid base URL path join for {}: {}", selected.name, e);
continue;
},
};
let mut req = app_state
.http_client
.post(endpoint_url.as_str())
.body(rewritten_body)
.timeout(Duration::from_secs(5)) .header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", selected.api_key));
if let Some(traceparent) = headers.get("traceparent") {
req = req.header("traceparent", traceparent);
}
info!(
"Embedding request routing to {} (attempt {})",
selected.name, attempts
);
let res = req.send().await;
match res {
Ok(res) if res.status().is_success() => {
let status_code = res.status().as_u16();
let upstream_headers = res.headers().clone();
let res_body = match res.bytes().await {
Ok(b) => b,
Err(e) => {
warn!(
"Failed to read success response body from {}: {}",
selected.name, e
);
let target_provider_state = app_state
.providers
.iter()
.find(|p| p.name == selected.name)
.unwrap();
strategy
.feedback(
target_provider_state,
false,
selected.is_probe,
Some(status_code),
None,
)
.await;
continue;
},
};
let target_provider_state = app_state
.providers
.iter()
.find(|p| p.name == selected.name)
.unwrap();
strategy
.feedback(
target_provider_state,
true,
selected.is_probe,
Some(status_code),
None,
)
.await;
telemetry.emit(TelemetryEvent::RecordTransaction {
operation: "embeddings".to_string(),
provider: selected.name.clone(),
model: target_model,
input_tokens: 0, output_tokens: 0,
duration: start_time.elapsed(),
attempts,
failure_reason: None,
trace_id: trace_id.clone(),
parent_span_id: parent_span_id.clone(),
});
let mut response = Response::new(Body::from(res_body));
*response.status_mut() = StatusCode::OK;
copy_response_headers(&upstream_headers, response.headers_mut());
return response.into_response();
},
Ok(res) => {
let status_code = res.status().as_u16();
warn!(
"Embedding request upstream {} failed with status {}",
selected.name, status_code
);
let retry_after = extract_retry_after(res.headers());
let target_provider_state = app_state
.providers
.iter()
.find(|p| p.name == selected.name)
.unwrap();
strategy
.feedback(
target_provider_state,
false,
selected.is_probe,
Some(status_code),
retry_after,
)
.await;
},
Err(e) => {
println!(
"Embedding request upstream {} connection failed: {:?}",
selected.name, e
);
warn!(
"Embedding request upstream {} connection failed: {}",
selected.name, e
);
let target_provider_state = app_state
.providers
.iter()
.find(|p| p.name == selected.name)
.unwrap();
strategy
.feedback(target_provider_state, false, selected.is_probe, None, None)
.await;
},
}
}
(
StatusCode::BAD_GATEWAY,
"All upstream embeddings providers failed or are rate-limited",
)
.into_response()
}
pub async fn create_chat_completions(
State((app_state, telemetry)): State<(Arc<AppState>, TelemetryClient)>,
headers: HeaderMap,
body: Bytes,
) -> impl IntoResponse {
let mut payload: Value = match serde_json::from_slice(&body) {
Ok(p) => p,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
format!("Invalid JSON payload: {}", e),
)
.into_response()
},
};
let requested_model = match payload.get("model").and_then(|m| m.as_str()) {
Some(m) => m,
None => return (StatusCode::BAD_REQUEST, "Missing required 'model' field").into_response(),
};
let candidates = app_state.resolve_candidates(requested_model);
if candidates.is_empty() {
return (
StatusCode::BAD_REQUEST,
format!("Invalid or unmapped virtual model: {}", requested_model),
)
.into_response();
}
let is_streaming = payload
.get("stream")
.and_then(|s| s.as_bool())
.unwrap_or(false);
let (trace_id, parent_span_id) = extract_traceparent(&headers);
let strategy = AdaptivePriorityStrategy;
let mut attempts = 0;
let start_time = Instant::now();
let candidate_states: Vec<&ProviderState> = candidates.iter().map(|(p, _)| *p).collect();
while let Some(selected) = strategy.select(&candidate_states).await {
attempts += 1;
let target_model = candidates
.iter()
.find(|(p, _)| p.name == selected.name)
.map(|(_, m)| m.clone())
.unwrap();
payload["model"] = Value::String(target_model.clone());
let rewritten_body = Bytes::from(serde_json::to_vec(&payload).unwrap());
let endpoint_url = match selected.base_url.join("chat/completions") {
Ok(url) => url,
Err(e) => {
error!("Invalid base URL path join for {}: {}", selected.name, e);
continue;
},
};
let mut req = app_state
.http_client
.post(endpoint_url.as_str())
.body(rewritten_body)
.timeout(Duration::from_secs(5))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", selected.api_key));
if let Some(traceparent) = headers.get("traceparent") {
req = req.header("traceparent", traceparent);
}
info!(
"Chat request routing to {} (attempt {})",
selected.name, attempts
);
let res = req.send().await;
match res {
Ok(res) if res.status().is_success() => {
let status_code = res.status().as_u16();
let upstream_headers = res.headers().clone();
let target_provider_state = app_state
.providers
.iter()
.find(|p| p.name == selected.name)
.unwrap();
strategy
.feedback(
target_provider_state,
true,
selected.is_probe,
Some(status_code),
None,
)
.await;
if is_streaming {
let reqwest_stream = res.bytes_stream();
let axum_stream = reqwest_stream.map(|chunk_res| {
chunk_res
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
});
telemetry.emit(TelemetryEvent::RecordTransaction {
operation: "chat".to_string(),
provider: selected.name.clone(),
model: target_model,
input_tokens: 0,
output_tokens: 0,
duration: start_time.elapsed(),
attempts,
failure_reason: None,
trace_id: trace_id.clone(),
parent_span_id: parent_span_id.clone(),
});
let mut response = Response::new(Body::from_stream(axum_stream));
*response.status_mut() = StatusCode::OK;
copy_response_headers(&upstream_headers, response.headers_mut());
return response.into_response();
} else {
let res_body = match res.bytes().await {
Ok(b) => b,
Err(e) => {
warn!(
"Failed to read success response body from {}: {}",
selected.name, e
);
strategy
.feedback(
target_provider_state,
false,
selected.is_probe,
Some(status_code),
None,
)
.await;
continue;
},
};
telemetry.emit(TelemetryEvent::RecordTransaction {
operation: "chat".to_string(),
provider: selected.name.clone(),
model: target_model,
input_tokens: 0,
output_tokens: 0,
duration: start_time.elapsed(),
attempts,
failure_reason: None,
trace_id: trace_id.clone(),
parent_span_id: parent_span_id.clone(),
});
let mut response = Response::new(Body::from(res_body));
*response.status_mut() = StatusCode::OK;
copy_response_headers(&upstream_headers, response.headers_mut());
return response.into_response();
}
},
Ok(res) => {
let status_code = res.status().as_u16();
println!(
"Chat completions upstream {} failed with status {}",
selected.name, status_code
);
warn!(
"Chat completions upstream {} failed with status {}",
selected.name, status_code
);
let retry_after = extract_retry_after(res.headers());
let target_provider_state = app_state
.providers
.iter()
.find(|p| p.name == selected.name)
.unwrap();
strategy
.feedback(
target_provider_state,
false,
selected.is_probe,
Some(status_code),
retry_after,
)
.await;
},
Err(e) => {
println!(
"Chat completions upstream {} connection failed: {:?}",
selected.name, e
);
warn!(
"Chat completions upstream {} connection failed: {}",
selected.name, e
);
let target_provider_state = app_state
.providers
.iter()
.find(|p| p.name == selected.name)
.unwrap();
strategy
.feedback(target_provider_state, false, selected.is_probe, None, None)
.await;
},
}
}
(
StatusCode::BAD_GATEWAY,
"All upstream chat completions providers failed or are rate-limited",
)
.into_response()
}
fn extract_traceparent(headers: &HeaderMap) -> (Option<String>, Option<String>) {
if let Some(val) = headers.get("traceparent").and_then(|v| v.to_str().ok()) {
let segments: Vec<&str> = val.split('-').collect();
if segments.len() >= 3 {
return (Some(segments[1].to_string()), Some(segments[2].to_string()));
}
}
(None, None)
}
fn copy_response_headers(src: &HeaderMap, dest: &mut HeaderMap) {
for (key, value) in src.iter() {
let name_str = key.as_str();
if name_str.starts_with("x-")
|| name_str == "content-type"
|| name_str == "cache-control"
|| name_str == "openai-version"
{
if let Ok(name) = HeaderName::from_bytes(name_str.as_bytes()) {
if let Ok(val) = HeaderValue::from_bytes(value.as_bytes()) {
dest.insert(name, val);
}
}
}
}
}
fn extract_retry_after(headers: &HeaderMap) -> Option<Duration> {
if let Some(retry_after) = headers.get("retry-after").and_then(|h| h.to_str().ok()) {
if let Ok(seconds) = retry_after.parse::<u64>() {
return Some(Duration::from_secs(seconds));
}
}
None
}