use crate::switcher::{InFlightGuard, ModelSwitcher, SwitchError};
use axum::body::Body;
use axum::http::{Request, Response, StatusCode};
use bytes::Bytes;
use futures_util::future::BoxFuture;
use http_body::Frame;
use http_body_util::BodyExt;
use metrics::{counter, histogram};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use tower::{Layer, Service};
use tracing::{debug, error, trace, warn};
#[derive(Clone)]
pub struct ModelSwitcherLayer {
switcher: ModelSwitcher,
}
impl ModelSwitcherLayer {
pub fn new(switcher: ModelSwitcher) -> Self {
Self { switcher }
}
}
impl<S> Layer<S> for ModelSwitcherLayer {
type Service = ModelSwitcherService<S>;
fn layer(&self, inner: S) -> Self::Service {
ModelSwitcherService {
switcher: self.switcher.clone(),
inner,
}
}
}
#[derive(Clone)]
pub struct ModelSwitcherService<S> {
switcher: ModelSwitcher,
inner: S,
}
impl<S> Service<Request<Body>> for ModelSwitcherService<S>
where
S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
S::Future: Send,
{
type Response = Response<Body>;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let switcher = self.switcher.clone();
let mut inner = self.inner.clone();
Box::pin(async move {
let (parts, body) = req.into_parts();
let body_bytes = match body.collect().await {
Ok(collected) => collected.to_bytes(),
Err(e) => {
error!(error = %e, "Failed to read request body");
return Ok(error_response(
StatusCode::BAD_REQUEST,
"Failed to read request body",
));
}
};
let model = extract_model(&parts.headers, &body_bytes);
let Some(model) = model else {
trace!("No model in request, passing through");
let req = Request::from_parts(parts, Body::from(body_bytes));
return inner.call(req).await;
};
debug!(model = %model, "Extracted model from request");
if !switcher.is_registered(&model) {
warn!(model = %model, "Model not registered with switcher, passing through");
let req = Request::from_parts(parts, Body::from(body_bytes));
return inner.call(req).await;
}
let queue_start = Instant::now();
let guard = loop {
if let Err(e) = switcher.ensure_model_ready(&model).await {
error!(model = %model, error = %e, "Failed to ensure model ready");
return Ok(switch_error_response(e));
}
match switcher.acquire_in_flight(&model) {
Some(guard) => break guard,
None => {
debug!(model = %model, "Model draining, re-entering ensure_model_ready");
continue;
}
}
};
let queue_wait = queue_start.elapsed();
let waited = queue_wait > Duration::from_millis(1);
histogram!("llmux_request_queue_wait_seconds", "model" => model.clone())
.record(queue_wait.as_secs_f64());
counter!("llmux_requests_total", "model" => model.clone(), "waited" => if waited { "true" } else { "false" })
.increment(1);
let req = Request::from_parts(parts, Body::from(body_bytes));
let response = inner.call(req).await?;
let (resp_parts, body) = response.into_parts();
let guarded = GuardedBody {
inner: body,
_guard: Some(guard),
};
Ok(Response::from_parts(resp_parts, Body::new(guarded)))
})
}
}
fn extract_model(headers: &axum::http::HeaderMap, body: &Bytes) -> Option<String> {
if let Some(model) = headers.get("model-override")
&& let Ok(model_str) = model.to_str()
{
return Some(model_str.to_string());
}
if let Ok(json) = serde_json::from_slice::<serde_json::Value>(body)
&& let Some(model) = json.get("model").and_then(|v| v.as_str())
{
return Some(model.to_string());
}
None
}
fn error_response(status: StatusCode, message: &str) -> Response<Body> {
let body = serde_json::json!({
"error": {
"message": message,
"type": "llmux_error"
}
});
Response::builder()
.status(status)
.header("Content-Type", "application/json")
.body(Body::from(body.to_string()))
.unwrap()
}
fn switch_error_response(error: SwitchError) -> Response<Body> {
let (status, message) = match &error {
SwitchError::ModelNotFound(m) => (StatusCode::NOT_FOUND, format!("Model not found: {}", m)),
SwitchError::NotReady(m) => (
StatusCode::SERVICE_UNAVAILABLE,
format!("Model not ready: {}", m),
),
SwitchError::Timeout => (
StatusCode::GATEWAY_TIMEOUT,
"Request timed out waiting for model".to_string(),
),
SwitchError::Orchestrator(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Orchestrator error: {}", e),
),
SwitchError::Internal(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Internal error: {}", e),
),
SwitchError::ManualModeRejected {
requested, active, ..
} => (
StatusCode::SERVICE_UNAVAILABLE,
format!(
"Manual mode active: model {} not available (active: {})",
requested, active
),
),
};
error_response(status, &message)
}
struct GuardedBody {
inner: Body,
_guard: Option<InFlightGuard>,
}
impl http_body::Body for GuardedBody {
type Data = Bytes;
type Error = axum::Error;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
Pin::new(&mut self.get_mut().inner).poll_frame(cx)
}
fn is_end_stream(&self) -> bool {
self.inner.is_end_stream()
}
fn size_hint(&self) -> http_body::SizeHint {
self.inner.size_hint()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_model_from_header() {
let mut headers = axum::http::HeaderMap::new();
headers.insert("model-override", "llama".parse().unwrap());
let body = Bytes::from("{}");
let model = extract_model(&headers, &body);
assert_eq!(model, Some("llama".to_string()));
}
#[test]
fn test_extract_model_from_body() {
let headers = axum::http::HeaderMap::new();
let body = Bytes::from(r#"{"model": "mistral", "messages": []}"#);
let model = extract_model(&headers, &body);
assert_eq!(model, Some("mistral".to_string()));
}
#[test]
fn test_extract_model_header_precedence() {
let mut headers = axum::http::HeaderMap::new();
headers.insert("model-override", "from-header".parse().unwrap());
let body = Bytes::from(r#"{"model": "from-body"}"#);
let model = extract_model(&headers, &body);
assert_eq!(model, Some("from-header".to_string()));
}
#[test]
fn test_extract_model_none() {
let headers = axum::http::HeaderMap::new();
let body = Bytes::from(r#"{"messages": []}"#);
let model = extract_model(&headers, &body);
assert_eq!(model, None);
}
}