#![forbid(unsafe_code)]
use bytes::Bytes;
use http::StatusCode;
use http_body_util::{BodyExt, Full};
use oxihttp_core::OxiHttpError;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CompressionAlgorithm {
Gzip,
Deflate,
}
impl CompressionAlgorithm {
pub fn as_str(&self) -> &'static str {
match self {
Self::Gzip => "gzip",
Self::Deflate => "deflate",
}
}
}
#[derive(Debug, Clone)]
pub struct CompressionConfig {
pub min_size: usize,
pub algorithms: Vec<CompressionAlgorithm>,
pub level: u8,
}
impl Default for CompressionConfig {
fn default() -> Self {
Self {
min_size: 1024,
algorithms: vec![CompressionAlgorithm::Gzip, CompressionAlgorithm::Deflate],
level: 6,
}
}
}
#[derive(Debug, Clone)]
pub struct Compression {
config: CompressionConfig,
}
impl Compression {
pub fn new() -> Self {
Self {
config: CompressionConfig::default(),
}
}
pub fn min_size(mut self, n: usize) -> Self {
self.config.min_size = n;
self
}
pub fn level(mut self, n: u8) -> Self {
self.config.level = n.clamp(1, 9);
self
}
pub fn algorithms(mut self, algos: Vec<CompressionAlgorithm>) -> Self {
self.config.algorithms = algos;
self
}
pub async fn apply(
&self,
accept_encoding: Option<&str>,
resp: hyper::Response<Full<Bytes>>,
) -> Result<hyper::Response<Full<Bytes>>, OxiHttpError> {
let status = resp.status();
let is_compressible_status = status.is_success() || status == StatusCode::PARTIAL_CONTENT;
if !is_compressible_status {
return Ok(resp);
}
if resp.headers().contains_key("content-encoding") {
return Ok(resp);
}
let chosen = match accept_encoding {
None => return Ok(resp),
Some(ae) => match negotiate(ae, &self.config.algorithms) {
None => return Ok(resp),
Some(algo) => algo,
},
};
let (parts, body) = resp.into_parts();
let collected = body
.collect()
.await
.map_err(|e| OxiHttpError::Body(e.to_string()))?;
let raw_bytes = collected.to_bytes();
if raw_bytes.len() < self.config.min_size {
return Ok(hyper::Response::from_parts(parts, Full::new(raw_bytes)));
}
let compressed_vec = match chosen {
CompressionAlgorithm::Gzip => {
oxiarc_deflate::gzip_compress(&raw_bytes, self.config.level)
.map_err(|e| OxiHttpError::Body(e.to_string()))?
}
CompressionAlgorithm::Deflate => {
oxiarc_deflate::zlib_compress(&raw_bytes, self.config.level)
.map_err(|e| OxiHttpError::Body(e.to_string()))?
}
};
let compressed_bytes = Bytes::from(compressed_vec);
let content_length = compressed_bytes.len().to_string();
let mut new_parts = parts;
new_parts.headers.remove("content-length");
new_parts.headers.insert(
http::header::CONTENT_ENCODING,
http::HeaderValue::from_str(chosen.as_str())
.map_err(|e| OxiHttpError::InvalidHeader(e.to_string()))?,
);
new_parts.headers.insert(
http::header::CONTENT_LENGTH,
http::HeaderValue::from_str(&content_length)
.map_err(|e| OxiHttpError::InvalidHeader(e.to_string()))?,
);
new_parts.headers.insert(
http::header::VARY,
http::HeaderValue::from_static("Accept-Encoding"),
);
Ok(hyper::Response::from_parts(
new_parts,
Full::new(compressed_bytes),
))
}
}
impl Default for Compression {
fn default() -> Self {
Self::new()
}
}
fn negotiate(
accept_encoding: &str,
preferred: &[CompressionAlgorithm],
) -> Option<CompressionAlgorithm> {
struct Entry {
name: String,
q: f32,
}
let mut entries: Vec<Entry> = accept_encoding
.split(',')
.filter_map(|part| {
let part = part.trim();
if part.is_empty() {
return None;
}
let mut iter = part.splitn(2, ';');
let name = iter.next()?.trim().to_lowercase();
let q = iter
.next()
.and_then(|q_part| q_part.trim().strip_prefix("q="))
.and_then(|v| v.parse::<f32>().ok())
.unwrap_or(1.0_f32);
Some(Entry { name, q })
})
.collect();
entries.sort_by(|a, b| b.q.partial_cmp(&a.q).unwrap_or(std::cmp::Ordering::Equal));
for algo in preferred {
let name = algo.as_str();
let allowed = entries
.iter()
.any(|e| (e.name == name || e.name == "*") && e.q > 0.0);
if allowed {
return Some(algo.clone());
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
fn preferred() -> Vec<CompressionAlgorithm> {
vec![CompressionAlgorithm::Gzip, CompressionAlgorithm::Deflate]
}
#[test]
fn test_negotiate_gzip_preferred() {
assert_eq!(
negotiate("gzip, deflate", &preferred()),
Some(CompressionAlgorithm::Gzip)
);
}
#[test]
fn test_negotiate_deflate_only() {
assert_eq!(
negotiate("deflate", &preferred()),
Some(CompressionAlgorithm::Deflate)
);
}
#[test]
fn test_negotiate_none() {
let gzip_only = vec![CompressionAlgorithm::Gzip];
assert_eq!(negotiate("br, zstd", &gzip_only), None);
}
#[test]
fn test_negotiate_gzip_q0_deflate_fallback() {
assert_eq!(
negotiate("gzip;q=0, deflate", &preferred()),
Some(CompressionAlgorithm::Deflate)
);
}
#[test]
fn test_negotiate_wildcard() {
let gzip_only = vec![CompressionAlgorithm::Gzip];
assert_eq!(negotiate("*", &gzip_only), Some(CompressionAlgorithm::Gzip));
}
#[test]
fn test_negotiate_empty_accept() {
assert_eq!(negotiate("", &preferred()), None);
}
#[test]
fn test_negotiate_q_ordering() {
assert_eq!(
negotiate("deflate;q=0.9, gzip;q=0.8", &preferred()),
Some(CompressionAlgorithm::Gzip)
);
}
}