use crate::switcher::{ModelSwitcher, SwitchError};
use axum::body::Body;
use axum::http::{Request, Response, StatusCode};
use bytes::Bytes;
use futures_util::future::BoxFuture;
use http_body_util::BodyExt;
use std::task::{Context, Poll};
use tower::{Layer, Service};
use tracing::{debug, error, trace, warn};
#[derive(Clone)]
pub struct ModelSwitcherLayer {
switcher: ModelSwitcher,
}
impl ModelSwitcherLayer {
pub fn new(switcher: ModelSwitcher) -> Self {
Self { switcher }
}
}
impl<S> Layer<S> for ModelSwitcherLayer {
type Service = ModelSwitcherService<S>;
fn layer(&self, inner: S) -> Self::Service {
ModelSwitcherService {
switcher: self.switcher.clone(),
inner,
}
}
}
#[derive(Clone)]
pub struct ModelSwitcherService<S> {
switcher: ModelSwitcher,
inner: S,
}
impl<S> Service<Request<Body>> for ModelSwitcherService<S>
where
S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
S::Future: Send,
{
type Response = Response<Body>;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let switcher = self.switcher.clone();
let mut inner = self.inner.clone();
Box::pin(async move {
let (parts, body) = req.into_parts();
let body_bytes = match body.collect().await {
Ok(collected) => collected.to_bytes(),
Err(e) => {
error!(error = %e, "Failed to read request body");
return Ok(error_response(
StatusCode::BAD_REQUEST,
"Failed to read request body",
));
}
};
let model = extract_model(&parts.headers, &body_bytes);
let Some(model) = model else {
trace!("No model in request, passing through");
let req = Request::from_parts(parts, Body::from(body_bytes));
return inner.call(req).await;
};
debug!(model = %model, "Extracted model from request");
if !switcher.is_registered(&model) {
warn!(model = %model, "Model not registered with switcher, passing through");
let req = Request::from_parts(parts, Body::from(body_bytes));
return inner.call(req).await;
}
if let Err(e) = switcher.ensure_model_ready(&model).await {
error!(model = %model, error = %e, "Failed to ensure model ready");
return Ok(switch_error_response(e));
}
let _guard = switcher.acquire_in_flight(&model);
let req = Request::from_parts(parts, Body::from(body_bytes));
inner.call(req).await
})
}
}
fn extract_model(headers: &axum::http::HeaderMap, body: &Bytes) -> Option<String> {
if let Some(model) = headers.get("model-override")
&& let Ok(model_str) = model.to_str()
{
return Some(model_str.to_string());
}
if let Ok(json) = serde_json::from_slice::<serde_json::Value>(body)
&& let Some(model) = json.get("model").and_then(|v| v.as_str())
{
return Some(model.to_string());
}
None
}
fn error_response(status: StatusCode, message: &str) -> Response<Body> {
let body = serde_json::json!({
"error": {
"message": message,
"type": "llmux_error"
}
});
Response::builder()
.status(status)
.header("Content-Type", "application/json")
.body(Body::from(body.to_string()))
.unwrap()
}
fn switch_error_response(error: SwitchError) -> Response<Body> {
let (status, message) = match &error {
SwitchError::ModelNotFound(m) => (StatusCode::NOT_FOUND, format!("Model not found: {}", m)),
SwitchError::NotReady(m) => (
StatusCode::SERVICE_UNAVAILABLE,
format!("Model not ready: {}", m),
),
SwitchError::Timeout => (
StatusCode::GATEWAY_TIMEOUT,
"Request timed out waiting for model".to_string(),
),
SwitchError::Orchestrator(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Orchestrator error: {}", e),
),
SwitchError::Internal(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Internal error: {}", e),
),
};
error_response(status, &message)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_model_from_header() {
let mut headers = axum::http::HeaderMap::new();
headers.insert("model-override", "llama".parse().unwrap());
let body = Bytes::from("{}");
let model = extract_model(&headers, &body);
assert_eq!(model, Some("llama".to_string()));
}
#[test]
fn test_extract_model_from_body() {
let headers = axum::http::HeaderMap::new();
let body = Bytes::from(r#"{"model": "mistral", "messages": []}"#);
let model = extract_model(&headers, &body);
assert_eq!(model, Some("mistral".to_string()));
}
#[test]
fn test_extract_model_header_precedence() {
let mut headers = axum::http::HeaderMap::new();
headers.insert("model-override", "from-header".parse().unwrap());
let body = Bytes::from(r#"{"model": "from-body"}"#);
let model = extract_model(&headers, &body);
assert_eq!(model, Some("from-header".to_string()));
}
#[test]
fn test_extract_model_none() {
let headers = axum::http::HeaderMap::new();
let body = Bytes::from(r#"{"messages": []}"#);
let model = extract_model(&headers, &body);
assert_eq!(model, None);
}
}