use std::io::Write;
use axum::body::{Body, Bytes};
use axum::http::header::{
HeaderValue, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, VARY,
};
use axum::http::Request;
use axum::response::Response;
use flate2::write::GzEncoder;
use flate2::Compression;
const MIN_SIZE: usize = 1024;
const MAX_SIZE: usize = 16 * 1024 * 1024;
pub(crate) fn client_accepts_gzip<B>(req: &Request<B>) -> bool {
req.headers()
.get(ACCEPT_ENCODING)
.and_then(|v| v.to_str().ok())
.map(|s| {
s.split(',')
.any(|tok| {
let mut parts = tok.trim().split(';');
let coding = parts.next().unwrap_or("").trim();
let disabled = parts.any(|p| p.trim().replace(' ', "") == "q=0");
(coding.eq_ignore_ascii_case("gzip") || coding == "*") && !disabled
})
})
.unwrap_or(false)
}
fn is_compressible(ct: &str) -> bool {
let ct = ct.split(';').next().unwrap_or("").trim();
matches!(
ct,
"application/json"
| "application/xml"
| "application/javascript"
| "application/x-ndjson"
| "image/svg+xml"
) || ct.starts_with("text/") && ct != "text/event-stream"
}
pub(crate) async fn maybe_compress(resp: Response, accepts_gzip: bool) -> Response {
if !accepts_gzip {
return add_vary(resp);
}
let headers = resp.headers();
if headers.contains_key(CONTENT_ENCODING) {
return add_vary(resp);
}
let len: usize = match headers
.get(CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse().ok())
{
Some(l) => l,
None => return add_vary(resp),
};
if !(MIN_SIZE..=MAX_SIZE).contains(&len) {
return add_vary(resp);
}
let compressible = headers
.get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(is_compressible)
.unwrap_or(false);
if !compressible {
return add_vary(resp);
}
let (mut parts, body) = resp.into_parts();
let bytes = match axum::body::to_bytes(body, MAX_SIZE).await {
Ok(b) => b,
Err(_) => {
return add_vary(Response::from_parts(parts, Body::empty()));
}
};
let gz = match gzip(&bytes) {
Some(g) => g,
None => return add_vary(Response::from_parts(parts, Body::from(bytes))),
};
parts
.headers
.insert(CONTENT_ENCODING, HeaderValue::from_static("gzip"));
parts.headers.insert(
CONTENT_LENGTH,
HeaderValue::from_str(&gz.len().to_string()).expect("ascii digits"),
);
add_vary(Response::from_parts(parts, Body::from(gz)))
}
fn gzip(data: &Bytes) -> Option<Vec<u8>> {
let mut enc = GzEncoder::new(Vec::with_capacity(data.len() / 2), Compression::default());
enc.write_all(data).ok()?;
let out = enc.finish().ok()?;
(out.len() < data.len()).then_some(out)
}
fn add_vary(mut resp: Response) -> Response {
let h = resp.headers_mut();
if !h.contains_key(VARY) {
h.insert(VARY, HeaderValue::from_static("accept-encoding"));
}
resp
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::Request as HttpRequest;
fn req_with(ae: Option<&str>) -> HttpRequest<()> {
let mut b = HttpRequest::builder();
if let Some(v) = ae {
b = b.header(ACCEPT_ENCODING, v);
}
b.body(()).unwrap()
}
#[test]
fn accept_encoding_parsing() {
assert!(client_accepts_gzip(&req_with(Some("gzip, deflate"))));
assert!(client_accepts_gzip(&req_with(Some("br, gzip;q=0.8"))));
assert!(client_accepts_gzip(&req_with(Some("*"))));
assert!(!client_accepts_gzip(&req_with(Some("gzip;q=0"))));
assert!(!client_accepts_gzip(&req_with(Some("identity"))));
assert!(!client_accepts_gzip(&req_with(None)));
}
#[test]
fn compressibility_excludes_streams_and_binary() {
assert!(is_compressible("application/json"));
assert!(is_compressible("text/html; charset=utf-8"));
assert!(!is_compressible("text/event-stream"));
assert!(!is_compressible("image/png"));
assert!(!is_compressible("application/octet-stream"));
}
#[tokio::test]
async fn compresses_large_json_and_sets_headers() {
let payload = Bytes::from(vec![b'a'; 4096]);
let resp = Response::builder()
.header(CONTENT_TYPE, "application/json")
.header(CONTENT_LENGTH, payload.len().to_string())
.body(Body::from(payload))
.unwrap();
let out = maybe_compress(resp, true).await;
assert_eq!(out.headers().get(CONTENT_ENCODING).unwrap(), "gzip");
assert_eq!(out.headers().get(VARY).unwrap(), "accept-encoding");
let new_len: usize = out
.headers()
.get(CONTENT_LENGTH)
.unwrap()
.to_str()
.unwrap()
.parse()
.unwrap();
assert!(new_len < 4096, "compressed length {new_len} should shrink");
}
#[tokio::test]
async fn skips_streaming_lengthless_body() {
let resp = Response::builder()
.header(CONTENT_TYPE, "text/event-stream")
.body(Body::from("data: hi\n\n"))
.unwrap();
let out = maybe_compress(resp, true).await;
assert!(out.headers().get(CONTENT_ENCODING).is_none());
}
#[tokio::test]
async fn skips_small_and_already_encoded() {
let small = Response::builder()
.header(CONTENT_TYPE, "application/json")
.header(CONTENT_LENGTH, "10")
.body(Body::from("{\"a\":\"b\"}"))
.unwrap();
assert!(maybe_compress(small, true)
.await
.headers()
.get(CONTENT_ENCODING)
.is_none());
let pre = Response::builder()
.header(CONTENT_TYPE, "application/json")
.header(CONTENT_LENGTH, "4096")
.header(CONTENT_ENCODING, "br")
.body(Body::from(vec![b'x'; 4096]))
.unwrap();
assert_eq!(
maybe_compress(pre, true)
.await
.headers()
.get(CONTENT_ENCODING)
.unwrap(),
"br"
);
}
}