use crate::middleware::v2::{Middleware, Next, NextFuture};
use crate::request::ElifRequest;
use crate::response::ElifResponse;
use http_body_util::BodyExt;
use tower::{Layer, Service};
use tower_http::compression::{CompressionLayer, CompressionLevel};
#[derive(Debug, Clone)]
pub struct CompressionConfig {
pub level: CompressionLevel,
pub enable_gzip: bool,
pub enable_brotli: bool,
pub enable_deflate: bool,
}
impl Default for CompressionConfig {
fn default() -> Self {
Self {
level: CompressionLevel::default(),
enable_gzip: true,
enable_brotli: true,
enable_deflate: false, }
}
}
pub struct CompressionMiddleware {
layer: CompressionLayer,
}
impl CompressionMiddleware {
pub fn new() -> Self {
let config = CompressionConfig::default();
Self::with_config(config)
}
pub fn with_config(config: CompressionConfig) -> Self {
let mut layer = CompressionLayer::new().quality(config.level);
if !config.enable_gzip {
layer = layer.no_gzip();
}
if !config.enable_brotli {
layer = layer.no_br();
}
if !config.enable_deflate {
layer = layer.no_deflate();
}
Self { layer }
}
pub fn level(self, level: CompressionLevel) -> Self {
Self {
layer: self.layer.quality(level),
}
}
pub fn fast(self) -> Self {
self.level(CompressionLevel::Fastest)
}
pub fn best(self) -> Self {
self.level(CompressionLevel::Best)
}
pub fn no_gzip(self) -> Self {
Self {
layer: self.layer.no_gzip(),
}
}
pub fn no_brotli(self) -> Self {
Self {
layer: self.layer.no_br(),
}
}
pub fn no_deflate(self) -> Self {
Self {
layer: self.layer.no_deflate(),
}
}
pub fn gzip_only(self) -> Self {
Self {
layer: self.layer.no_br().no_deflate(),
}
}
pub fn brotli_only(self) -> Self {
Self {
layer: self.layer.no_gzip().no_deflate(),
}
}
}
impl std::fmt::Debug for CompressionMiddleware {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompressionMiddleware")
.field("layer", &"<CompressionLayer>")
.finish()
}
}
impl Default for CompressionMiddleware {
fn default() -> Self {
Self::new()
}
}
impl Clone for CompressionMiddleware {
fn clone(&self) -> Self {
Self {
layer: self.layer.clone(),
}
}
}
impl Middleware for CompressionMiddleware {
fn handle(&self, request: ElifRequest, next: Next) -> NextFuture<'static> {
let layer = self.layer.clone();
Box::pin(async move {
let accept_encoding = request
.header("accept-encoding")
.and_then(|h| h.to_str().ok())
.map(|s| s.to_owned())
.unwrap_or_default();
let wants_compression = accept_encoding.contains("gzip")
|| accept_encoding.contains("br")
|| accept_encoding.contains("deflate");
let response = next.run(request).await;
if !wants_compression {
return response;
}
let axum_response = response.into_axum_response();
let (parts, body) = axum_response.into_parts();
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
Ok(bytes) => bytes,
Err(_) => {
let response =
axum::response::Response::from_parts(parts, axum::body::Body::empty());
return ElifResponse::from_axum_response(response).await;
}
};
let parts_clone = parts.clone();
let body_bytes_clone = body_bytes.clone();
let mock_request = axum::extract::Request::builder()
.uri("/")
.header("accept-encoding", &accept_encoding)
.body(axum::body::Body::empty())
.unwrap();
let service = tower::service_fn(move |_req: axum::extract::Request| {
let response_parts = parts.clone();
let response_body = body_bytes.clone();
async move {
let response = axum::response::Response::from_parts(
response_parts,
axum::body::Body::from(response_body),
);
Ok::<axum::response::Response, std::convert::Infallible>(response)
}
});
let mut compression_service = layer.layer(service);
match compression_service.call(mock_request).await {
Ok(compressed_response) => {
let (compressed_parts, compressed_body) = compressed_response.into_parts();
match compressed_body.collect().await {
Ok(collected) => {
let compressed_bytes = collected.to_bytes();
let final_response = axum::response::Response::from_parts(
compressed_parts,
axum::body::Body::from(compressed_bytes),
);
ElifResponse::from_axum_response(final_response).await
}
Err(_) => {
let original_response = axum::response::Response::from_parts(
parts_clone,
axum::body::Body::from(body_bytes_clone),
);
ElifResponse::from_axum_response(original_response).await
}
}
}
Err(_) => {
let original_response = axum::response::Response::from_parts(
parts_clone,
axum::body::Body::from(body_bytes_clone),
);
ElifResponse::from_axum_response(original_response).await
}
}
})
}
fn name(&self) -> &'static str {
"CompressionMiddleware"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::request::ElifRequest;
use crate::response::ElifResponse;
#[test]
fn test_compression_config() {
let config = CompressionConfig::default();
assert!(config.enable_gzip);
assert!(config.enable_brotli);
assert!(!config.enable_deflate);
}
#[tokio::test]
async fn test_compression_middleware() {
let middleware = CompressionMiddleware::new();
let mut headers = crate::response::headers::ElifHeaderMap::new();
let encoding_header =
crate::response::headers::ElifHeaderName::from_str("accept-encoding").unwrap();
let encoding_value =
crate::response::headers::ElifHeaderValue::from_str("gzip, br").unwrap();
headers.insert(encoding_header, encoding_value);
let request = ElifRequest::new(
crate::request::ElifMethod::GET,
"/api/data".parse().unwrap(),
headers,
);
let next = Next::new(|_req| {
Box::pin(async move {
let json_data = serde_json::json!({
"message": "Hello, World!".repeat(100), "data": (0..100).collect::<Vec<i32>>()
});
ElifResponse::ok().json_value(json_data)
})
});
let response = middleware.handle(request, next).await;
assert_eq!(
response.status_code(),
crate::response::status::ElifStatusCode::OK
);
}
#[tokio::test]
async fn test_compression_builder_pattern() {
let middleware = CompressionMiddleware::new()
.best() .gzip_only();
assert_eq!(middleware.name(), "CompressionMiddleware");
}
#[test]
fn test_compression_levels() {
let fast = CompressionMiddleware::new().fast();
let best = CompressionMiddleware::new().best();
let custom = CompressionMiddleware::new().level(CompressionLevel::Precise(5));
assert_eq!(fast.name(), "CompressionMiddleware");
assert_eq!(best.name(), "CompressionMiddleware");
assert_eq!(custom.name(), "CompressionMiddleware");
}
#[test]
fn test_algorithm_selection() {
let gzip_only = CompressionMiddleware::new().gzip_only();
let brotli_only = CompressionMiddleware::new().brotli_only();
let no_brotli = CompressionMiddleware::new().no_brotli();
assert_eq!(gzip_only.name(), "CompressionMiddleware");
assert_eq!(brotli_only.name(), "CompressionMiddleware");
assert_eq!(no_brotli.name(), "CompressionMiddleware");
}
#[test]
fn test_clone() {
let middleware = CompressionMiddleware::new().best();
let cloned = middleware.clone();
assert_eq!(cloned.name(), "CompressionMiddleware");
}
}