llmux 0.3.0

Zero-reload model switching for vLLM - manages multiple models on shared GPU
Documentation
//! Axum middleware layer for model switching
//!
//! This layer intercepts requests, extracts the model name, ensures the model
//! is ready, and then passes the request through to the inner service (onwards).

use crate::switcher::{ModelSwitcher, SwitchError};
use axum::body::Body;
use axum::http::{Request, Response, StatusCode};
use bytes::Bytes;
use futures_util::future::BoxFuture;
use http_body_util::BodyExt;
use std::task::{Context, Poll};
use tower::{Layer, Service};
use tracing::{debug, error, trace, warn};

/// Layer that adds model switching to a service
#[derive(Clone)]
pub struct ModelSwitcherLayer {
    switcher: ModelSwitcher,
}

impl ModelSwitcherLayer {
    pub fn new(switcher: ModelSwitcher) -> Self {
        Self { switcher }
    }
}

impl<S> Layer<S> for ModelSwitcherLayer {
    type Service = ModelSwitcherService<S>;

    fn layer(&self, inner: S) -> Self::Service {
        ModelSwitcherService {
            switcher: self.switcher.clone(),
            inner,
        }
    }
}

/// Service that wraps requests with model switching
#[derive(Clone)]
pub struct ModelSwitcherService<S> {
    switcher: ModelSwitcher,
    inner: S,
}

impl<S> Service<Request<Body>> for ModelSwitcherService<S>
where
    S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
    S::Future: Send,
{
    type Response = Response<Body>;
    type Error = S::Error;
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, req: Request<Body>) -> Self::Future {
        let switcher = self.switcher.clone();
        let mut inner = self.inner.clone();

        Box::pin(async move {
            // Extract the body to read the model
            let (parts, body) = req.into_parts();

            // Collect body bytes
            let body_bytes = match body.collect().await {
                Ok(collected) => collected.to_bytes(),
                Err(e) => {
                    error!(error = %e, "Failed to read request body");
                    return Ok(error_response(
                        StatusCode::BAD_REQUEST,
                        "Failed to read request body",
                    ));
                }
            };

            // Extract model name from body or header
            let model = extract_model(&parts.headers, &body_bytes);

            let Some(model) = model else {
                // No model specified - pass through (might be a non-model endpoint)
                trace!("No model in request, passing through");
                let req = Request::from_parts(parts, Body::from(body_bytes));
                return inner.call(req).await;
            };

            debug!(model = %model, "Extracted model from request");

            // Check if model is registered
            if !switcher.is_registered(&model) {
                warn!(model = %model, "Model not registered with switcher, passing through");
                // Model not in our switcher - pass through to onwards
                // (might be handled by a different backend)
                let req = Request::from_parts(parts, Body::from(body_bytes));
                return inner.call(req).await;
            }

            // Ensure model is ready
            if let Err(e) = switcher.ensure_model_ready(&model).await {
                error!(model = %model, error = %e, "Failed to ensure model ready");
                return Ok(switch_error_response(e));
            }

            // Acquire in-flight guard
            let _guard = switcher.acquire_in_flight(&model);

            // Reconstruct request with body
            let req = Request::from_parts(parts, Body::from(body_bytes));

            // Forward to inner service
            inner.call(req).await
        })
    }
}

/// Extract model name from request headers or JSON body
fn extract_model(headers: &axum::http::HeaderMap, body: &Bytes) -> Option<String> {
    // Check model-override header first
    if let Some(model) = headers.get("model-override")
        && let Ok(model_str) = model.to_str()
    {
        return Some(model_str.to_string());
    }

    // Parse JSON body for "model" field
    if let Ok(json) = serde_json::from_slice::<serde_json::Value>(body)
        && let Some(model) = json.get("model").and_then(|v| v.as_str())
    {
        return Some(model.to_string());
    }

    None
}

/// Create an error response
fn error_response(status: StatusCode, message: &str) -> Response<Body> {
    let body = serde_json::json!({
        "error": {
            "message": message,
            "type": "llmux_error"
        }
    });

    Response::builder()
        .status(status)
        .header("Content-Type", "application/json")
        .body(Body::from(body.to_string()))
        .unwrap()
}

/// Create error response from SwitchError
fn switch_error_response(error: SwitchError) -> Response<Body> {
    let (status, message) = match &error {
        SwitchError::ModelNotFound(m) => (StatusCode::NOT_FOUND, format!("Model not found: {}", m)),
        SwitchError::NotReady(m) => (
            StatusCode::SERVICE_UNAVAILABLE,
            format!("Model not ready: {}", m),
        ),
        SwitchError::Timeout => (
            StatusCode::GATEWAY_TIMEOUT,
            "Request timed out waiting for model".to_string(),
        ),
        SwitchError::Orchestrator(e) => (
            StatusCode::INTERNAL_SERVER_ERROR,
            format!("Orchestrator error: {}", e),
        ),
        SwitchError::Internal(e) => (
            StatusCode::INTERNAL_SERVER_ERROR,
            format!("Internal error: {}", e),
        ),
    };

    error_response(status, &message)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_extract_model_from_header() {
        let mut headers = axum::http::HeaderMap::new();
        headers.insert("model-override", "llama".parse().unwrap());

        let body = Bytes::from("{}");
        let model = extract_model(&headers, &body);

        assert_eq!(model, Some("llama".to_string()));
    }

    #[test]
    fn test_extract_model_from_body() {
        let headers = axum::http::HeaderMap::new();
        let body = Bytes::from(r#"{"model": "mistral", "messages": []}"#);

        let model = extract_model(&headers, &body);
        assert_eq!(model, Some("mistral".to_string()));
    }

    #[test]
    fn test_extract_model_header_precedence() {
        let mut headers = axum::http::HeaderMap::new();
        headers.insert("model-override", "from-header".parse().unwrap());

        let body = Bytes::from(r#"{"model": "from-body"}"#);
        let model = extract_model(&headers, &body);

        // Header takes precedence
        assert_eq!(model, Some("from-header".to_string()));
    }

    #[test]
    fn test_extract_model_none() {
        let headers = axum::http::HeaderMap::new();
        let body = Bytes::from(r#"{"messages": []}"#);

        let model = extract_model(&headers, &body);
        assert_eq!(model, None);
    }
}