use actus_reply::{ReplyData, ReplySpec};
use http::{HeaderValue, Response, header};
use std::borrow::Cow;
use std::collections::HashMap;
use std::io::Write;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
enum Encoding {
Brotli,
Gzip,
Identity,
}
#[derive(Clone, Debug)]
pub struct CompressionLayer {
min_size: usize,
prefer_brotli: bool,
brotli_quality: u32,
}
const DEFAULT_BROTLI_QUALITY: u32 = 4;
impl Default for CompressionLayer {
fn default() -> Self {
Self::new()
}
}
impl CompressionLayer {
pub fn new() -> Self {
Self {
min_size: 1024,
prefer_brotli: true,
brotli_quality: DEFAULT_BROTLI_QUALITY,
}
}
pub fn min_size(mut self, bytes: usize) -> Self {
self.min_size = bytes;
self
}
pub fn prefer_gzip(mut self) -> Self {
self.prefer_brotli = false;
self
}
pub fn brotli_quality(mut self, q: u32) -> Self {
self.brotli_quality = q.min(11);
self
}
pub(crate) fn compress_reply(
&self,
data: ReplyData,
accept_encoding: Option<&str>,
) -> ReplyData {
if let ReplyData::Rich(spec) = &data
&& spec.headers.iter().any(|(k, v)| {
k.eq_ignore_ascii_case("cache-control")
&& v.split(',')
.any(|t| t.trim().eq_ignore_ascii_case("no-transform"))
})
{
return data;
}
let enc = match negotiate(accept_encoding, self.prefer_brotli) {
Encoding::Identity => return data,
other => other,
};
match data {
ReplyData::Rich(mut spec) => {
if spec
.headers
.keys()
.any(|k| k.eq_ignore_ascii_case("content-encoding"))
{
return ReplyData::Rich(spec);
}
let inner = std::mem::replace(&mut spec.payload, ReplyData::Empty);
let (payload, encoded_as) = self.compress_payload(inner, enc);
spec.payload = payload;
if let Some(name) = encoded_as {
spec.headers
.insert("content-encoding".to_string(), name.to_string());
}
ReplyData::Rich(spec)
}
other => match self.compress_payload(other, enc) {
(payload, Some(name)) => ReplyData::Rich(Box::new(ReplySpec {
payload,
status: None,
headers: HashMap::from([("content-encoding".to_string(), name.to_string())]),
})),
(payload, None) => payload,
},
}
}
fn compress_payload(
&self,
payload: ReplyData,
enc: Encoding,
) -> (ReplyData, Option<&'static str>) {
let name = match enc {
Encoding::Gzip => "gzip",
Encoding::Brotli => "br",
Encoding::Identity => return (payload, None),
};
match payload {
ReplyData::Json(value) => {
let bytes = match serde_json::to_vec(&value) {
Ok(b) => b,
Err(_) => return (ReplyData::Json(value), None),
};
let json: Cow<'static, str> = Cow::Borrowed("application/json");
if bytes.len() < self.min_size {
return (
ReplyData::Bytes {
content_type: json,
data: bytes,
},
None,
);
}
match encode(enc, &bytes, self.brotli_quality) {
Some(out) if out.len() < bytes.len() => (
ReplyData::Bytes {
content_type: json,
data: out,
},
Some(name),
),
_ => (
ReplyData::Bytes {
content_type: json,
data: bytes,
},
None,
),
}
}
ReplyData::Bytes { content_type, data } => {
if data.len() < self.min_size || !is_compressible(&content_type) {
return (ReplyData::Bytes { content_type, data }, None);
}
match encode(enc, &data, self.brotli_quality) {
Some(out) if out.len() < data.len() => (
ReplyData::Bytes {
content_type,
data: out,
},
Some(name),
),
_ => (ReplyData::Bytes { content_type, data }, None),
}
}
other => (other, None),
}
}
}
fn negotiate(accept_encoding: Option<&str>, prefer_brotli: bool) -> Encoding {
let Some(ae) = accept_encoding else {
return Encoding::Identity;
};
let mut br_q: Option<f32> = None;
let mut gzip_q: Option<f32> = None;
let mut star_q: Option<f32> = None;
for token in ae.split(',') {
let mut parts = token.split(';');
let name = parts.next().map(str::trim).unwrap_or("");
let mut q: f32 = 1.0;
for p in parts {
let p = p.trim();
if let Some(qs) = p.strip_prefix("q=").or_else(|| p.strip_prefix("Q="))
&& let Ok(v) = qs.parse::<f32>()
&& (0.0..=1.0).contains(&v)
{
q = v;
}
}
match name.to_ascii_lowercase().as_str() {
"br" => br_q = Some(q),
"gzip" => gzip_q = Some(q),
"*" => star_q = Some(q),
_ => {}
}
}
let br = br_q.or(star_q).unwrap_or(0.0);
let gzip = gzip_q.or(star_q).unwrap_or(0.0);
let br_ok = br > 0.0;
let gzip_ok = gzip > 0.0;
match (br_ok, gzip_ok) {
(true, true) => {
if (br - gzip).abs() < f32::EPSILON {
if prefer_brotli {
Encoding::Brotli
} else {
Encoding::Gzip
}
} else if br > gzip {
Encoding::Brotli
} else {
Encoding::Gzip
}
}
(true, false) => Encoding::Brotli,
(false, true) => Encoding::Gzip,
(false, false) => Encoding::Identity,
}
}
fn is_compressible(content_type: &str) -> bool {
let ct = content_type
.split(';')
.next()
.unwrap_or("")
.trim()
.to_ascii_lowercase();
ct.starts_with("text/")
|| ct == "application/json"
|| ct == "application/javascript"
|| ct == "application/manifest+json"
|| ct == "application/xml"
|| ct == "application/xhtml+xml"
|| ct == "application/rss+xml"
|| ct == "application/atom+xml"
|| ct == "application/wasm"
|| ct == "image/svg+xml"
|| ct.ends_with("+json")
|| ct.ends_with("+xml")
}
fn encode(enc: Encoding, data: &[u8], brotli_quality: u32) -> Option<Vec<u8>> {
match enc {
Encoding::Gzip => {
let mut e = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
e.write_all(data).ok()?;
e.finish().ok()
}
Encoding::Brotli => {
let mut out = Vec::new();
{
let mut w = brotli::CompressorWriter::new(&mut out, 4096, brotli_quality, 22);
w.write_all(data).ok()?;
} Some(out)
}
Encoding::Identity => None,
}
}
pub(crate) fn tag_vary_if_encoded<B>(mut response: Response<B>) -> Response<B> {
if response.headers().contains_key(header::CONTENT_ENCODING) {
response
.headers_mut()
.append(header::VARY, HeaderValue::from_static("Accept-Encoding"));
}
response
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn negotiate_picks_the_higher_q_when_client_states_a_preference() {
assert_eq!(
negotiate(Some("br;q=0.8, gzip;q=1.0"), true),
Encoding::Gzip,
"gzip has higher q; prefer_brotli is only a tie-breaker",
);
assert_eq!(
negotiate(Some("br;q=1.0, gzip;q=0.5"), false),
Encoding::Brotli,
"br has higher q; prefer_brotli=false doesn't override it",
);
}
#[test]
fn negotiate_uses_prefer_brotli_only_on_a_tie() {
assert_eq!(
negotiate(Some("br;q=0.7, gzip;q=0.7"), true),
Encoding::Brotli,
);
assert_eq!(
negotiate(Some("br;q=0.7, gzip;q=0.7"), false),
Encoding::Gzip,
);
assert_eq!(negotiate(Some("gzip, deflate, br"), true), Encoding::Brotli);
assert_eq!(negotiate(Some("gzip, deflate, br"), false), Encoding::Gzip);
}
#[test]
fn negotiate_treats_q_zero_as_explicit_disallow() {
assert_eq!(negotiate(Some("br;q=0, gzip"), true), Encoding::Gzip);
assert_eq!(
negotiate(Some("br;q=0, gzip;q=0"), true),
Encoding::Identity
);
}
#[test]
fn negotiate_wildcard_applies_to_unnamed_encodings() {
assert_eq!(negotiate(Some("*"), true), Encoding::Brotli);
assert_eq!(negotiate(Some("*;q=0.5"), true), Encoding::Brotli);
assert_eq!(negotiate(Some("*;q=0"), true), Encoding::Identity);
assert_eq!(negotiate(Some("gzip, *;q=0"), true), Encoding::Gzip);
assert_eq!(negotiate(Some("gzip, *;q=0.5"), true), Encoding::Gzip);
}
#[test]
fn negotiate_handles_only_one_offered() {
assert_eq!(negotiate(Some("gzip"), true), Encoding::Gzip);
assert_eq!(negotiate(Some("br"), false), Encoding::Brotli);
}
#[test]
fn negotiate_identity_only_means_no_encoding() {
assert_eq!(negotiate(Some("identity"), true), Encoding::Identity);
}
#[test]
fn negotiate_missing_header_means_no_compression() {
assert_eq!(negotiate(None, true), Encoding::Identity);
}
#[test]
fn negotiate_ignores_unknown_encodings() {
assert_eq!(
negotiate(Some("deflate, compress, x-gzip"), true),
Encoding::Identity,
);
}
#[test]
fn negotiate_tolerates_whitespace_and_casing() {
assert_eq!(
negotiate(Some(" BR ; Q=0.9 , GZip ; q=0.5 "), true),
Encoding::Brotli,
"case-insensitive name + Q=; tolerated whitespace",
);
}
#[test]
fn negotiate_rejects_out_of_range_q_silently() {
assert_eq!(negotiate(Some("br;q=2.0"), true), Encoding::Brotli);
assert_eq!(negotiate(Some("br;q=-1"), true), Encoding::Brotli);
}
#[test]
fn is_compressible_allowlist() {
assert!(is_compressible("application/json"));
assert!(is_compressible("application/vnd.api+json; charset=utf-8"));
assert!(is_compressible("text/html"));
assert!(is_compressible("image/svg+xml"));
assert!(!is_compressible("image/png"));
assert!(!is_compressible("application/zip"));
assert!(!is_compressible("application/octet-stream"));
}
#[test]
fn small_json_is_buffered_but_not_encoded() {
let out = CompressionLayer::new()
.compress_reply(ReplyData::Json(json!({"ok": true})), Some("br"));
match out {
ReplyData::Bytes { content_type, .. } => assert_eq!(content_type, "application/json"),
other => panic!("expected buffered Bytes, got {other:?}"),
}
}
#[test]
fn large_json_is_brotli_encoded_and_smaller() {
let big = json!({ "rows": (0..2000).map(|i| json!({"id": i, "name": "User Name"})).collect::<Vec<_>>() });
let original_len = serde_json::to_vec(&big).unwrap().len();
assert!(original_len > 10_000);
let out = CompressionLayer::new().compress_reply(ReplyData::Json(big), Some("br, gzip"));
match out {
ReplyData::Rich(spec) => {
assert_eq!(
spec.headers.get("content-encoding").map(String::as_str),
Some("br")
);
match &spec.payload {
ReplyData::Bytes { data, .. } => assert!(data.len() < original_len / 2),
other => panic!("expected Bytes payload, got {other:?}"),
}
}
other => panic!("expected Rich(compressed), got {other:?}"),
}
}
#[test]
fn no_accept_encoding_leaves_json_alone() {
let out = CompressionLayer::new().compress_reply(ReplyData::Json(json!({"a": 1})), None);
assert!(matches!(out, ReplyData::Json(_)));
}
#[test]
fn does_not_double_encode_an_already_encoded_reply() {
let big = json!({ "rows": (0..2000).map(|i| json!({"id": i})).collect::<Vec<_>>() });
let pre = ReplyData::Rich(Box::new(ReplySpec {
payload: ReplyData::Bytes {
content_type: "application/json".into(),
data: serde_json::to_vec(&big).unwrap(),
},
status: None,
headers: HashMap::from([("content-encoding".to_string(), "gzip".to_string())]),
}));
let out = CompressionLayer::new().compress_reply(pre, Some("br"));
match out {
ReplyData::Rich(spec) => {
assert_eq!(
spec.headers.get("content-encoding").map(String::as_str),
Some("gzip")
); }
other => panic!("expected Rich, got {other:?}"),
}
}
#[test]
fn tag_vary_appends_only_when_content_encoding_present() {
let with_ce = Response::builder()
.header(header::CONTENT_ENCODING, "br")
.body(())
.unwrap();
let tagged = tag_vary_if_encoded(with_ce);
assert_eq!(
tagged.headers().get(header::VARY).unwrap(),
"Accept-Encoding"
);
let without = Response::builder().body(()).unwrap();
let untagged = tag_vary_if_encoded(without);
assert!(untagged.headers().get(header::VARY).is_none());
}
fn big_compressible_rich(headers: HashMap<String, String>) -> ReplyData {
let big = json!({ "rows": (0..2000).map(|i| json!({"id": i})).collect::<Vec<_>>() });
ReplyData::Rich(Box::new(ReplySpec {
payload: ReplyData::Json(big),
status: None,
headers,
}))
}
#[test]
fn no_transform_directive_skips_compression_entirely() {
let pre = big_compressible_rich(HashMap::from([(
"Cache-Control".into(),
"no-transform".into(),
)]));
let out = CompressionLayer::new().compress_reply(pre, Some("br, gzip"));
match out {
ReplyData::Rich(spec) => {
assert!(
!spec
.headers
.keys()
.any(|k| k.eq_ignore_ascii_case("content-encoding")),
"no-transform forbids compression; no Content-Encoding should be set",
);
assert!(
matches!(spec.payload, ReplyData::Json(_)),
"payload should be untouched (still Json, not lifted to Bytes)",
);
}
other => panic!("expected Rich passing through unchanged, got {other:?}"),
}
}
#[test]
fn no_transform_is_case_insensitive_and_robust_to_other_directives() {
for header_name in ["cache-control", "Cache-Control", "CACHE-CONTROL"] {
for value in [
"no-transform",
"no-cache, no-transform",
"private, no-transform, max-age=0",
" no-transform ", "no-cache, NO-TRANSFORM",
] {
let pre =
big_compressible_rich(HashMap::from([(header_name.into(), value.into())]));
let out = CompressionLayer::new().compress_reply(pre, Some("br"));
match out {
ReplyData::Rich(spec) => assert!(
!spec
.headers
.keys()
.any(|k| k.eq_ignore_ascii_case("content-encoding")),
"no-transform should suppress compression for header `{header_name}: {value}`",
),
other => panic!("expected Rich, got {other:?}"),
}
}
}
}
#[test]
fn other_cache_control_directives_do_not_disable_compression() {
for value in ["no-cache", "no-store", "private", "max-age=0"] {
let pre =
big_compressible_rich(HashMap::from([("Cache-Control".into(), value.into())]));
let out = CompressionLayer::new().compress_reply(pre, Some("br"));
match out {
ReplyData::Rich(spec) => assert_eq!(
spec.headers.get("content-encoding").map(String::as_str),
Some("br"),
"compression should still run for header `Cache-Control: {value}`",
),
other => panic!("expected Rich, got {other:?}"),
}
}
}
#[test]
fn no_transform_only_applies_to_rich_replies() {
let big = json!({ "rows": (0..2000).map(|i| json!({"id": i})).collect::<Vec<_>>() });
let out = CompressionLayer::new().compress_reply(ReplyData::Json(big), Some("br"));
match out {
ReplyData::Rich(spec) => {
assert_eq!(
spec.headers.get("content-encoding").map(String::as_str),
Some("br"),
);
}
other => panic!("expected Rich (compressed), got {other:?}"),
}
}
#[test]
fn quality_setting_changes_brotli_output() {
let payload = json!({ "rows": (0..2000).map(|i| json!({"id": i})).collect::<Vec<_>>() });
let bytes = serde_json::to_vec(&payload).unwrap();
let fast = encode(Encoding::Brotli, &bytes, 0).unwrap();
let best = encode(Encoding::Brotli, &bytes, 11).unwrap();
assert_ne!(
fast, best,
"quality 0 and quality 11 should produce different brotli outputs",
);
assert!(best.len() <= fast.len());
}
#[test]
fn quality_clamps_to_eleven() {
let layer = CompressionLayer::new().brotli_quality(99);
let payload = json!({"x": "y".repeat(2000)});
let out = layer.compress_reply(ReplyData::Json(payload), Some("br"));
assert!(matches!(out, ReplyData::Rich(_)));
}
}