use axum::body::Body;
use axum::http::{HeaderMap, HeaderValue, Response, header};
use std::io::Read as _;
use tracing::debug;
pub fn extract_override_id(headers: &HeaderMap, header_name: &str) -> Option<String> {
headers
.get(header_name)
.and_then(|v| v.to_str().ok())
.map(|id| {
if id.starts_with("resp_") {
id.to_string()
} else {
format!("resp_{id}")
}
})
}
pub async fn patch_response_body_id(response: &mut Response<Body>, override_id: String) {
let is_json = response
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.is_some_and(|ct| ct.contains("application/json"));
if !is_json {
return;
}
let content_encoding = response
.headers()
.get("content-encoding")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_lowercase());
let bytes = match axum::body::to_bytes(std::mem::take(response.body_mut()), usize::MAX).await {
Ok(b) => b,
Err(_) => return,
};
let decompressed = match content_encoding.as_deref() {
Some("gzip") => {
let mut decoder = flate2::read::GzDecoder::new(&bytes[..]);
let mut buf = Vec::new();
if decoder.read_to_end(&mut buf).is_ok() {
buf
} else {
debug!("Failed to gzip-decompress response for ID patching, passing through");
*response.body_mut() = Body::from(bytes);
return;
}
}
Some("br") | Some("brotli") => {
let mut buf = Vec::new();
if brotli::Decompressor::new(&bytes[..], 4096)
.read_to_end(&mut buf)
.is_ok()
{
buf
} else {
debug!("Failed to brotli-decompress response for ID patching, passing through");
*response.body_mut() = Body::from(bytes);
return;
}
}
_ => bytes.to_vec(),
};
if let Ok(mut json) = serde_json::from_slice::<serde_json::Value>(&decompressed) {
if json.get("id").is_some() {
json["id"] = serde_json::Value::String(override_id);
}
let patched = serde_json::to_vec(&json).unwrap_or(decompressed);
let content_length = patched.len();
*response.body_mut() = Body::from(patched);
response.headers_mut().remove(header::CONTENT_ENCODING);
response.headers_mut().remove(header::TRANSFER_ENCODING);
response
.headers_mut()
.insert(header::CONTENT_LENGTH, HeaderValue::from(content_length));
} else {
*response.body_mut() = Body::from(bytes);
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::StatusCode;
use flate2::write::GzEncoder;
use flate2::Compression;
use std::io::Write as _;
#[test]
fn extract_override_id_adds_prefix() {
let mut headers = HeaderMap::new();
headers.insert("x-custom-id", HeaderValue::from_static("abc-123"));
assert_eq!(
extract_override_id(&headers, "x-custom-id"),
Some("resp_abc-123".to_string())
);
}
#[test]
fn extract_override_id_preserves_existing_prefix() {
let mut headers = HeaderMap::new();
headers.insert("x-custom-id", HeaderValue::from_static("resp_abc-123"));
assert_eq!(
extract_override_id(&headers, "x-custom-id"),
Some("resp_abc-123".to_string())
);
}
#[test]
fn extract_override_id_missing_header() {
let headers = HeaderMap::new();
assert_eq!(extract_override_id(&headers, "x-custom-id"), None);
}
#[test]
fn extract_override_id_wrong_header_name() {
let mut headers = HeaderMap::new();
headers.insert("x-other", HeaderValue::from_static("abc"));
assert_eq!(extract_override_id(&headers, "x-custom-id"), None);
}
fn build_json_response(body: &[u8], content_type: &str) -> Response<Body> {
Response::builder()
.status(StatusCode::OK)
.header("content-type", content_type)
.body(Body::from(body.to_vec()))
.unwrap()
}
fn gzip_compress(data: &[u8]) -> Vec<u8> {
let mut encoder = GzEncoder::new(Vec::new(), Compression::fast());
encoder.write_all(data).unwrap();
encoder.finish().unwrap()
}
fn brotli_compress(data: &[u8]) -> Vec<u8> {
let mut buf = Vec::new();
{
let mut writer = brotli::CompressorWriter::new(&mut buf, 4096, 4, 22);
writer.write_all(data).unwrap();
}
buf
}
#[tokio::test]
async fn patch_uncompressed_json() {
let body = br#"{"id":"original","model":"gpt-4","status":"completed"}"#;
let mut response = build_json_response(body, "application/json");
patch_response_body_id(&mut response, "resp_override".to_string()).await;
let result = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&result).unwrap();
assert_eq!(json["id"], "resp_override");
assert_eq!(json["model"], "gpt-4");
}
#[tokio::test]
async fn patch_gzip_compressed_json() {
let body = br#"{"id":"original","model":"gpt-4","status":"completed"}"#;
let compressed = gzip_compress(body);
let mut response = Response::builder()
.status(StatusCode::OK)
.header("content-type", "application/json")
.header("content-encoding", "gzip")
.body(Body::from(compressed))
.unwrap();
patch_response_body_id(&mut response, "resp_patched".to_string()).await;
assert!(response.headers().get("content-encoding").is_none());
let result = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&result).unwrap();
assert_eq!(json["id"], "resp_patched");
}
#[tokio::test]
async fn patch_brotli_compressed_json() {
let body = br#"{"id":"original","model":"gpt-4","status":"completed"}"#;
let compressed = brotli_compress(body);
let mut response = Response::builder()
.status(StatusCode::OK)
.header("content-type", "application/json")
.header("content-encoding", "br")
.body(Body::from(compressed))
.unwrap();
patch_response_body_id(&mut response, "resp_br_patched".to_string()).await;
assert!(response.headers().get("content-encoding").is_none());
let result = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&result).unwrap();
assert_eq!(json["id"], "resp_br_patched");
}
#[tokio::test]
async fn patch_skips_non_json() {
let body = b"not json";
let mut response = build_json_response(body, "text/event-stream");
patch_response_body_id(&mut response, "resp_ignored".to_string()).await;
let result = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
assert_eq!(result.as_ref(), b"not json");
}
#[tokio::test]
async fn patch_preserves_body_without_id_field() {
let body = br#"{"model":"gpt-4","status":"completed"}"#;
let mut response = build_json_response(body, "application/json");
patch_response_body_id(&mut response, "resp_noop".to_string()).await;
let result = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&result).unwrap();
assert!(json.get("id").is_none());
}
#[tokio::test]
async fn patch_sets_content_length() {
let body = br#"{"id":"short"}"#;
let mut response = build_json_response(body, "application/json");
patch_response_body_id(&mut response, "resp_much-longer-id-value".to_string()).await;
let cl: usize = response.headers().get("content-length").unwrap().to_str().unwrap().parse().unwrap();
let result = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
assert_eq!(cl, result.len());
}
}