use crate::{CompressionAlgorithm, CompressionConfig};
use armature_core::{Error, HttpRequest, HttpResponse, Middleware};
use async_trait::async_trait;
use std::future::Future;
use std::pin::Pin;
#[derive(Debug, Clone)]
pub struct CompressionMiddleware {
config: CompressionConfig,
}
impl CompressionMiddleware {
pub fn new() -> Self {
Self {
config: CompressionConfig::default(),
}
}
pub fn with_config(config: CompressionConfig) -> Self {
Self { config }
}
pub fn config(&self) -> &CompressionConfig {
&self.config
}
fn select_algorithm(&self, accept_encoding: Option<&str>) -> CompressionAlgorithm {
match self.config.algorithm {
CompressionAlgorithm::Auto => {
if let Some(encoding) = accept_encoding {
CompressionAlgorithm::select_from_accept_encoding(encoding)
} else {
CompressionAlgorithm::None
}
}
algo => algo,
}
}
fn should_compress(&self, response: &HttpResponse) -> bool {
if response.status >= 400 || response.body.is_empty() {
return false;
}
if !self.config.should_compress_size(response.body.len()) {
return false;
}
if !self.config.compress_encoded {
if let Some(encoding) = response.headers.get("Content-Encoding") {
if !encoding.is_empty() && encoding != "identity" {
return false;
}
}
}
if let Some(content_type) = response.headers.get("Content-Type") {
self.config.should_compress_content_type(content_type)
} else {
true
}
}
fn compress_response(
&self,
mut response: HttpResponse,
algorithm: CompressionAlgorithm,
) -> HttpResponse {
let level = self.config.effective_level();
match algorithm.compress(&response.body, level) {
Ok(compressed) => {
if compressed.len() < response.body.len() {
response.body = compressed;
if let Some(encoding) = algorithm.encoding_name() {
response
.headers
.insert("Content-Encoding".to_string(), encoding.to_string());
}
response.headers.insert(
"Content-Length".to_string(),
response.body.len().to_string(),
);
let vary = response.headers.entry("Vary".to_string()).or_default();
if !vary.contains("Accept-Encoding") {
if !vary.is_empty() {
vary.push_str(", ");
}
vary.push_str("Accept-Encoding");
}
}
}
Err(e) => {
tracing::warn!("Compression failed: {}", e);
}
}
response
}
}
impl Default for CompressionMiddleware {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Middleware for CompressionMiddleware {
async fn handle(
&self,
req: HttpRequest,
next: Box<
dyn FnOnce(
HttpRequest,
)
-> Pin<Box<dyn Future<Output = Result<HttpResponse, Error>> + Send>>
+ Send,
>,
) -> Result<HttpResponse, Error> {
let accept_encoding = req
.headers
.get("Accept-Encoding")
.or_else(|| req.headers.get("accept-encoding"))
.cloned();
let response = next(req).await?;
let algorithm = self.select_algorithm(accept_encoding.as_deref());
if algorithm == CompressionAlgorithm::None || !self.should_compress(&response) {
return Ok(response);
}
Ok(self.compress_response(response, algorithm))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_response(body: &str, content_type: &str) -> HttpResponse {
HttpResponse::new(200)
.with_header("Content-Type".to_string(), content_type.to_string())
.with_header("Content-Length".to_string(), body.len().to_string())
.with_body(body.as_bytes().to_vec())
}
#[test]
fn test_middleware_creation() {
let middleware = CompressionMiddleware::new();
assert_eq!(middleware.config().algorithm, CompressionAlgorithm::Auto);
}
#[test]
fn test_should_compress() {
let middleware = CompressionMiddleware::new();
let response = create_response(&"x".repeat(1000), "application/json");
assert!(middleware.should_compress(&response));
let response = create_response("small", "application/json");
assert!(!middleware.should_compress(&response));
let response = create_response(&"x".repeat(1000), "image/png");
assert!(!middleware.should_compress(&response));
let mut response = create_response(&"x".repeat(1000), "application/json");
response.status = 500;
assert!(!middleware.should_compress(&response));
let mut response = create_response("", "application/json");
response.body = Vec::new();
assert!(!middleware.should_compress(&response));
}
#[test]
fn test_select_algorithm_auto() {
let middleware = CompressionMiddleware::new();
assert_eq!(
middleware.select_algorithm(None),
CompressionAlgorithm::None
);
#[cfg(feature = "gzip")]
{
assert_eq!(
middleware.select_algorithm(Some("gzip")),
CompressionAlgorithm::Gzip
);
}
#[cfg(feature = "brotli")]
{
assert_eq!(
middleware.select_algorithm(Some("br")),
CompressionAlgorithm::Brotli
);
}
#[cfg(all(feature = "gzip", feature = "brotli"))]
{
assert_eq!(
middleware.select_algorithm(Some("gzip, br")),
CompressionAlgorithm::Brotli
);
}
}
#[cfg(feature = "gzip")]
#[test]
fn test_select_algorithm_specific() {
let config = CompressionConfig::builder().gzip().build();
let middleware = CompressionMiddleware::with_config(config);
assert_eq!(
middleware.select_algorithm(Some("br")),
CompressionAlgorithm::Gzip
);
}
#[cfg(feature = "gzip")]
#[test]
fn test_compress_response() {
let middleware = CompressionMiddleware::with_config(
CompressionConfig::builder().gzip().min_size(10).build(),
);
let body = "Hello, World! ".repeat(100);
let response = create_response(&body, "text/plain");
let compressed = middleware.compress_response(response, CompressionAlgorithm::Gzip);
assert_eq!(
compressed.headers.get("Content-Encoding"),
Some(&"gzip".to_string())
);
assert!(
compressed
.headers
.get("Vary")
.unwrap()
.contains("Accept-Encoding")
);
assert!(compressed.body.len() < body.len());
}
#[cfg(feature = "gzip")]
#[test]
fn test_vary_header_appended() {
let middleware = CompressionMiddleware::with_config(
CompressionConfig::builder().gzip().min_size(10).build(),
);
let body = "Hello, World! ".repeat(100);
let mut response = create_response(&body, "text/plain");
response
.headers
.insert("Vary".to_string(), "Origin".to_string());
let compressed = middleware.compress_response(response, CompressionAlgorithm::Gzip);
let vary = compressed.headers.get("Vary").unwrap();
assert!(vary.contains("Origin"));
assert!(vary.contains("Accept-Encoding"));
}
}