1use actus_reply::{ReplyData, ReplySpec};
20use http::{HeaderValue, Response, header};
21use std::borrow::Cow;
22use std::collections::HashMap;
23use std::io::Write;
24
25#[derive(Clone, Copy, PartialEq, Eq, Debug)]
27enum Encoding {
28 Brotli,
29 Gzip,
30 Identity,
31}
32
33#[derive(Clone, Debug)]
35pub struct CompressionLayer {
36 min_size: usize,
37 prefer_brotli: bool,
38 brotli_quality: u32,
39}
40
41const DEFAULT_BROTLI_QUALITY: u32 = 4;
45
46impl Default for CompressionLayer {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52impl CompressionLayer {
53 pub fn new() -> Self {
57 Self {
58 min_size: 1024,
59 prefer_brotli: true,
60 brotli_quality: DEFAULT_BROTLI_QUALITY,
61 }
62 }
63
64 pub fn min_size(mut self, bytes: usize) -> Self {
68 self.min_size = bytes;
69 self
70 }
71
72 pub fn prefer_gzip(mut self) -> Self {
75 self.prefer_brotli = false;
76 self
77 }
78
79 pub fn brotli_quality(mut self, q: u32) -> Self {
89 self.brotli_quality = q.min(11);
90 self
91 }
92
93 pub(crate) fn compress_reply(
99 &self,
100 data: ReplyData,
101 accept_encoding: Option<&str>,
102 ) -> ReplyData {
103 if let ReplyData::Rich(spec) = &data
110 && spec.headers.iter().any(|(k, v)| {
111 k.eq_ignore_ascii_case("cache-control")
112 && v.split(',')
113 .any(|t| t.trim().eq_ignore_ascii_case("no-transform"))
114 })
115 {
116 return data;
117 }
118
119 let enc = match negotiate(accept_encoding, self.prefer_brotli) {
120 Encoding::Identity => return data,
121 other => other,
122 };
123 match data {
124 ReplyData::Rich(mut spec) => {
128 if spec
129 .headers
130 .keys()
131 .any(|k| k.eq_ignore_ascii_case("content-encoding"))
132 {
133 return ReplyData::Rich(spec);
134 }
135 let inner = std::mem::replace(&mut spec.payload, ReplyData::Empty);
136 let (payload, encoded_as) = self.compress_payload(inner, enc);
137 spec.payload = payload;
138 if let Some(name) = encoded_as {
139 spec.headers
140 .insert("content-encoding".to_string(), name.to_string());
141 }
142 ReplyData::Rich(spec)
143 }
144 other => match self.compress_payload(other, enc) {
145 (payload, Some(name)) => ReplyData::Rich(Box::new(ReplySpec {
146 payload,
147 status: None,
148 headers: HashMap::from([("content-encoding".to_string(), name.to_string())]),
149 })),
150 (payload, None) => payload,
151 },
152 }
153 }
154
155 fn compress_payload(
160 &self,
161 payload: ReplyData,
162 enc: Encoding,
163 ) -> (ReplyData, Option<&'static str>) {
164 let name = match enc {
165 Encoding::Gzip => "gzip",
166 Encoding::Brotli => "br",
167 Encoding::Identity => return (payload, None),
168 };
169 match payload {
170 ReplyData::Json(value) => {
171 let bytes = match serde_json::to_vec(&value) {
172 Ok(b) => b,
173 Err(_) => return (ReplyData::Json(value), None),
175 };
176 let json: Cow<'static, str> = Cow::Borrowed("application/json");
177 if bytes.len() < self.min_size {
178 return (
179 ReplyData::Bytes {
180 content_type: json,
181 data: bytes,
182 },
183 None,
184 );
185 }
186 match encode(enc, &bytes, self.brotli_quality) {
187 Some(out) if out.len() < bytes.len() => (
188 ReplyData::Bytes {
189 content_type: json,
190 data: out,
191 },
192 Some(name),
193 ),
194 _ => (
195 ReplyData::Bytes {
196 content_type: json,
197 data: bytes,
198 },
199 None,
200 ),
201 }
202 }
203 ReplyData::Bytes { content_type, data } => {
204 if data.len() < self.min_size || !is_compressible(&content_type) {
205 return (ReplyData::Bytes { content_type, data }, None);
206 }
207 match encode(enc, &data, self.brotli_quality) {
208 Some(out) if out.len() < data.len() => (
209 ReplyData::Bytes {
210 content_type,
211 data: out,
212 },
213 Some(name),
214 ),
215 _ => (ReplyData::Bytes { content_type, data }, None),
216 }
217 }
218 other => (other, None),
221 }
222 }
223}
224
225fn negotiate(accept_encoding: Option<&str>, prefer_brotli: bool) -> Encoding {
235 let Some(ae) = accept_encoding else {
236 return Encoding::Identity;
237 };
238
239 let mut br_q: Option<f32> = None;
240 let mut gzip_q: Option<f32> = None;
241 let mut star_q: Option<f32> = None;
242
243 for token in ae.split(',') {
244 let mut parts = token.split(';');
245 let name = parts.next().map(str::trim).unwrap_or("");
246 let mut q: f32 = 1.0;
249 for p in parts {
250 let p = p.trim();
251 if let Some(qs) = p.strip_prefix("q=").or_else(|| p.strip_prefix("Q="))
252 && let Ok(v) = qs.parse::<f32>()
253 && (0.0..=1.0).contains(&v)
254 {
255 q = v;
256 }
257 }
258 match name.to_ascii_lowercase().as_str() {
259 "br" => br_q = Some(q),
260 "gzip" => gzip_q = Some(q),
261 "*" => star_q = Some(q),
262 _ => {}
265 }
266 }
267
268 let br = br_q.or(star_q).unwrap_or(0.0);
270 let gzip = gzip_q.or(star_q).unwrap_or(0.0);
271
272 let br_ok = br > 0.0;
273 let gzip_ok = gzip > 0.0;
274 match (br_ok, gzip_ok) {
275 (true, true) => {
276 if (br - gzip).abs() < f32::EPSILON {
280 if prefer_brotli {
281 Encoding::Brotli
282 } else {
283 Encoding::Gzip
284 }
285 } else if br > gzip {
286 Encoding::Brotli
287 } else {
288 Encoding::Gzip
289 }
290 }
291 (true, false) => Encoding::Brotli,
292 (false, true) => Encoding::Gzip,
293 (false, false) => Encoding::Identity,
294 }
295}
296
297fn is_compressible(content_type: &str) -> bool {
301 let ct = content_type
302 .split(';')
303 .next()
304 .unwrap_or("")
305 .trim()
306 .to_ascii_lowercase();
307 ct.starts_with("text/")
308 || ct == "application/json"
309 || ct == "application/javascript"
310 || ct == "application/manifest+json"
311 || ct == "application/xml"
312 || ct == "application/xhtml+xml"
313 || ct == "application/rss+xml"
314 || ct == "application/atom+xml"
315 || ct == "application/wasm"
316 || ct == "image/svg+xml"
317 || ct.ends_with("+json")
318 || ct.ends_with("+xml")
319}
320
321fn encode(enc: Encoding, data: &[u8], brotli_quality: u32) -> Option<Vec<u8>> {
322 match enc {
323 Encoding::Gzip => {
324 let mut e = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
325 e.write_all(data).ok()?;
326 e.finish().ok()
327 }
328 Encoding::Brotli => {
329 let mut out = Vec::new();
330 {
331 let mut w = brotli::CompressorWriter::new(&mut out, 4096, brotli_quality, 22);
335 w.write_all(data).ok()?;
336 } Some(out)
338 }
339 Encoding::Identity => None,
340 }
341}
342
343pub(crate) fn tag_vary_if_encoded<B>(mut response: Response<B>) -> Response<B> {
347 if response.headers().contains_key(header::CONTENT_ENCODING) {
348 response
349 .headers_mut()
350 .append(header::VARY, HeaderValue::from_static("Accept-Encoding"));
351 }
352 response
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use serde_json::json;
359
360 #[test]
361 fn negotiate_picks_the_higher_q_when_client_states_a_preference() {
362 assert_eq!(
365 negotiate(Some("br;q=0.8, gzip;q=1.0"), true),
366 Encoding::Gzip,
367 "gzip has higher q; prefer_brotli is only a tie-breaker",
368 );
369 assert_eq!(
370 negotiate(Some("br;q=1.0, gzip;q=0.5"), false),
371 Encoding::Brotli,
372 "br has higher q; prefer_brotli=false doesn't override it",
373 );
374 }
375
376 #[test]
377 fn negotiate_uses_prefer_brotli_only_on_a_tie() {
378 assert_eq!(
380 negotiate(Some("br;q=0.7, gzip;q=0.7"), true),
381 Encoding::Brotli,
382 );
383 assert_eq!(
384 negotiate(Some("br;q=0.7, gzip;q=0.7"), false),
385 Encoding::Gzip,
386 );
387 assert_eq!(negotiate(Some("gzip, deflate, br"), true), Encoding::Brotli);
389 assert_eq!(negotiate(Some("gzip, deflate, br"), false), Encoding::Gzip);
390 }
391
392 #[test]
393 fn negotiate_treats_q_zero_as_explicit_disallow() {
394 assert_eq!(negotiate(Some("br;q=0, gzip"), true), Encoding::Gzip);
396 assert_eq!(
398 negotiate(Some("br;q=0, gzip;q=0"), true),
399 Encoding::Identity
400 );
401 }
402
403 #[test]
404 fn negotiate_wildcard_applies_to_unnamed_encodings() {
405 assert_eq!(negotiate(Some("*"), true), Encoding::Brotli);
406 assert_eq!(negotiate(Some("*;q=0.5"), true), Encoding::Brotli);
407 assert_eq!(negotiate(Some("*;q=0"), true), Encoding::Identity);
409 assert_eq!(negotiate(Some("gzip, *;q=0"), true), Encoding::Gzip);
411 assert_eq!(negotiate(Some("gzip, *;q=0.5"), true), Encoding::Gzip);
415 }
416
417 #[test]
418 fn negotiate_handles_only_one_offered() {
419 assert_eq!(negotiate(Some("gzip"), true), Encoding::Gzip);
420 assert_eq!(negotiate(Some("br"), false), Encoding::Brotli);
421 }
422
423 #[test]
424 fn negotiate_identity_only_means_no_encoding() {
425 assert_eq!(negotiate(Some("identity"), true), Encoding::Identity);
427 }
428
429 #[test]
430 fn negotiate_missing_header_means_no_compression() {
431 assert_eq!(negotiate(None, true), Encoding::Identity);
435 }
436
437 #[test]
438 fn negotiate_ignores_unknown_encodings() {
439 assert_eq!(
442 negotiate(Some("deflate, compress, x-gzip"), true),
443 Encoding::Identity,
444 );
445 }
446
447 #[test]
448 fn negotiate_tolerates_whitespace_and_casing() {
449 assert_eq!(
450 negotiate(Some(" BR ; Q=0.9 , GZip ; q=0.5 "), true),
451 Encoding::Brotli,
452 "case-insensitive name + Q=; tolerated whitespace",
453 );
454 }
455
456 #[test]
457 fn negotiate_rejects_out_of_range_q_silently() {
458 assert_eq!(negotiate(Some("br;q=2.0"), true), Encoding::Brotli);
461 assert_eq!(negotiate(Some("br;q=-1"), true), Encoding::Brotli);
462 }
463
464 #[test]
465 fn is_compressible_allowlist() {
466 assert!(is_compressible("application/json"));
467 assert!(is_compressible("application/vnd.api+json; charset=utf-8"));
468 assert!(is_compressible("text/html"));
469 assert!(is_compressible("image/svg+xml"));
470 assert!(!is_compressible("image/png"));
471 assert!(!is_compressible("application/zip"));
472 assert!(!is_compressible("application/octet-stream"));
473 }
474
475 #[test]
476 fn small_json_is_buffered_but_not_encoded() {
477 let out = CompressionLayer::new()
478 .compress_reply(ReplyData::Json(json!({"ok": true})), Some("br"));
479 match out {
480 ReplyData::Bytes { content_type, .. } => assert_eq!(content_type, "application/json"),
481 other => panic!("expected buffered Bytes, got {other:?}"),
482 }
483 }
484
485 #[test]
486 fn large_json_is_brotli_encoded_and_smaller() {
487 let big = json!({ "rows": (0..2000).map(|i| json!({"id": i, "name": "User Name"})).collect::<Vec<_>>() });
489 let original_len = serde_json::to_vec(&big).unwrap().len();
490 assert!(original_len > 10_000);
491 let out = CompressionLayer::new().compress_reply(ReplyData::Json(big), Some("br, gzip"));
492 match out {
493 ReplyData::Rich(spec) => {
494 assert_eq!(
495 spec.headers.get("content-encoding").map(String::as_str),
496 Some("br")
497 );
498 match &spec.payload {
499 ReplyData::Bytes { data, .. } => assert!(data.len() < original_len / 2),
500 other => panic!("expected Bytes payload, got {other:?}"),
501 }
502 }
503 other => panic!("expected Rich(compressed), got {other:?}"),
504 }
505 }
506
507 #[test]
508 fn no_accept_encoding_leaves_json_alone() {
509 let out = CompressionLayer::new().compress_reply(ReplyData::Json(json!({"a": 1})), None);
510 assert!(matches!(out, ReplyData::Json(_)));
511 }
512
513 #[test]
514 fn does_not_double_encode_an_already_encoded_reply() {
515 let big = json!({ "rows": (0..2000).map(|i| json!({"id": i})).collect::<Vec<_>>() });
516 let pre = ReplyData::Rich(Box::new(ReplySpec {
517 payload: ReplyData::Bytes {
518 content_type: "application/json".into(),
519 data: serde_json::to_vec(&big).unwrap(),
520 },
521 status: None,
522 headers: HashMap::from([("content-encoding".to_string(), "gzip".to_string())]),
523 }));
524 let out = CompressionLayer::new().compress_reply(pre, Some("br"));
525 match out {
526 ReplyData::Rich(spec) => {
527 assert_eq!(
528 spec.headers.get("content-encoding").map(String::as_str),
529 Some("gzip")
530 ); }
532 other => panic!("expected Rich, got {other:?}"),
533 }
534 }
535
536 #[test]
537 fn tag_vary_appends_only_when_content_encoding_present() {
538 let with_ce = Response::builder()
539 .header(header::CONTENT_ENCODING, "br")
540 .body(())
541 .unwrap();
542 let tagged = tag_vary_if_encoded(with_ce);
543 assert_eq!(
544 tagged.headers().get(header::VARY).unwrap(),
545 "Accept-Encoding"
546 );
547
548 let without = Response::builder().body(()).unwrap();
549 let untagged = tag_vary_if_encoded(without);
550 assert!(untagged.headers().get(header::VARY).is_none());
551 }
552
553 fn big_compressible_rich(headers: HashMap<String, String>) -> ReplyData {
556 let big = json!({ "rows": (0..2000).map(|i| json!({"id": i})).collect::<Vec<_>>() });
560 ReplyData::Rich(Box::new(ReplySpec {
561 payload: ReplyData::Json(big),
562 status: None,
563 headers,
564 }))
565 }
566
567 #[test]
568 fn no_transform_directive_skips_compression_entirely() {
569 let pre = big_compressible_rich(HashMap::from([(
573 "Cache-Control".into(),
574 "no-transform".into(),
575 )]));
576 let out = CompressionLayer::new().compress_reply(pre, Some("br, gzip"));
577 match out {
580 ReplyData::Rich(spec) => {
581 assert!(
582 !spec
583 .headers
584 .keys()
585 .any(|k| k.eq_ignore_ascii_case("content-encoding")),
586 "no-transform forbids compression; no Content-Encoding should be set",
587 );
588 assert!(
589 matches!(spec.payload, ReplyData::Json(_)),
590 "payload should be untouched (still Json, not lifted to Bytes)",
591 );
592 }
593 other => panic!("expected Rich passing through unchanged, got {other:?}"),
594 }
595 }
596
597 #[test]
598 fn no_transform_is_case_insensitive_and_robust_to_other_directives() {
599 for header_name in ["cache-control", "Cache-Control", "CACHE-CONTROL"] {
604 for value in [
605 "no-transform",
606 "no-cache, no-transform",
607 "private, no-transform, max-age=0",
608 " no-transform ", "no-cache, NO-TRANSFORM",
610 ] {
611 let pre =
612 big_compressible_rich(HashMap::from([(header_name.into(), value.into())]));
613 let out = CompressionLayer::new().compress_reply(pre, Some("br"));
614 match out {
615 ReplyData::Rich(spec) => assert!(
616 !spec
617 .headers
618 .keys()
619 .any(|k| k.eq_ignore_ascii_case("content-encoding")),
620 "no-transform should suppress compression for header `{header_name}: {value}`",
621 ),
622 other => panic!("expected Rich, got {other:?}"),
623 }
624 }
625 }
626 }
627
628 #[test]
629 fn other_cache_control_directives_do_not_disable_compression() {
630 for value in ["no-cache", "no-store", "private", "max-age=0"] {
633 let pre =
634 big_compressible_rich(HashMap::from([("Cache-Control".into(), value.into())]));
635 let out = CompressionLayer::new().compress_reply(pre, Some("br"));
636 match out {
637 ReplyData::Rich(spec) => assert_eq!(
638 spec.headers.get("content-encoding").map(String::as_str),
639 Some("br"),
640 "compression should still run for header `Cache-Control: {value}`",
641 ),
642 other => panic!("expected Rich, got {other:?}"),
643 }
644 }
645 }
646
647 #[test]
648 fn no_transform_only_applies_to_rich_replies() {
649 let big = json!({ "rows": (0..2000).map(|i| json!({"id": i})).collect::<Vec<_>>() });
654 let out = CompressionLayer::new().compress_reply(ReplyData::Json(big), Some("br"));
655 match out {
656 ReplyData::Rich(spec) => {
657 assert_eq!(
658 spec.headers.get("content-encoding").map(String::as_str),
659 Some("br"),
660 );
661 }
662 other => panic!("expected Rich (compressed), got {other:?}"),
663 }
664 }
665
666 #[test]
669 fn quality_setting_changes_brotli_output() {
670 let payload = json!({ "rows": (0..2000).map(|i| json!({"id": i})).collect::<Vec<_>>() });
675 let bytes = serde_json::to_vec(&payload).unwrap();
676 let fast = encode(Encoding::Brotli, &bytes, 0).unwrap();
677 let best = encode(Encoding::Brotli, &bytes, 11).unwrap();
678 assert_ne!(
679 fast, best,
680 "quality 0 and quality 11 should produce different brotli outputs",
681 );
682 assert!(best.len() <= fast.len());
684 }
685
686 #[test]
687 fn quality_clamps_to_eleven() {
688 let layer = CompressionLayer::new().brotli_quality(99);
691 let payload = json!({"x": "y".repeat(2000)});
695 let out = layer.compress_reply(ReplyData::Json(payload), Some("br"));
696 assert!(matches!(out, ReplyData::Rich(_)));
697 }
698}