use async_trait::async_trait;
use brotli::enc::BrotliEncoderParams;
use bytes::Bytes;
use hyper::header::{ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, HeaderValue};
use reinhardt_http::{Handler, Middleware, Request, Response, Result};
use std::sync::Arc;
#[derive(Debug, Clone, Copy)]
pub enum BrotliQuality {
Fast,
Balanced,
Best,
}
impl BrotliQuality {
fn to_value(self) -> u32 {
match self {
BrotliQuality::Fast => 1,
BrotliQuality::Balanced => 6,
BrotliQuality::Best => 11,
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct BrotliConfig {
pub min_length: usize,
pub quality: BrotliQuality,
pub compressible_types: Vec<String>,
pub window_size: u32,
}
impl BrotliConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_min_length(mut self, min_length: usize) -> Self {
self.min_length = min_length;
self
}
pub fn with_quality(mut self, quality: BrotliQuality) -> Self {
self.quality = quality;
self
}
pub fn with_window_size(mut self, window_size: u32) -> Self {
self.window_size = window_size.clamp(10, 24);
self
}
pub fn with_compressible_type(mut self, content_type: String) -> Self {
self.compressible_types.push(content_type);
self
}
}
impl Default for BrotliConfig {
fn default() -> Self {
Self {
min_length: 200,
quality: BrotliQuality::Balanced,
compressible_types: vec![
"text/".to_string(),
"application/json".to_string(),
"application/javascript".to_string(),
"application/xml".to_string(),
"application/xhtml+xml".to_string(),
],
window_size: 22, }
}
}
pub struct BrotliMiddleware {
config: BrotliConfig,
}
impl BrotliMiddleware {
pub fn new() -> Self {
Self {
config: BrotliConfig::default(),
}
}
pub fn with_config(config: BrotliConfig) -> Self {
Self { config }
}
fn accepts_brotli(&self, request: &Request) -> bool {
if let Some(accept_encoding) = request.headers.get(ACCEPT_ENCODING)
&& let Ok(value) = accept_encoding.to_str()
{
return value.to_lowercase().contains("br");
}
false
}
fn is_compressible(&self, content_type: &str) -> bool {
self.config
.compressible_types
.iter()
.any(|ct| content_type.starts_with(ct))
}
fn compress(&self, data: &[u8]) -> std::io::Result<Vec<u8>> {
let params = BrotliEncoderParams {
quality: self.config.quality.to_value() as i32,
lgwin: self.config.window_size as i32,
..Default::default()
};
let mut output = Vec::new();
let mut reader = std::io::Cursor::new(data);
brotli::BrotliCompress(&mut reader, &mut output, ¶ms)?;
Ok(output)
}
}
impl Default for BrotliMiddleware {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Middleware for BrotliMiddleware {
async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
if !self.accepts_brotli(&request) {
return handler.handle(request).await;
}
let mut response = match handler.handle(request).await {
Ok(resp) => resp,
Err(e) => Response::from(e),
};
if response.headers.contains_key(CONTENT_ENCODING) {
return Ok(response);
}
if response.body.len() < self.config.min_length {
return Ok(response);
}
let content_type = response
.headers
.get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !self.is_compressible(content_type) {
return Ok(response);
}
match self.compress(&response.body) {
Ok(compressed) => {
if compressed.len() < response.body.len() {
response.body = Bytes::from(compressed);
response
.headers
.insert(CONTENT_ENCODING, HeaderValue::from_static("br"));
response.headers.remove(CONTENT_LENGTH);
}
Ok(response)
}
Err(_) => {
Ok(response)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use hyper::{HeaderMap, Method, StatusCode, Version};
struct TestHandler {
body: String,
content_type: String,
}
impl TestHandler {
fn new(body: String, content_type: String) -> Self {
Self { body, content_type }
}
}
#[async_trait]
impl Handler for TestHandler {
async fn handle(&self, _request: Request) -> Result<Response> {
let mut response =
Response::new(StatusCode::OK).with_body(Bytes::from(self.body.clone()));
response
.headers
.insert(CONTENT_TYPE, self.content_type.parse().unwrap());
Ok(response)
}
}
#[tokio::test]
async fn test_brotli_compression_basic() {
let config = BrotliConfig::default();
let middleware = BrotliMiddleware::with_config(config);
let body = "This is a test body that should be compressed. ".repeat(10);
let handler = Arc::new(TestHandler::new(body.clone(), "text/html".to_string()));
let mut headers = HeaderMap::new();
headers.insert(ACCEPT_ENCODING, "br".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.headers.get(CONTENT_ENCODING).unwrap(), "br");
assert!(response.body.len() < body.len());
}
#[tokio::test]
async fn test_no_compression_without_accept_encoding() {
let middleware = BrotliMiddleware::new();
let body = "Test body".repeat(50);
let handler = Arc::new(TestHandler::new(body.clone(), "text/html".to_string()));
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(!response.headers.contains_key(CONTENT_ENCODING));
assert_eq!(response.body.len(), body.len());
}
#[tokio::test]
async fn test_no_compression_for_small_body() {
let config = BrotliConfig {
min_length: 1000,
..Default::default()
};
let middleware = BrotliMiddleware::with_config(config);
let body = "Small body";
let handler = Arc::new(TestHandler::new(body.to_string(), "text/html".to_string()));
let mut headers = HeaderMap::new();
headers.insert(ACCEPT_ENCODING, "br".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(!response.headers.contains_key(CONTENT_ENCODING));
}
#[tokio::test]
async fn test_no_compression_for_non_text_content() {
let middleware = BrotliMiddleware::new();
let body = "Binary data".repeat(50);
let handler = Arc::new(TestHandler::new(body.clone(), "image/png".to_string()));
let mut headers = HeaderMap::new();
headers.insert(ACCEPT_ENCODING, "br".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(!response.headers.contains_key(CONTENT_ENCODING));
}
#[tokio::test]
async fn test_compression_quality_levels() {
for quality in &[
BrotliQuality::Fast,
BrotliQuality::Balanced,
BrotliQuality::Best,
] {
let config = BrotliConfig {
quality: *quality,
..Default::default()
};
let middleware = BrotliMiddleware::with_config(config);
let body = "Test compression quality. ".repeat(20);
let handler = Arc::new(TestHandler::new(body.clone(), "text/html".to_string()));
let mut headers = HeaderMap::new();
headers.insert(ACCEPT_ENCODING, "br".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.headers.get(CONTENT_ENCODING).unwrap(), "br");
assert!(response.body.len() < body.len());
}
}
#[tokio::test]
async fn test_json_compression() {
let middleware = BrotliMiddleware::new();
let body = r#"{"data": "This is JSON data that should be compressed."}"#.repeat(10);
let handler = Arc::new(TestHandler::new(
body.clone(),
"application/json".to_string(),
));
let mut headers = HeaderMap::new();
headers.insert(ACCEPT_ENCODING, "br, gzip".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/api/data")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.headers.get(CONTENT_ENCODING).unwrap(), "br");
assert!(response.body.len() < body.len());
}
#[tokio::test]
async fn test_javascript_compression() {
let middleware = BrotliMiddleware::new();
let body = "function test() { console.log('hello'); }".repeat(10);
let handler = Arc::new(TestHandler::new(
body.clone(),
"application/javascript".to_string(),
));
let mut headers = HeaderMap::new();
headers.insert(ACCEPT_ENCODING, "br".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/script.js")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.headers.get(CONTENT_ENCODING).unwrap(), "br");
}
#[tokio::test]
async fn test_custom_compressible_types() {
let config = BrotliConfig {
compressible_types: vec!["application/custom".to_string()],
..Default::default()
};
let middleware = BrotliMiddleware::with_config(config);
let body = "Custom content type data. ".repeat(20);
let handler = Arc::new(TestHandler::new(
body.clone(),
"application/custom".to_string(),
));
let mut headers = HeaderMap::new();
headers.insert(ACCEPT_ENCODING, "br".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/custom")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.headers.get(CONTENT_ENCODING).unwrap(), "br");
}
#[tokio::test]
async fn test_window_size_configuration() {
let config = BrotliConfig {
window_size: 18,
..Default::default()
};
let middleware = BrotliMiddleware::with_config(config);
let body = "Test window size. ".repeat(20);
let handler = Arc::new(TestHandler::new(body.clone(), "text/html".to_string()));
let mut headers = HeaderMap::new();
headers.insert(ACCEPT_ENCODING, "br".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.headers.get(CONTENT_ENCODING).unwrap(), "br");
}
}