use std::io::Read;
use compcol::deflate::Deflate;
use compcol::gzip::Gzip;
use compcol::io::DecoderReader;
use compcol::limit::LimitedDecoder;
use compcol::zlib::Zlib;
use compcol::Algorithm;
use crate::error::{Error, Result};
use crate::http::MAX_BODY_BYTES;
const MAX_ENCODING_LAYERS: usize = 3;
#[derive(Debug)]
pub(crate) struct Decoded {
pub body: Vec<u8>,
pub decoded: bool,
}
pub(crate) fn decode_body(body: Vec<u8>, content_encoding: &str) -> Result<Decoded> {
let layers: Vec<&str> = content_encoding
.split(',')
.map(str::trim)
.filter(|s| !s.is_empty())
.collect();
if layers.is_empty() {
return Ok(Decoded {
body,
decoded: false,
});
}
if layers.len() > MAX_ENCODING_LAYERS {
return Err(Error::BadResponse(format!(
"too many Content-Encoding layers ({}, max {MAX_ENCODING_LAYERS})",
layers.len()
)));
}
let mut current = body;
let mut peeled = false;
let mut budget = MAX_BODY_BYTES as u64;
for token in layers.iter().rev() {
match Layer::parse(token) {
Some(Layer::Identity) => {
peeled = true;
}
Some(Layer::Gzip) => {
current = gunzip(¤t, budget)?;
budget = budget.saturating_sub(current.len() as u64);
peeled = true;
}
Some(Layer::Deflate) => {
current = inflate_zlib(¤t, budget)?;
budget = budget.saturating_sub(current.len() as u64);
peeled = true;
}
None => {
return Ok(Decoded {
body: current,
decoded: peeled,
});
}
}
}
Ok(Decoded {
body: current,
decoded: peeled,
})
}
enum Layer {
Gzip,
Deflate,
Identity,
}
impl Layer {
fn parse(token: &str) -> Option<Self> {
if token.eq_ignore_ascii_case("gzip") || token.eq_ignore_ascii_case("x-gzip") {
Some(Layer::Gzip)
} else if token.eq_ignore_ascii_case("deflate") {
Some(Layer::Deflate)
} else if token.eq_ignore_ascii_case("identity") {
Some(Layer::Identity)
} else {
None
}
}
}
fn decode_with<A: Algorithm>(src: &[u8], budget: u64) -> std::io::Result<Vec<u8>> {
let cap = (src.len() as u64).saturating_mul(3).min(budget) as usize;
let mut out = Vec::with_capacity(cap);
let dec = LimitedDecoder::new(A::decoder(), budget);
let mut reader = DecoderReader::new(src, dec);
reader.read_to_end(&mut out)?;
Ok(out)
}
fn gunzip(src: &[u8], budget: u64) -> Result<Vec<u8>> {
decode_with::<Gzip>(src, budget)
.map_err(|e| Error::BadResponse(format!("gzip decode failed: {e}")))
}
fn inflate_zlib(src: &[u8], budget: u64) -> Result<Vec<u8>> {
let zlib_err = match decode_with::<Zlib>(src, budget) {
Ok(out) => return Ok(out),
Err(e) => e,
};
decode_with::<Deflate>(src, budget)
.map_err(|_| Error::BadResponse(format!("deflate decode failed: {zlib_err}")))
}
pub(crate) fn strip_after_decode(headers: Vec<(String, String)>) -> Vec<(String, String)> {
headers
.into_iter()
.filter(|(k, _)| {
!k.eq_ignore_ascii_case("content-encoding") && !k.eq_ignore_ascii_case("content-length")
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use compcol::vec::compress_to_vec;
fn gz(data: &[u8]) -> Vec<u8> {
compress_to_vec::<Gzip>(data).expect("gzip encode")
}
fn zlib(data: &[u8]) -> Vec<u8> {
compress_to_vec::<Zlib>(data).expect("zlib encode")
}
fn raw_deflate(data: &[u8]) -> Vec<u8> {
compress_to_vec::<Deflate>(data).expect("deflate encode")
}
#[test]
fn decodes_gzip() {
let out = decode_body(gz(b"hello world"), "gzip").unwrap();
assert_eq!(out.body, b"hello world");
assert!(out.decoded);
}
#[test]
fn decodes_x_gzip_alias() {
let out = decode_body(gz(b"abc"), "x-gzip").unwrap();
assert_eq!(out.body, b"abc");
}
#[test]
fn decodes_zlib_wrapped_deflate() {
let out = decode_body(zlib(b"payload"), "deflate").unwrap();
assert_eq!(out.body, b"payload");
}
#[test]
fn decodes_raw_deflate_for_buggy_servers() {
let out = decode_body(raw_deflate(b"payload"), "deflate").unwrap();
assert_eq!(out.body, b"payload");
}
#[test]
fn case_insensitive_token() {
let out = decode_body(gz(b"x"), "GZIP").unwrap();
assert_eq!(out.body, b"x");
}
#[test]
fn identity_passes_through() {
let out = decode_body(b"raw".to_vec(), "identity").unwrap();
assert_eq!(out.body, b"raw");
assert!(out.decoded); }
#[test]
fn empty_encoding_is_noop() {
let out = decode_body(b"raw".to_vec(), "").unwrap();
assert_eq!(out.body, b"raw");
assert!(!out.decoded);
}
#[test]
fn nested_gzip_then_identity() {
let out = decode_body(gz(b"nested"), "identity, gzip").unwrap();
assert_eq!(out.body, b"nested");
}
#[test]
fn unknown_outer_layer_returns_undecoded() {
let payload = gz(b"inner");
let out = decode_body(payload.clone(), "gzip, br").unwrap();
assert_eq!(out.body, payload);
assert!(!out.decoded);
}
#[test]
fn corrupt_gzip_reports_error() {
let mut bad = gz(b"valid");
bad.pop(); bad.pop();
let err = decode_body(bad, "gzip").unwrap_err();
match err {
Error::BadResponse(msg) => assert!(msg.contains("gzip"), "got {msg:?}"),
other => panic!("unexpected error variant: {other:?}"),
}
}
#[test]
fn rejects_too_many_encoding_layers() {
let err = decode_body(gz(b"x"), "gzip, gzip, gzip, gzip").unwrap_err();
match err {
Error::BadResponse(msg) => {
assert!(msg.contains("Content-Encoding layers"), "got {msg:?}")
}
other => panic!("unexpected error variant: {other:?}"),
}
}
#[test]
fn accepts_max_encoding_layers() {
let inner = gz(b"deep");
let mid = gz(&inner);
let outer = gz(&mid);
let out = decode_body(outer, "gzip, gzip, gzip").unwrap();
assert_eq!(out.body, b"deep");
assert!(out.decoded);
}
#[test]
fn nested_layers_share_one_budget() {
let inner = gz(b"payload");
let outer = gz(&inner);
let out = decode_body(outer, "gzip, gzip").unwrap();
assert_eq!(out.body, b"payload");
}
#[test]
fn strip_after_decode_removes_both_headers() {
let h = vec![
("Content-Type".into(), "text/html".into()),
("Content-Encoding".into(), "gzip".into()),
("Content-Length".into(), "123".into()),
("Server".into(), "test".into()),
];
let out = strip_after_decode(h);
let names: Vec<&str> = out.iter().map(|(k, _)| k.as_str()).collect();
assert_eq!(names, ["Content-Type", "Server"]);
}
}