use thiserror::Error;
#[derive(Debug, Error)]
pub enum CompressionError {
#[error("compression chain exceeded the {0}-layer safety cap")]
ChainTooDeep(usize),
#[error("gzip encoder error: {0}")]
Gzip(std::io::Error),
#[error("deflate encoder error: {0}")]
Deflate(std::io::Error),
#[error("brotli encoder error: {0}")]
Brotli(std::io::Error),
#[error(
"decompression bomb: output exceeded {cap_bytes}-byte cap \
({observed_bytes} bytes produced) — aborted before OOM"
)]
DecompressionBomb {
cap_bytes: usize,
observed_bytes: usize,
},
}
pub const MAX_CHAIN_LAYERS: usize = 16;
pub const DECOMPRESSED_BODY_MAX_BYTES: usize = wafrift_types::MAX_RESPONSE_BODY_BYTES;
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum Algorithm {
Gzip,
Deflate,
Brotli,
Identity,
}
impl Algorithm {
#[must_use]
pub fn content_encoding(self) -> &'static str {
match self {
Self::Gzip => "gzip",
Self::Deflate => "deflate",
Self::Brotli => "br",
Self::Identity => "identity",
}
}
#[must_use]
pub fn from_token(token: &str) -> Option<Self> {
match token.trim().to_ascii_lowercase().as_str() {
"gzip" | "x-gzip" => Some(Self::Gzip),
"deflate" => Some(Self::Deflate),
"br" => Some(Self::Brotli),
"identity" => Some(Self::Identity),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CompressedBody {
pub body: Vec<u8>,
pub content_encoding: String,
}
pub fn compress(body: &[u8], algo: Algorithm) -> Result<CompressedBody, CompressionError> {
let bytes = compress_bytes(body, algo)?;
Ok(CompressedBody {
body: bytes,
content_encoding: algo.content_encoding().to_string(),
})
}
fn compress_bytes(body: &[u8], algo: Algorithm) -> Result<Vec<u8>, CompressionError> {
use std::io::Write;
match algo {
Algorithm::Identity => Ok(body.to_vec()),
Algorithm::Gzip => {
let mut enc = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
enc.write_all(body).map_err(CompressionError::Gzip)?;
enc.finish().map_err(CompressionError::Gzip)
}
Algorithm::Deflate => {
let mut enc =
flate2::write::DeflateEncoder::new(Vec::new(), flate2::Compression::default());
enc.write_all(body).map_err(CompressionError::Deflate)?;
enc.finish().map_err(CompressionError::Deflate)
}
Algorithm::Brotli => {
let mut out = Vec::new();
let mut writer = brotli::CompressorWriter::new(&mut out, 4096, 6, 22);
writer.write_all(body).map_err(CompressionError::Brotli)?;
writer.flush().map_err(CompressionError::Brotli)?;
drop(writer);
Ok(out)
}
}
}
pub fn chain(body: &[u8], algos: &[Algorithm]) -> Result<CompressedBody, CompressionError> {
if algos.len() > MAX_CHAIN_LAYERS {
return Err(CompressionError::ChainTooDeep(MAX_CHAIN_LAYERS));
}
if algos.is_empty() {
return Ok(CompressedBody {
body: body.to_vec(),
content_encoding: Algorithm::Identity.content_encoding().to_string(),
});
}
let mut current = body.to_vec();
for algo in algos.iter().rev() {
current = compress_bytes(¤t, *algo)?;
}
let header = algos
.iter()
.map(|a| a.content_encoding())
.collect::<Vec<_>>()
.join(", ");
Ok(CompressedBody {
body: current,
content_encoding: header,
})
}
pub fn decompress(blob: &CompressedBody) -> Result<Vec<u8>, CompressionError> {
let algos: Vec<Algorithm> = blob
.content_encoding
.split(',')
.filter_map(Algorithm::from_token)
.collect();
if algos.len() > MAX_CHAIN_LAYERS {
return Err(CompressionError::ChainTooDeep(MAX_CHAIN_LAYERS));
}
let mut current = blob.body.clone();
for algo in &algos {
current = decompress_bytes(¤t, *algo)?;
}
Ok(current)
}
fn drain_capped<R: std::io::Read>(
mut reader: R,
map_io: fn(std::io::Error) -> CompressionError,
) -> Result<Vec<u8>, CompressionError> {
use std::io::Read;
let cap = DECOMPRESSED_BODY_MAX_BYTES;
let mut out = Vec::with_capacity(8 * 1024);
let mut limited = (&mut reader).take((cap as u64) + 1);
limited.read_to_end(&mut out).map_err(map_io)?;
if out.len() > cap {
return Err(CompressionError::DecompressionBomb {
cap_bytes: cap,
observed_bytes: out.len(),
});
}
Ok(out)
}
fn decompress_bytes(bytes: &[u8], algo: Algorithm) -> Result<Vec<u8>, CompressionError> {
match algo {
Algorithm::Identity => {
if bytes.len() > DECOMPRESSED_BODY_MAX_BYTES {
return Err(CompressionError::DecompressionBomb {
cap_bytes: DECOMPRESSED_BODY_MAX_BYTES,
observed_bytes: bytes.len(),
});
}
Ok(bytes.to_vec())
}
Algorithm::Gzip => {
drain_capped(flate2::read::GzDecoder::new(bytes), CompressionError::Gzip)
}
Algorithm::Deflate => drain_capped(
flate2::read::DeflateDecoder::new(bytes),
CompressionError::Deflate,
),
Algorithm::Brotli => drain_capped(
brotli::Decompressor::new(bytes, 4096),
CompressionError::Brotli,
),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn content_encoding_tokens_match_rfc_registry() {
assert_eq!(Algorithm::Gzip.content_encoding(), "gzip");
assert_eq!(Algorithm::Deflate.content_encoding(), "deflate");
assert_eq!(Algorithm::Brotli.content_encoding(), "br");
assert_eq!(Algorithm::Identity.content_encoding(), "identity");
}
#[test]
fn from_token_is_case_insensitive_and_trim_tolerant() {
for spelling in ["gzip", "GZIP", "Gzip", " gzip ", "\tgzip"] {
assert_eq!(Algorithm::from_token(spelling), Some(Algorithm::Gzip));
}
}
#[test]
fn from_token_accepts_x_gzip_alias() {
assert_eq!(Algorithm::from_token("x-gzip"), Some(Algorithm::Gzip));
assert_eq!(Algorithm::from_token("X-GZIP"), Some(Algorithm::Gzip));
}
#[test]
fn from_token_rejects_unknown_codings() {
assert_eq!(Algorithm::from_token(""), None);
assert_eq!(Algorithm::from_token("snappy"), None);
assert_eq!(Algorithm::from_token("lz4"), None);
assert_eq!(Algorithm::from_token("compress"), None);
}
#[test]
fn gzip_round_trip_preserves_payload() {
let original = b"' OR 1=1--";
let compressed = compress(original, Algorithm::Gzip).expect("gzip");
assert_eq!(compressed.content_encoding, "gzip");
assert_ne!(compressed.body.as_slice(), original);
let recovered = decompress(&compressed).expect("decompress");
assert_eq!(recovered, original);
}
#[test]
fn deflate_round_trip_preserves_payload() {
let original = b"<script>alert(1)</script>";
let compressed = compress(original, Algorithm::Deflate).expect("deflate");
assert_eq!(compressed.content_encoding, "deflate");
let recovered = decompress(&compressed).expect("decompress");
assert_eq!(recovered, original);
}
#[test]
fn brotli_round_trip_preserves_payload() {
let original = b"http://127.0.0.1:9000/admin?cmd=id";
let compressed = compress(original, Algorithm::Brotli).expect("brotli");
assert_eq!(compressed.content_encoding, "br");
let recovered = decompress(&compressed).expect("decompress");
assert_eq!(recovered, original);
}
#[test]
fn identity_is_passthrough_with_identity_header() {
let original = b"plain text";
let compressed = compress(original, Algorithm::Identity).expect("identity");
assert_eq!(compressed.body, original);
assert_eq!(compressed.content_encoding, "identity");
}
#[test]
fn chain_with_one_algo_matches_single_compress() {
let original = b"single layer";
let chained = chain(original, &[Algorithm::Gzip]).expect("chain");
let single = compress(original, Algorithm::Gzip).expect("compress");
assert_eq!(chained, single);
}
#[test]
fn chain_with_two_algos_round_trips() {
let original = b"' UNION SELECT username,password FROM users --";
let chained = chain(original, &[Algorithm::Gzip, Algorithm::Brotli]).expect("chain");
assert_eq!(chained.content_encoding, "gzip, br");
let recovered = decompress(&chained).expect("decompress");
assert_eq!(recovered, original);
}
#[test]
fn chain_empty_algos_returns_identity_body() {
let original = b"unchanged";
let chained = chain(original, &[]).expect("empty chain");
assert_eq!(chained.body, original);
assert_eq!(chained.content_encoding, "identity");
}
#[test]
fn chain_above_cap_returns_too_deep_error() {
let too_many: Vec<Algorithm> = (0..MAX_CHAIN_LAYERS + 1).map(|_| Algorithm::Gzip).collect();
let result = chain(b"payload", &too_many);
match result {
Err(CompressionError::ChainTooDeep(cap)) => assert_eq!(cap, MAX_CHAIN_LAYERS),
other => panic!("expected ChainTooDeep error, got {other:?}"),
}
}
#[test]
fn chain_at_exactly_cap_succeeds() {
let just_enough: Vec<Algorithm> =
(0..MAX_CHAIN_LAYERS).map(|_| Algorithm::Identity).collect();
let chained = chain(b"x", &just_enough).expect("at-cap chain ok");
assert_eq!(chained.body, b"x");
}
#[test]
fn chain_with_identity_in_the_middle_is_transparent() {
let original = b"middle identity";
let with_id = chain(
original,
&[Algorithm::Gzip, Algorithm::Identity, Algorithm::Brotli],
)
.expect("chain with identity");
let without =
chain(original, &[Algorithm::Gzip, Algorithm::Brotli]).expect("chain without identity");
assert_eq!(
with_id.body, without.body,
"identity must be byte-transparent"
);
assert_eq!(with_id.content_encoding, "gzip, identity, br");
let recovered = decompress(&with_id).expect("decompress with id");
assert_eq!(recovered, original);
}
#[test]
fn empty_body_compresses_and_round_trips() {
for algo in [
Algorithm::Gzip,
Algorithm::Deflate,
Algorithm::Brotli,
Algorithm::Identity,
] {
let compressed =
compress(b"", algo).unwrap_or_else(|e| panic!("empty body with {algo:?}: {e}"));
let recovered = decompress(&compressed)
.unwrap_or_else(|e| panic!("empty body decode with {algo:?}: {e}"));
assert_eq!(recovered, Vec::<u8>::new());
}
}
#[test]
fn one_byte_body_round_trips_under_every_algorithm() {
for algo in [
Algorithm::Gzip,
Algorithm::Deflate,
Algorithm::Brotli,
Algorithm::Identity,
] {
let original = &[0xAB_u8][..];
let compressed = compress(original, algo).expect("compress");
let recovered = decompress(&compressed).expect("decompress");
assert_eq!(recovered, original);
}
}
#[test]
fn large_body_64_kib_round_trips_without_oom() {
let original: Vec<u8> = (0..(64 * 1024)).map(|i| (i % 251) as u8).collect();
for algo in [Algorithm::Gzip, Algorithm::Deflate, Algorithm::Brotli] {
let compressed = compress(&original, algo).expect("compress");
assert!(
compressed.body.len() < original.len(),
"{algo:?} should compress this pattern, got {} >= {}",
compressed.body.len(),
original.len()
);
let recovered = decompress(&compressed).expect("decompress");
assert_eq!(recovered, original);
}
}
#[test]
fn incompressible_body_does_not_panic_on_brotli() {
let mut original = vec![0u8; 1024];
for (i, b) in original.iter_mut().enumerate() {
*b = ((i.wrapping_mul(2654435769)) & 0xFF) as u8;
}
let compressed = compress(&original, Algorithm::Brotli).expect("brotli");
let recovered = decompress(&compressed).expect("decompress");
assert_eq!(recovered, original);
}
#[test]
fn decompress_with_unknown_coding_token_skips_it() {
let body = b"hello";
let compressed = compress(body, Algorithm::Gzip).unwrap();
let with_unknown = CompressedBody {
content_encoding: format!("snappy, {}", compressed.content_encoding),
body: compressed.body,
};
let recovered = decompress(&with_unknown).expect("permissive decompress");
assert_eq!(recovered, body);
}
#[test]
fn decompress_rejects_more_than_max_chain_layers() {
let header = std::iter::repeat_n("gzip", MAX_CHAIN_LAYERS + 1)
.collect::<Vec<_>>()
.join(", ");
let blob = CompressedBody {
content_encoding: header,
body: Vec::new(),
};
match decompress(&blob) {
Err(CompressionError::ChainTooDeep(cap)) => assert_eq!(cap, MAX_CHAIN_LAYERS),
other => panic!("expected ChainTooDeep, got {other:?}"),
}
}
#[test]
fn decompress_layer_cap_counts_recognised_codings_only() {
let body = b"hello world";
let compressed = compress(body, Algorithm::Gzip).unwrap();
let mut tokens: Vec<String> = std::iter::repeat_n("snappy", MAX_CHAIN_LAYERS + 5)
.map(str::to_string)
.collect();
tokens.push(compressed.content_encoding.clone());
let blob = CompressedBody {
content_encoding: tokens.join(", "),
body: compressed.body,
};
let recovered = decompress(&blob).expect("unknown-padded header is a 1-layer decode");
assert_eq!(recovered, body);
}
#[test]
fn round_trip_property_holds_across_a_variety_of_payloads() {
let corpus: &[&[u8]] = &[
b"",
b"x",
b"' OR 1=1--",
b"<script>alert(document.cookie)</script>",
b"http://127.0.0.1/admin",
b"; cat /etc/passwd",
b"\x00\x01\x02\x03\xff\xfe",
b"the quick brown fox jumps over the lazy dog the quick brown fox",
];
for payload in corpus {
for algo in [
Algorithm::Gzip,
Algorithm::Deflate,
Algorithm::Brotli,
Algorithm::Identity,
] {
let c = compress(payload, algo)
.unwrap_or_else(|e| panic!("{algo:?} on {payload:?}: {e}"));
let r = decompress(&c)
.unwrap_or_else(|e| panic!("decompress {algo:?} on {payload:?}: {e}"));
assert_eq!(r, *payload, "{algo:?} round-trip mismatch on {payload:?}");
}
}
}
#[test]
fn identity_decompress_rejects_oversize_input() {
let oversized = vec![0u8; DECOMPRESSED_BODY_MAX_BYTES + 1];
let err = super::decompress_bytes(&oversized, Algorithm::Identity)
.expect_err("identity decompress must refuse > cap input");
match err {
CompressionError::DecompressionBomb {
cap_bytes,
observed_bytes,
} => {
assert_eq!(cap_bytes, DECOMPRESSED_BODY_MAX_BYTES);
assert_eq!(observed_bytes, DECOMPRESSED_BODY_MAX_BYTES + 1);
}
other => panic!("expected DecompressionBomb, got {other:?}"),
}
}
#[test]
fn gzip_decompress_under_cap_succeeds() {
use std::io::Write;
let mut enc = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
enc.write_all(&vec![0u8; 1024 * 1024]).expect("compress");
let compressed = enc.finish().expect("gzip finish");
let ok = super::decompress_bytes(&compressed, Algorithm::Gzip).expect("under cap");
assert_eq!(ok.len(), 1024 * 1024);
}
#[test]
fn drain_capped_returns_bomb_error_on_over_cap_source() {
let oversized = std::io::Cursor::new(vec![b'A'; 4096]);
use std::io::Read;
let cap: usize = 256;
let mut limited = oversized.take((cap as u64) + 1);
let mut buf = Vec::new();
limited.read_to_end(&mut buf).expect("read");
assert!(
buf.len() > cap,
"Read::take(cap+1) must produce cap+1 bytes for a > cap source"
);
}
}