onwards 0.26.0

A flexible LLM proxy library
Documentation
//! Shared utilities for response ID override.
//!
//! When [`AppState::response_id_header`] is configured, the caller can supply a
//! response ID via a request header. These helpers extract the override value
//! and patch it into the HTTP response body, handling gzip/brotli decompression
//! and returning the body uncompressed (with `Content-Encoding` stripped).

use axum::body::Body;
use axum::http::{HeaderMap, HeaderValue, Response, header};
use std::io::Read as _;
use tracing::debug;

/// Extract a response ID override from request headers.
///
/// Looks up the header named `header_name` and normalises the value with a
/// `resp_` prefix (added if not already present).
pub fn extract_override_id(headers: &HeaderMap, header_name: &str) -> Option<String> {
    headers
        .get(header_name)
        .and_then(|v| v.to_str().ok())
        .map(|id| {
            if id.starts_with("resp_") {
                id.to_string()
            } else {
                format!("resp_{id}")
            }
        })
}

/// Patch the `id` field in a JSON response body with the given override.
///
/// Handles `Content-Encoding: gzip` and `Content-Encoding: br` transparently:
/// the body is decompressed, the `id` field is overwritten, and the response is
/// returned **uncompressed** (with `Content-Encoding` / `Transfer-Encoding`
/// stripped and `Content-Length` updated).
///
/// If the body is not JSON, cannot be decompressed, or does not contain an `id`
/// field, the response is returned unchanged.
pub async fn patch_response_body_id(response: &mut Response<Body>, override_id: String) {
    let is_json = response
        .headers()
        .get("content-type")
        .and_then(|v| v.to_str().ok())
        .is_some_and(|ct| ct.contains("application/json"));

    if !is_json {
        return;
    }

    let content_encoding = response
        .headers()
        .get("content-encoding")
        .and_then(|v| v.to_str().ok())
        .map(|s| s.to_lowercase());

    let bytes = match axum::body::to_bytes(std::mem::take(response.body_mut()), usize::MAX).await {
        Ok(b) => b,
        Err(_) => return,
    };

    // Decompress if needed.
    let decompressed = match content_encoding.as_deref() {
        Some("gzip") => {
            let mut decoder = flate2::read::GzDecoder::new(&bytes[..]);
            let mut buf = Vec::new();
            if decoder.read_to_end(&mut buf).is_ok() {
                buf
            } else {
                debug!("Failed to gzip-decompress response for ID patching, passing through");
                *response.body_mut() = Body::from(bytes);
                return;
            }
        }
        Some("br") | Some("brotli") => {
            let mut buf = Vec::new();
            if brotli::Decompressor::new(&bytes[..], 4096)
                .read_to_end(&mut buf)
                .is_ok()
            {
                buf
            } else {
                debug!("Failed to brotli-decompress response for ID patching, passing through");
                *response.body_mut() = Body::from(bytes);
                return;
            }
        }
        _ => bytes.to_vec(),
    };

    // Parse, patch, rewrite.
    if let Ok(mut json) = serde_json::from_slice::<serde_json::Value>(&decompressed) {
        if json.get("id").is_some() {
            json["id"] = serde_json::Value::String(override_id);
        }
        let patched = serde_json::to_vec(&json).unwrap_or(decompressed);
        let content_length = patched.len();
        *response.body_mut() = Body::from(patched);

        // Return uncompressed — the body size changed so the original encoding
        // is no longer valid. Strip encoding headers and set Content-Length.
        response.headers_mut().remove(header::CONTENT_ENCODING);
        response.headers_mut().remove(header::TRANSFER_ENCODING);
        response
            .headers_mut()
            .insert(header::CONTENT_LENGTH, HeaderValue::from(content_length));
    } else {
        // Not valid JSON — restore the original (possibly compressed) body.
        *response.body_mut() = Body::from(bytes);
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use axum::http::StatusCode;
    use flate2::write::GzEncoder;
    use flate2::Compression;
    use std::io::Write as _;

    // --- extract_override_id ---

    #[test]
    fn extract_override_id_adds_prefix() {
        let mut headers = HeaderMap::new();
        headers.insert("x-custom-id", HeaderValue::from_static("abc-123"));
        assert_eq!(
            extract_override_id(&headers, "x-custom-id"),
            Some("resp_abc-123".to_string())
        );
    }

    #[test]
    fn extract_override_id_preserves_existing_prefix() {
        let mut headers = HeaderMap::new();
        headers.insert("x-custom-id", HeaderValue::from_static("resp_abc-123"));
        assert_eq!(
            extract_override_id(&headers, "x-custom-id"),
            Some("resp_abc-123".to_string())
        );
    }

    #[test]
    fn extract_override_id_missing_header() {
        let headers = HeaderMap::new();
        assert_eq!(extract_override_id(&headers, "x-custom-id"), None);
    }

    #[test]
    fn extract_override_id_wrong_header_name() {
        let mut headers = HeaderMap::new();
        headers.insert("x-other", HeaderValue::from_static("abc"));
        assert_eq!(extract_override_id(&headers, "x-custom-id"), None);
    }

    // --- patch_response_body_id ---

    fn build_json_response(body: &[u8], content_type: &str) -> Response<Body> {
        Response::builder()
            .status(StatusCode::OK)
            .header("content-type", content_type)
            .body(Body::from(body.to_vec()))
            .unwrap()
    }

    fn gzip_compress(data: &[u8]) -> Vec<u8> {
        let mut encoder = GzEncoder::new(Vec::new(), Compression::fast());
        encoder.write_all(data).unwrap();
        encoder.finish().unwrap()
    }

    fn brotli_compress(data: &[u8]) -> Vec<u8> {
        let mut buf = Vec::new();
        {
            let mut writer = brotli::CompressorWriter::new(&mut buf, 4096, 4, 22);
            writer.write_all(data).unwrap();
        }
        buf
    }

    #[tokio::test]
    async fn patch_uncompressed_json() {
        let body = br#"{"id":"original","model":"gpt-4","status":"completed"}"#;
        let mut response = build_json_response(body, "application/json");

        patch_response_body_id(&mut response, "resp_override".to_string()).await;

        let result = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
        let json: serde_json::Value = serde_json::from_slice(&result).unwrap();
        assert_eq!(json["id"], "resp_override");
        assert_eq!(json["model"], "gpt-4");
    }

    #[tokio::test]
    async fn patch_gzip_compressed_json() {
        let body = br#"{"id":"original","model":"gpt-4","status":"completed"}"#;
        let compressed = gzip_compress(body);
        let mut response = Response::builder()
            .status(StatusCode::OK)
            .header("content-type", "application/json")
            .header("content-encoding", "gzip")
            .body(Body::from(compressed))
            .unwrap();

        patch_response_body_id(&mut response, "resp_patched".to_string()).await;

        // Should be decompressed after patching
        assert!(response.headers().get("content-encoding").is_none());
        let result = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
        let json: serde_json::Value = serde_json::from_slice(&result).unwrap();
        assert_eq!(json["id"], "resp_patched");
    }

    #[tokio::test]
    async fn patch_brotli_compressed_json() {
        let body = br#"{"id":"original","model":"gpt-4","status":"completed"}"#;
        let compressed = brotli_compress(body);
        let mut response = Response::builder()
            .status(StatusCode::OK)
            .header("content-type", "application/json")
            .header("content-encoding", "br")
            .body(Body::from(compressed))
            .unwrap();

        patch_response_body_id(&mut response, "resp_br_patched".to_string()).await;

        assert!(response.headers().get("content-encoding").is_none());
        let result = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
        let json: serde_json::Value = serde_json::from_slice(&result).unwrap();
        assert_eq!(json["id"], "resp_br_patched");
    }

    #[tokio::test]
    async fn patch_skips_non_json() {
        let body = b"not json";
        let mut response = build_json_response(body, "text/event-stream");

        patch_response_body_id(&mut response, "resp_ignored".to_string()).await;

        let result = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
        assert_eq!(result.as_ref(), b"not json");
    }

    #[tokio::test]
    async fn patch_preserves_body_without_id_field() {
        let body = br#"{"model":"gpt-4","status":"completed"}"#;
        let mut response = build_json_response(body, "application/json");

        patch_response_body_id(&mut response, "resp_noop".to_string()).await;

        let result = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
        let json: serde_json::Value = serde_json::from_slice(&result).unwrap();
        assert!(json.get("id").is_none());
    }

    #[tokio::test]
    async fn patch_sets_content_length() {
        let body = br#"{"id":"short"}"#;
        let mut response = build_json_response(body, "application/json");

        patch_response_body_id(&mut response, "resp_much-longer-id-value".to_string()).await;

        let cl: usize = response.headers().get("content-length").unwrap().to_str().unwrap().parse().unwrap();
        let result = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
        assert_eq!(cl, result.len());
    }
}