use crate::middleware::Middleware;
use crate::{Request, Response, Result};
use async_compression::tokio::write::{BrotliEncoder, GzipEncoder};
use bytes::Bytes;
use http::{header, HeaderValue};
use tokio::io::AsyncWriteExt;
use tracing::debug;
use super::Next;
#[derive(Debug, Clone)]
pub enum CompressionLevel {
Fastest,
Default,
Best,
Precise(i32),
}
impl From<CompressionLevel> for async_compression::Level {
fn from(level: CompressionLevel) -> Self {
match level {
CompressionLevel::Fastest => async_compression::Level::Fastest,
CompressionLevel::Default => async_compression::Level::Default,
CompressionLevel::Best => async_compression::Level::Best,
CompressionLevel::Precise(n) => async_compression::Level::Precise(n),
}
}
}
#[derive(Debug, Clone)]
pub struct CompressionMiddleware {
threshold: usize,
level: CompressionLevel,
enable_gzip: bool,
enable_brotli: bool,
compressible_types: Vec<String>,
}
impl Default for CompressionMiddleware {
fn default() -> Self {
Self {
threshold: 1024, level: CompressionLevel::Default,
enable_gzip: true,
enable_brotli: true,
compressible_types: vec![
"text/plain".to_string(),
"text/html".to_string(),
"text/css".to_string(),
"text/javascript".to_string(),
"application/javascript".to_string(),
"application/json".to_string(),
"application/xml".to_string(),
"text/xml".to_string(),
"application/rss+xml".to_string(),
"application/atom+xml".to_string(),
"image/svg+xml".to_string(),
],
}
}
}
impl CompressionMiddleware {
pub fn new() -> Self {
Self::default()
}
pub fn with_threshold(mut self, threshold: usize) -> Self {
self.threshold = threshold;
self
}
pub fn with_level(mut self, level: CompressionLevel) -> Self {
self.level = level;
self
}
pub fn with_gzip(mut self, enabled: bool) -> Self {
self.enable_gzip = enabled;
self
}
pub fn with_brotli(mut self, enabled: bool) -> Self {
self.enable_brotli = enabled;
self
}
pub fn with_compressible_types(mut self, types: Vec<&str>) -> Self {
self.compressible_types = types.into_iter().map(String::from).collect();
self
}
fn is_compressible(&self, content_type: Option<&str>) -> bool {
if let Some(ct) = content_type {
let ct_lower = ct.to_lowercase();
self.compressible_types
.iter()
.any(|t| ct_lower.starts_with(t))
} else {
false
}
}
fn negotiate_encoding(&self, accept_encoding: Option<&str>) -> Option<Encoding> {
if let Some(accept) = accept_encoding {
let accept_lower = accept.to_lowercase();
let mut encodings: Vec<(Encoding, f32)> = Vec::new();
for part in accept_lower.split(',') {
let part = part.trim();
if let Some((encoding, quality)) = self.parse_encoding_with_quality(part) {
encodings.push((encoding, quality));
}
}
encodings.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
for (encoding, _) in encodings {
match encoding {
Encoding::Brotli if self.enable_brotli => return Some(Encoding::Brotli),
Encoding::Gzip if self.enable_gzip => return Some(Encoding::Gzip),
_ => continue,
}
}
}
None
}
fn parse_encoding_with_quality(&self, part: &str) -> Option<(Encoding, f32)> {
let mut split = part.split(';');
let encoding_str = split.next()?.trim();
let encoding = match encoding_str {
"br" => Encoding::Brotli,
"gzip" => Encoding::Gzip,
"*" => Encoding::Gzip, _ => return None,
};
let quality = if let Some(q_part) = split.next() {
if let Some(q_value) = q_part.trim().strip_prefix("q=") {
q_value.parse().unwrap_or(1.0)
} else {
1.0
}
} else {
1.0
};
Some((encoding, quality))
}
async fn compress_data(&self, data: &Bytes, encoding: Encoding) -> Result<Bytes> {
match encoding {
Encoding::Gzip => {
let mut encoder = GzipEncoder::with_quality(Vec::new(), self.level.clone().into());
encoder.write_all(data).await.map_err(|e| {
crate::Error::Internal(format!("Gzip compression failed: {}", e))
})?;
encoder.shutdown().await.map_err(|e| {
crate::Error::Internal(format!("Gzip finalization failed: {}", e))
})?;
let compressed = encoder.into_inner();
Ok(Bytes::from(compressed))
}
Encoding::Brotli => {
let mut encoder =
BrotliEncoder::with_quality(Vec::new(), self.level.clone().into());
encoder.write_all(data).await.map_err(|e| {
crate::Error::Internal(format!("Brotli compression failed: {}", e))
})?;
encoder.shutdown().await.map_err(|e| {
crate::Error::Internal(format!("Brotli finalization failed: {}", e))
})?;
let compressed = encoder.into_inner();
Ok(Bytes::from(compressed))
}
}
}
}
#[derive(Debug, Clone, Copy)]
enum Encoding {
Gzip,
Brotli,
}
impl Encoding {
fn as_str(&self) -> &'static str {
match self {
Encoding::Gzip => "gzip",
Encoding::Brotli => "br",
}
}
}
#[async_trait::async_trait]
impl Middleware for CompressionMiddleware {
async fn handle(&self, mut req: Request, next: Next) -> Response {
let accept_encoding = req.header("accept-encoding");
if let Some(encoding) = self.negotiate_encoding(accept_encoding) {
req.headers.insert(
"x-negotiated-encoding",
HeaderValue::from_static(encoding.as_str()),
);
debug!("Negotiated encoding: {}", encoding.as_str());
}
let mut res = next.run(req.clone()).await;
if res.body.len() < self.threshold {
debug!(
"Response too small for compression: {} bytes",
res.body.len()
);
return res;
}
if res.headers.contains_key(header::CONTENT_ENCODING) {
debug!("Response already has Content-Encoding header");
return res;
}
let content_type = res
.headers
.get(header::CONTENT_TYPE)
.and_then(|ct| ct.to_str().ok());
if !self.is_compressible(content_type) {
debug!("Content type not compressible: {:?}", content_type);
return res;
}
let encoding_str = req
.headers
.get("x-negotiated-encoding")
.and_then(|val| val.to_str().ok());
let encoding = match encoding_str {
Some("br") if self.enable_brotli => Encoding::Brotli,
Some("gzip") if self.enable_gzip => Encoding::Gzip,
_ => {
if self.enable_brotli {
Encoding::Brotli
} else if self.enable_gzip {
Encoding::Gzip
} else {
return res;
}
}
};
let original_size = res.body.len();
let compressed_body = match self.compress_data(&res.body, encoding).await {
Ok(body) => body,
Err(e) => {
debug!(
"Compression failed: {}, returning uncompressed",
e.to_string()
);
return res; }
};
let compressed_size = compressed_body.len();
if compressed_size < original_size {
res.body = compressed_body;
res.headers.insert(
header::CONTENT_ENCODING,
HeaderValue::from_static(encoding.as_str()),
);
res.headers
.insert(header::VARY, HeaderValue::from_static("Accept-Encoding"));
debug!(
"Compressed response: {} -> {} bytes ({}% reduction, {})",
original_size,
compressed_size,
((original_size - compressed_size) * 100) / original_size,
encoding.as_str()
);
} else {
debug!(
"Compression not beneficial: {} -> {} bytes",
original_size, compressed_size
);
}
res
}
}
impl CompressionMiddleware {
pub fn for_api() -> Self {
Self::new()
.with_threshold(512) .with_compressible_types(vec![
"application/json",
"application/xml",
"text/xml",
"text/plain",
])
}
pub fn for_web() -> Self {
Self::new()
.with_threshold(1024)
.with_level(CompressionLevel::Default)
}
pub fn high_compression() -> Self {
Self::new()
.with_level(CompressionLevel::Best)
.with_threshold(2048)
}
pub fn fast_compression() -> Self {
Self::new()
.with_level(CompressionLevel::Fastest)
.with_threshold(512)
}
}