llmux 0.7.7

Zero-reload model switching for vLLM - manages multiple models on shared GPU
Documentation
//! Axum middleware layer for model switching
//!
//! This layer intercepts requests, extracts the model name, ensures the model
//! is ready, and then passes the request through to the inner service (onwards).

use crate::switcher::{InFlightGuard, ModelSwitcher, SwitchError};
use axum::body::Body;
use axum::http::{Request, Response, StatusCode};
use bytes::Bytes;
use futures_util::future::BoxFuture;
use http_body::Frame;
use http_body_util::BodyExt;
use metrics::{counter, histogram};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use tower::{Layer, Service};
use tracing::{debug, error, trace, warn};

/// Layer that adds model switching to a service
#[derive(Clone)]
pub struct ModelSwitcherLayer {
    switcher: ModelSwitcher,
}

impl ModelSwitcherLayer {
    pub fn new(switcher: ModelSwitcher) -> Self {
        Self { switcher }
    }
}

impl<S> Layer<S> for ModelSwitcherLayer {
    type Service = ModelSwitcherService<S>;

    fn layer(&self, inner: S) -> Self::Service {
        ModelSwitcherService {
            switcher: self.switcher.clone(),
            inner,
        }
    }
}

/// Service that wraps requests with model switching
#[derive(Clone)]
pub struct ModelSwitcherService<S> {
    switcher: ModelSwitcher,
    inner: S,
}

impl<S> Service<Request<Body>> for ModelSwitcherService<S>
where
    S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
    S::Future: Send,
{
    type Response = Response<Body>;
    type Error = S::Error;
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, req: Request<Body>) -> Self::Future {
        let switcher = self.switcher.clone();
        let mut inner = self.inner.clone();

        Box::pin(async move {
            // Extract the body to read the model
            let (parts, body) = req.into_parts();

            // Collect body bytes
            let body_bytes = match body.collect().await {
                Ok(collected) => collected.to_bytes(),
                Err(e) => {
                    error!(error = %e, "Failed to read request body");
                    return Ok(error_response(
                        StatusCode::BAD_REQUEST,
                        "Failed to read request body",
                    ));
                }
            };

            // Extract model name from body or header
            let model = extract_model(&parts.headers, &body_bytes);

            let Some(model) = model else {
                // No model specified - pass through (might be a non-model endpoint)
                trace!("No model in request, passing through");
                let req = Request::from_parts(parts, Body::from(body_bytes));
                return inner.call(req).await;
            };

            debug!(model = %model, "Extracted model from request");

            // Check if model is registered
            if !switcher.is_registered(&model) {
                warn!(model = %model, "Model not registered with switcher, passing through");
                // Model not in our switcher - pass through to onwards
                // (might be handled by a different backend)
                let req = Request::from_parts(parts, Body::from(body_bytes));
                return inner.call(req).await;
            }

            // Ensure model is ready and acquire in-flight guard.
            // If the model starts draining between the ready check and the
            // guard acquisition, loop back through ensure_model_ready so the
            // request waits for the next switch instead of getting a 503.
            let queue_start = Instant::now();
            let guard = loop {
                if let Err(e) = switcher.ensure_model_ready(&model).await {
                    error!(model = %model, error = %e, "Failed to ensure model ready");
                    return Ok(switch_error_response(e));
                }

                match switcher.acquire_in_flight(&model) {
                    Some(guard) => break guard,
                    None => {
                        debug!(model = %model, "Model draining, re-entering ensure_model_ready");
                        continue;
                    }
                }
            };
            let queue_wait = queue_start.elapsed();

            // Record request metrics
            let waited = queue_wait > Duration::from_millis(1);
            histogram!("llmux_request_queue_wait_seconds", "model" => model.clone())
                .record(queue_wait.as_secs_f64());
            counter!("llmux_requests_total", "model" => model.clone(), "waited" => if waited { "true" } else { "false" })
                .increment(1);

            // Reconstruct request with body
            let req = Request::from_parts(parts, Body::from(body_bytes));

            // Forward to inner service, then wrap the response body so the
            // in-flight guard is held until the full response (including
            // streamed body) is consumed — not just until headers arrive.
            let response = inner.call(req).await?;
            let (resp_parts, body) = response.into_parts();
            let guarded = GuardedBody {
                inner: body,
                _guard: Some(guard),
            };
            Ok(Response::from_parts(resp_parts, Body::new(guarded)))
        })
    }
}

/// Extract model name from request headers or JSON body
fn extract_model(headers: &axum::http::HeaderMap, body: &Bytes) -> Option<String> {
    // Check model-override header first
    if let Some(model) = headers.get("model-override")
        && let Ok(model_str) = model.to_str()
    {
        return Some(model_str.to_string());
    }

    // Parse JSON body for "model" field
    if let Ok(json) = serde_json::from_slice::<serde_json::Value>(body)
        && let Some(model) = json.get("model").and_then(|v| v.as_str())
    {
        return Some(model.to_string());
    }

    None
}

/// Create an error response
fn error_response(status: StatusCode, message: &str) -> Response<Body> {
    let body = serde_json::json!({
        "error": {
            "message": message,
            "type": "llmux_error"
        }
    });

    Response::builder()
        .status(status)
        .header("Content-Type", "application/json")
        .body(Body::from(body.to_string()))
        .unwrap()
}

/// Create error response from SwitchError
fn switch_error_response(error: SwitchError) -> Response<Body> {
    let (status, message) = match &error {
        SwitchError::ModelNotFound(m) => (StatusCode::NOT_FOUND, format!("Model not found: {}", m)),
        SwitchError::NotReady(m) => (
            StatusCode::SERVICE_UNAVAILABLE,
            format!("Model not ready: {}", m),
        ),
        SwitchError::Timeout => (
            StatusCode::GATEWAY_TIMEOUT,
            "Request timed out waiting for model".to_string(),
        ),
        SwitchError::Orchestrator(e) => (
            StatusCode::INTERNAL_SERVER_ERROR,
            format!("Orchestrator error: {}", e),
        ),
        SwitchError::Internal(e) => (
            StatusCode::INTERNAL_SERVER_ERROR,
            format!("Internal error: {}", e),
        ),
        SwitchError::ManualModeRejected {
            requested, active, ..
        } => (
            StatusCode::SERVICE_UNAVAILABLE,
            format!(
                "Manual mode active: model {} not available (active: {})",
                requested, active
            ),
        ),
    };

    error_response(status, &message)
}

/// Response body wrapper that holds an [`InFlightGuard`] until the body is
/// fully consumed. For streaming responses, this ensures the in-flight count
/// stays accurate until vLLM finishes generating — not just until the response
/// headers arrive.
struct GuardedBody {
    inner: Body,
    _guard: Option<InFlightGuard>,
}

impl http_body::Body for GuardedBody {
    type Data = Bytes;
    type Error = axum::Error;

    fn poll_frame(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
        Pin::new(&mut self.get_mut().inner).poll_frame(cx)
    }

    fn is_end_stream(&self) -> bool {
        self.inner.is_end_stream()
    }

    fn size_hint(&self) -> http_body::SizeHint {
        self.inner.size_hint()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_extract_model_from_header() {
        let mut headers = axum::http::HeaderMap::new();
        headers.insert("model-override", "llama".parse().unwrap());

        let body = Bytes::from("{}");
        let model = extract_model(&headers, &body);

        assert_eq!(model, Some("llama".to_string()));
    }

    #[test]
    fn test_extract_model_from_body() {
        let headers = axum::http::HeaderMap::new();
        let body = Bytes::from(r#"{"model": "mistral", "messages": []}"#);

        let model = extract_model(&headers, &body);
        assert_eq!(model, Some("mistral".to_string()));
    }

    #[test]
    fn test_extract_model_header_precedence() {
        let mut headers = axum::http::HeaderMap::new();
        headers.insert("model-override", "from-header".parse().unwrap());

        let body = Bytes::from(r#"{"model": "from-body"}"#);
        let model = extract_model(&headers, &body);

        // Header takes precedence
        assert_eq!(model, Some("from-header".to_string()));
    }

    #[test]
    fn test_extract_model_none() {
        let headers = axum::http::HeaderMap::new();
        let body = Bytes::from(r#"{"messages": []}"#);

        let model = extract_model(&headers, &body);
        assert_eq!(model, None);
    }
}