use async_trait::async_trait;
use bytes::Bytes;
use flate2::Compression;
use flate2::write::GzEncoder;
use hyper::header::{ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, HeaderValue};
use reinhardt_http::{Handler, Middleware, Request, Response, Result};
use std::io::Write;
use std::sync::Arc;
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct GZipConfig {
pub min_length: usize,
pub compression_level: u32,
pub compressible_types: Vec<String>,
}
impl Default for GZipConfig {
fn default() -> Self {
Self {
min_length: 200,
compression_level: 6,
compressible_types: vec![
"text/".to_string(),
"application/json".to_string(),
"application/javascript".to_string(),
"application/xml".to_string(),
"application/xhtml+xml".to_string(),
],
}
}
}
pub struct GZipMiddleware {
config: GZipConfig,
}
impl GZipMiddleware {
pub fn new() -> Self {
Self {
config: GZipConfig::default(),
}
}
pub fn with_config(config: GZipConfig) -> Self {
Self { config }
}
fn accepts_gzip(&self, request: &Request) -> bool {
if let Some(accept_encoding) = request.headers.get(ACCEPT_ENCODING)
&& let Ok(encoding_str) = accept_encoding.to_str()
{
return encoding_str.contains("gzip");
}
false
}
fn should_compress(&self, content_type: &str, body_len: usize) -> bool {
if body_len < self.config.min_length {
return false;
}
self.config
.compressible_types
.iter()
.any(|ct| content_type.starts_with(ct))
}
fn compress_body(&self, body: &[u8]) -> Result<Vec<u8>> {
let mut encoder =
GzEncoder::new(Vec::new(), Compression::new(self.config.compression_level));
encoder
.write_all(body)
.map_err(|e| reinhardt_core::exception::Error::Internal(e.to_string()))?;
encoder
.finish()
.map_err(|e| reinhardt_core::exception::Error::Internal(e.to_string()))
}
}
impl Default for GZipMiddleware {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Middleware for GZipMiddleware {
async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
let accepts_gzip = self.accepts_gzip(&request);
let mut response = match handler.handle(request).await {
Ok(resp) => resp,
Err(e) => Response::from(e),
};
if !accepts_gzip || response.headers.contains_key(CONTENT_ENCODING) {
return Ok(response);
}
let content_type = response
.headers
.get(CONTENT_TYPE)
.and_then(|ct| ct.to_str().ok())
.unwrap_or("");
let body_len = response.body.len();
if !self.should_compress(content_type, body_len) {
return Ok(response);
}
let compressed = self.compress_body(&response.body)?;
if compressed.len() < body_len {
response.body = Bytes::from(compressed);
response
.headers
.insert(CONTENT_ENCODING, HeaderValue::from_static("gzip"));
if let Ok(len_value) = response.body.len().to_string().parse() {
response.headers.insert(CONTENT_LENGTH, len_value);
}
}
Ok(response)
}
}
#[cfg(test)]
mod tests {
use super::*;
use hyper::{HeaderMap, Method, StatusCode, Version};
use reinhardt_http::Response;
struct TestHandler {
response_body: &'static str,
content_type: &'static str,
}
#[async_trait]
impl Handler for TestHandler {
async fn handle(&self, _request: Request) -> Result<Response> {
let mut response =
Response::new(StatusCode::OK).with_body(self.response_body.as_bytes());
response
.headers
.insert(CONTENT_TYPE, self.content_type.parse().unwrap());
Ok(response)
}
}
#[tokio::test]
async fn test_gzip_compression() {
let middleware = GZipMiddleware::new();
let long_content = "This is a test response that should be compressed because it's long enough and is text/html content type. ".repeat(5);
let handler = Arc::new(TestHandler {
response_body: Box::leak(long_content.into_boxed_str()),
content_type: "text/html",
});
let mut headers = HeaderMap::new();
headers.insert(ACCEPT_ENCODING, "gzip, deflate".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.headers.get(CONTENT_ENCODING).unwrap(), "gzip");
assert!(!response.body.is_empty()); }
#[tokio::test]
async fn test_no_gzip_if_client_doesnt_accept() {
let middleware = GZipMiddleware::new();
let body = "This is a test response";
let handler = Arc::new(TestHandler {
response_body: body,
content_type: "text/html",
});
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(!response.headers.contains_key(CONTENT_ENCODING));
assert_eq!(response.body, Bytes::from(body));
}
#[tokio::test]
async fn test_no_gzip_for_small_response() {
let config = GZipConfig {
min_length: 1000,
..Default::default()
};
let middleware = GZipMiddleware::with_config(config);
let handler = Arc::new(TestHandler {
response_body: "Small response",
content_type: "text/html",
});
let mut headers = HeaderMap::new();
headers.insert(ACCEPT_ENCODING, "gzip".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(!response.headers.contains_key(CONTENT_ENCODING));
}
#[tokio::test]
async fn test_no_gzip_for_non_compressible_type() {
let middleware = GZipMiddleware::new();
let handler = Arc::new(TestHandler {
response_body: "This is a long response that could be compressed but is an image",
content_type: "image/png",
});
let mut headers = HeaderMap::new();
headers.insert(ACCEPT_ENCODING, "gzip".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(!response.headers.contains_key(CONTENT_ENCODING));
}
#[tokio::test]
async fn test_gzip_non_200_response() {
struct NotFoundHandler;
#[async_trait]
impl Handler for NotFoundHandler {
async fn handle(&self, _request: Request) -> Result<Response> {
let content = "Not found page with enough content to compress".repeat(10);
let mut response =
Response::new(StatusCode::NOT_FOUND).with_body(Bytes::from(content));
response
.headers
.insert(CONTENT_TYPE, "text/html".parse().unwrap());
Ok(response)
}
}
let middleware = GZipMiddleware::new();
let handler = Arc::new(NotFoundHandler);
let mut headers = HeaderMap::new();
headers.insert(ACCEPT_ENCODING, "gzip".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::NOT_FOUND);
assert_eq!(response.headers.get(CONTENT_ENCODING).unwrap(), "gzip");
}
#[tokio::test]
async fn test_no_compress_already_compressed() {
struct CompressedHandler;
#[async_trait]
impl Handler for CompressedHandler {
async fn handle(&self, _request: Request) -> Result<Response> {
let mut response = Response::new(StatusCode::OK)
.with_body("Already compressed content".as_bytes());
response
.headers
.insert(CONTENT_TYPE, "text/html".parse().unwrap());
response
.headers
.insert(CONTENT_ENCODING, "deflate".parse().unwrap());
Ok(response)
}
}
let middleware = GZipMiddleware::new();
let handler = Arc::new(CompressedHandler);
let mut headers = HeaderMap::new();
headers.insert(ACCEPT_ENCODING, "gzip".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.headers.get(CONTENT_ENCODING).unwrap(), "deflate");
}
#[tokio::test]
async fn test_gzip_json_content() {
struct JsonHandler;
#[async_trait]
impl Handler for JsonHandler {
async fn handle(&self, _request: Request) -> Result<Response> {
let json_data = r#"{"key": "value", "data": "This is a JSON response that should be compressed"}"#.repeat(5);
let mut response = Response::new(StatusCode::OK).with_body(Bytes::from(json_data));
response
.headers
.insert(CONTENT_TYPE, "application/json".parse().unwrap());
Ok(response)
}
}
let middleware = GZipMiddleware::new();
let handler = Arc::new(JsonHandler);
let mut headers = HeaderMap::new();
headers.insert(ACCEPT_ENCODING, "gzip, deflate".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.headers.get(CONTENT_ENCODING).unwrap(), "gzip");
}
#[tokio::test]
async fn test_content_length_updated() {
struct LongHandler;
#[async_trait]
impl Handler for LongHandler {
async fn handle(&self, _request: Request) -> Result<Response> {
let content =
"This is a test response that should be compressed because it's long enough"
.repeat(3);
let mut response = Response::new(StatusCode::OK).with_body(Bytes::from(content));
response
.headers
.insert(CONTENT_TYPE, "text/html".parse().unwrap());
Ok(response)
}
}
let middleware = GZipMiddleware::new();
let handler = Arc::new(LongHandler);
let mut headers = HeaderMap::new();
headers.insert(ACCEPT_ENCODING, "gzip".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.headers.get(CONTENT_ENCODING).unwrap(), "gzip");
let content_length: usize = response
.headers
.get(CONTENT_LENGTH)
.unwrap()
.to_str()
.unwrap()
.parse()
.unwrap();
assert_eq!(content_length, response.body.len());
}
}