#![warn(missing_docs)]
use crate::{
context::HttpResponse, next::Next, req::HttpRequest, res::ResponseBody, types::MiddlewareOutput,
};
use flate2::{write::GzEncoder, Compression};
use std::io::Write;
#[derive(Clone)]
pub struct CompressionConfig {
pub threshold: usize,
pub level: u8,
}
impl Default for CompressionConfig {
fn default() -> Self {
Self {
threshold: 1024,
level: 6,
}
}
}
pub(crate) fn compression(
config: Option<CompressionConfig>,
) -> impl Fn(HttpRequest, HttpResponse, Next) -> MiddlewareOutput + Send + Sync + 'static {
let config = config.unwrap_or_default();
move |req: HttpRequest, mut res, next| {
let config = config.clone();
Box::pin(async move {
let accepts_gzip = req
.headers
.get("Accept-Encoding")
.map(|v| accepts_gzip_encoding(v))
.unwrap_or(false);
if !accepts_gzip {
return next.call(req, res).await;
}
if res
.headers
.get("Content-Encoding")
.or_else(|| res.headers.get("content-encoding"))
.is_some()
{
return next.call(req, res).await;
}
let body_bytes = match get_response_body_bytes(&res) {
Some(bytes) => bytes,
None => return next.call(req, res).await,
};
if body_bytes.len() < config.threshold {
return next.call(req, res).await;
}
let content_type = &res.headers.get("Content-Type").unwrap();
if !should_compress_content_type(content_type) {
return next.call(req, res).await;
}
match compress_data(&body_bytes, config.level) {
Ok(compressed_body) => {
if let Err(_) = set_response_body(&mut res, compressed_body) {
return next.call(req, res).await;
}
res = res
.set_header("Content-Encoding", "gzip")
.set_header("Vary", "Accept-Encoding");
res.headers.remove("Content-Length");
(req, Some(res))
}
Err(_) => {
return next.call(req, res).await;
}
}
})
}
}
pub(crate) fn should_compress_content_type(content_type: &str) -> bool {
let compressible_types = [
"text/",
"application/json",
"application/javascript",
"application/xml",
"application/rss+xml",
"application/atom+xml",
"application/xhtml+xml",
"image/svg+xml",
];
let content_type_lower = content_type.to_lowercase();
compressible_types
.iter()
.any(|&ct| content_type_lower.starts_with(ct))
}
pub(crate) fn compress_data(data: &[u8], level: u8) -> Result<Vec<u8>, std::io::Error> {
let mut encoder = GzEncoder::new(Vec::new(), Compression::new(level.min(9) as u32));
encoder.write_all(data)?;
encoder.finish()
}
pub(crate) fn get_response_body_bytes(response: &HttpResponse) -> Option<Vec<u8>> {
match &response.body {
ResponseBody::TEXT(text) => Some(text.as_bytes().to_vec()),
ResponseBody::JSON(json) => serde_json::to_vec(json).ok(),
ResponseBody::HTML(html) => Some(html.as_bytes().to_vec()),
ResponseBody::BINARY(bytes) => Some(bytes.to_vec()),
}
}
pub(crate) fn set_response_body(
response: &mut HttpResponse,
compressed_body: Vec<u8>,
) -> Result<(), ()> {
response.body = ResponseBody::BINARY(compressed_body.into());
Ok(())
}
pub(crate) fn accepts_gzip_encoding(header: &str) -> bool {
header
.split(',')
.filter_map(|t| {
let mut parts = t.trim().split(';');
let enc = parts.next()?.trim().to_ascii_lowercase();
let mut q = 1.0_f32;
for p in parts {
if let Some(val) = p.trim().strip_prefix("q=") {
q = val.parse::<f32>().unwrap_or(0.0);
}
}
Some((enc, q))
})
.any(|(enc, q)| q > 0.0 && (enc == "gzip" || enc == "*"))
}