use axum::body::Bytes;
use axum::http::HeaderMap;
use flate2::read::GzDecoder;
use std::io::Read;
use zstd::stream::read::Decoder as ZstdDecoder;
const MAX_DECOMPRESSED_SIZE: u64 = 512 * 1024 * 1024;
pub fn extract_body(headers: &HeaderMap, raw: Bytes) -> Result<String, String> {
let encoding = headers
.get("content-encoding")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let no_encoding = encoding.is_empty();
let has_gzip_magic = raw.len() >= 2 && raw[0] == 0x1f && raw[1] == 0x8b;
let has_zstd_magic =
raw.len() >= 4 && raw[0] == 0x28 && raw[1] == 0xb5 && raw[2] == 0x2f && raw[3] == 0xfd;
let is_gzip = encoding.eq_ignore_ascii_case("gzip") || (no_encoding && has_gzip_magic);
let is_zstd = encoding.eq_ignore_ascii_case("zstd") || (no_encoding && has_zstd_magic);
let bytes = if is_gzip {
let mut decoded = Vec::new();
GzDecoder::new(&raw[..])
.take(MAX_DECOMPRESSED_SIZE)
.read_to_end(&mut decoded)
.map_err(|e| format!("gzip decode error: {e}"))?;
decoded
} else if is_zstd {
let mut decoded = Vec::new();
ZstdDecoder::new(&raw[..])
.map_err(|e| format!("zstd init error: {e}"))?
.take(MAX_DECOMPRESSED_SIZE)
.read_to_end(&mut decoded)
.map_err(|e| format!("zstd decode error: {e}"))?;
decoded
} else {
raw.to_vec()
};
String::from_utf8(bytes).map_err(|e| format!("invalid UTF-8: {e}"))
}
#[cfg(test)]
mod tests {
use super::*;
use flate2::write::GzEncoder;
use flate2::Compression;
use std::io::Write;
#[test]
fn test_extract_plain_body() {
let headers = HeaderMap::new();
let body = Bytes::from(r#"{"model":"gpt-4o"}"#);
let result = extract_body(&headers, body).unwrap();
assert_eq!(result, r#"{"model":"gpt-4o"}"#);
}
#[test]
fn test_extract_gzip_body() {
let original = r#"{"model":"claude-sonnet-4-20250514","messages":[{"role":"user","content":"hello"}]}"#;
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(original.as_bytes()).unwrap();
let compressed = encoder.finish().unwrap();
let mut headers = HeaderMap::new();
headers.insert("content-encoding", "gzip".parse().unwrap());
let result = extract_body(&headers, Bytes::from(compressed)).unwrap();
assert_eq!(result, original);
}
#[test]
fn test_extract_gzip_body_case_insensitive() {
let original = "hello world";
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(original.as_bytes()).unwrap();
let compressed = encoder.finish().unwrap();
let mut headers = HeaderMap::new();
headers.insert("content-encoding", "Gzip".parse().unwrap());
let result = extract_body(&headers, Bytes::from(compressed)).unwrap();
assert_eq!(result, original);
}
#[test]
fn test_extract_gzip_body_magic_bytes_no_header() {
let original = r#"{"model":"gpt-4o","messages":[{"role":"user","content":"test"}]}"#;
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(original.as_bytes()).unwrap();
let compressed = encoder.finish().unwrap();
assert_eq!(compressed[0], 0x1f);
assert_eq!(compressed[1], 0x8b);
let headers = HeaderMap::new();
let result = extract_body(&headers, Bytes::from(compressed)).unwrap();
assert_eq!(result, original);
}
#[test]
fn test_extract_zstd_body() {
let original = r#"{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}]}"#;
let compressed = zstd::encode_all(original.as_bytes(), 3).unwrap();
let mut headers = HeaderMap::new();
headers.insert("content-encoding", "zstd".parse().unwrap());
let result = extract_body(&headers, Bytes::from(compressed)).unwrap();
assert_eq!(result, original);
}
#[test]
fn test_extract_zstd_body_magic_bytes_no_header() {
let original = r#"{"model":"gpt-5.4","messages":[{"role":"user","content":"test"}]}"#;
let compressed = zstd::encode_all(original.as_bytes(), 3).unwrap();
assert_eq!(compressed[0], 0x28);
assert_eq!(compressed[1], 0xb5);
assert_eq!(compressed[2], 0x2f);
assert_eq!(compressed[3], 0xfd);
let headers = HeaderMap::new();
let result = extract_body(&headers, Bytes::from(compressed)).unwrap();
assert_eq!(result, original);
}
#[test]
fn test_magic_bytes_ignored_when_encoding_header_set() {
let original = r#"{"test": true}"#;
let compressed = zstd::encode_all(original.as_bytes(), 3).unwrap();
let mut headers = HeaderMap::new();
headers.insert("content-encoding", "br".parse().unwrap());
let result = extract_body(&headers, Bytes::from(compressed));
assert!(result.is_err());
}
#[test]
fn test_extract_invalid_utf8() {
let headers = HeaderMap::new();
let body = Bytes::from(vec![0xff, 0xfe, 0xfd]);
let result = extract_body(&headers, body);
assert!(result.is_err());
assert!(result.unwrap_err().contains("invalid UTF-8"));
}
#[test]
fn test_extract_invalid_gzip() {
let mut headers = HeaderMap::new();
headers.insert("content-encoding", "gzip".parse().unwrap());
let body = Bytes::from(vec![0x00, 0x01, 0x02, 0x03]);
let result = extract_body(&headers, body);
assert!(result.is_err());
assert!(result.unwrap_err().contains("gzip decode error"));
}
}