elif_http/middleware/utils/
compression.rs

1//! # Compression Middleware
2//!
3//! Provides response compression using tower-http's battle-tested CompressionLayer.
4//! This is an adapter to make it work with the V2 middleware pattern.
5
6use crate::middleware::v2::{Middleware, Next, NextFuture};
7use crate::request::ElifRequest;
8use crate::response::ElifResponse;
9use http_body_util::BodyExt;
10use tower::{Layer, Service};
11use tower_http::compression::{CompressionLayer, CompressionLevel};
12
13/// Configuration for compression middleware
14#[derive(Debug, Clone)]
15pub struct CompressionConfig {
16    /// Compression level
17    pub level: CompressionLevel,
18    /// Enable gzip compression
19    pub enable_gzip: bool,
20    /// Enable brotli compression  
21    pub enable_brotli: bool,
22    /// Enable deflate compression
23    pub enable_deflate: bool,
24}
25
26impl Default for CompressionConfig {
27    fn default() -> Self {
28        Self {
29            level: CompressionLevel::default(),
30            enable_gzip: true,
31            enable_brotli: true,
32            enable_deflate: false, // Less common, disabled by default
33        }
34    }
35}
36
37/// Middleware for compressing HTTP responses using tower-http
38pub struct CompressionMiddleware {
39    layer: CompressionLayer,
40}
41
42impl CompressionMiddleware {
43    /// Create new compression middleware with default configuration
44    pub fn new() -> Self {
45        let config = CompressionConfig::default();
46        Self::with_config(config)
47    }
48
49    /// Create compression middleware with custom configuration
50    pub fn with_config(config: CompressionConfig) -> Self {
51        let mut layer = CompressionLayer::new().quality(config.level);
52
53        // Enable/disable compression algorithms based on config
54        if !config.enable_gzip {
55            layer = layer.no_gzip();
56        }
57        if !config.enable_brotli {
58            layer = layer.no_br();
59        }
60        if !config.enable_deflate {
61            layer = layer.no_deflate();
62        }
63
64        Self { layer }
65    }
66
67    /// Set compression level (consuming)
68    pub fn level(self, level: CompressionLevel) -> Self {
69        Self {
70            layer: self.layer.quality(level),
71        }
72    }
73
74    /// Set fast compression (level 1)
75    pub fn fast(self) -> Self {
76        self.level(CompressionLevel::Fastest)
77    }
78
79    /// Set best compression (level 9)
80    pub fn best(self) -> Self {
81        self.level(CompressionLevel::Best)
82    }
83
84    /// Disable gzip compression
85    pub fn no_gzip(self) -> Self {
86        Self {
87            layer: self.layer.no_gzip(),
88        }
89    }
90
91    /// Disable brotli compression
92    pub fn no_brotli(self) -> Self {
93        Self {
94            layer: self.layer.no_br(),
95        }
96    }
97
98    /// Disable deflate compression
99    pub fn no_deflate(self) -> Self {
100        Self {
101            layer: self.layer.no_deflate(),
102        }
103    }
104
105    /// Enable only gzip compression
106    pub fn gzip_only(self) -> Self {
107        Self {
108            layer: self.layer.no_br().no_deflate(),
109        }
110    }
111
112    /// Enable only brotli compression
113    pub fn brotli_only(self) -> Self {
114        Self {
115            layer: self.layer.no_gzip().no_deflate(),
116        }
117    }
118}
119
120impl std::fmt::Debug for CompressionMiddleware {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        f.debug_struct("CompressionMiddleware")
123            .field("layer", &"<CompressionLayer>")
124            .finish()
125    }
126}
127
128impl Default for CompressionMiddleware {
129    fn default() -> Self {
130        Self::new()
131    }
132}
133
134impl Clone for CompressionMiddleware {
135    fn clone(&self) -> Self {
136        Self {
137            layer: self.layer.clone(),
138        }
139    }
140}
141
142impl Middleware for CompressionMiddleware {
143    fn handle(&self, request: ElifRequest, next: Next) -> NextFuture<'static> {
144        let layer = self.layer.clone();
145
146        Box::pin(async move {
147            // Check if the client accepts compression from the original request
148            let accept_encoding = request
149                .header("accept-encoding")
150                .and_then(|h| h.to_str().ok())
151                .map(|s| s.to_owned())
152                .unwrap_or_default();
153
154            let wants_compression = accept_encoding.contains("gzip")
155                || accept_encoding.contains("br")
156                || accept_encoding.contains("deflate");
157
158            // First get the response from the next handler
159            let response = next.run(request).await;
160
161            if !wants_compression {
162                // Client doesn't want compression, return as-is
163                return response;
164            }
165
166            let axum_response = response.into_axum_response();
167            let (parts, body) = axum_response.into_parts();
168
169            // Read the response body to compress it
170            let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
171                Ok(bytes) => bytes,
172                Err(_) => {
173                    // Can't read body, return as-is
174                    let response =
175                        axum::response::Response::from_parts(parts, axum::body::Body::empty());
176                    return ElifResponse::from_axum_response(response).await;
177                }
178            };
179
180            // Store copies for fallback use
181            let parts_clone = parts.clone();
182            let body_bytes_clone = body_bytes.clone();
183
184            // Create a mock request for the compression service
185            let mock_request = axum::extract::Request::builder()
186                .uri("/")
187                .header("accept-encoding", &accept_encoding)
188                .body(axum::body::Body::empty())
189                .unwrap();
190
191            // Create a service that returns our response body
192            let service = tower::service_fn(move |_req: axum::extract::Request| {
193                let response_parts = parts.clone();
194                let response_body = body_bytes.clone();
195                async move {
196                    let response = axum::response::Response::from_parts(
197                        response_parts,
198                        axum::body::Body::from(response_body),
199                    );
200                    Ok::<axum::response::Response, std::convert::Infallible>(response)
201                }
202            });
203
204            // Apply compression layer
205            let mut compression_service = layer.layer(service);
206
207            // Call the compression service
208            match compression_service.call(mock_request).await {
209                Ok(compressed_response) => {
210                    // Extract the compressed response
211                    let (compressed_parts, compressed_body) = compressed_response.into_parts();
212
213                    // Convert CompressionBody to bytes
214                    match compressed_body.collect().await {
215                        Ok(collected) => {
216                            // Get the compressed bytes
217                            let compressed_bytes = collected.to_bytes();
218
219                            // Create final response with compressed body
220                            let final_response = axum::response::Response::from_parts(
221                                compressed_parts,
222                                axum::body::Body::from(compressed_bytes),
223                            );
224
225                            // Convert back to ElifResponse
226                            ElifResponse::from_axum_response(final_response).await
227                        }
228                        Err(_) => {
229                            // Fallback: return original response if compression fails
230                            let original_response = axum::response::Response::from_parts(
231                                parts_clone,
232                                axum::body::Body::from(body_bytes_clone),
233                            );
234                            ElifResponse::from_axum_response(original_response).await
235                        }
236                    }
237                }
238                Err(_) => {
239                    // Fallback: return original response if compression service fails
240                    let original_response = axum::response::Response::from_parts(
241                        parts_clone,
242                        axum::body::Body::from(body_bytes_clone),
243                    );
244                    ElifResponse::from_axum_response(original_response).await
245                }
246            }
247        })
248    }
249
250    fn name(&self) -> &'static str {
251        "CompressionMiddleware"
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use crate::request::ElifRequest;
259    use crate::response::ElifResponse;
260
261    #[test]
262    fn test_compression_config() {
263        let config = CompressionConfig::default();
264        assert!(config.enable_gzip);
265        assert!(config.enable_brotli);
266        assert!(!config.enable_deflate);
267    }
268
269    #[tokio::test]
270    async fn test_compression_middleware() {
271        let middleware = CompressionMiddleware::new();
272
273        // Create request with accept-encoding
274        let mut headers = crate::response::headers::ElifHeaderMap::new();
275        let encoding_header =
276            crate::response::headers::ElifHeaderName::from_str("accept-encoding").unwrap();
277        let encoding_value =
278            crate::response::headers::ElifHeaderValue::from_str("gzip, br").unwrap();
279        headers.insert(encoding_header, encoding_value);
280        let request = ElifRequest::new(
281            crate::request::ElifMethod::GET,
282            "/api/data".parse().unwrap(),
283            headers,
284        );
285
286        // Create next handler that returns JSON response
287        let next = Next::new(|_req| {
288            Box::pin(async move {
289                let json_data = serde_json::json!({
290                    "message": "Hello, World!".repeat(100), // Make it large enough to compress
291                    "data": (0..100).collect::<Vec<i32>>()
292                });
293                ElifResponse::ok().json_value(json_data)
294            })
295        });
296
297        // Execute middleware
298        let response = middleware.handle(request, next).await;
299
300        // Response should be successful
301        assert_eq!(
302            response.status_code(),
303            crate::response::status::ElifStatusCode::OK
304        );
305    }
306
307    #[tokio::test]
308    async fn test_compression_builder_pattern() {
309        let middleware = CompressionMiddleware::new()
310            .best() // Maximum compression
311            .gzip_only(); // Only gzip
312
313        // Test that it builds without errors
314        assert_eq!(middleware.name(), "CompressionMiddleware");
315    }
316
317    #[test]
318    fn test_compression_levels() {
319        let fast = CompressionMiddleware::new().fast();
320        let best = CompressionMiddleware::new().best();
321        let custom = CompressionMiddleware::new().level(CompressionLevel::Precise(5));
322
323        // All should build without errors
324        assert_eq!(fast.name(), "CompressionMiddleware");
325        assert_eq!(best.name(), "CompressionMiddleware");
326        assert_eq!(custom.name(), "CompressionMiddleware");
327    }
328
329    #[test]
330    fn test_algorithm_selection() {
331        let gzip_only = CompressionMiddleware::new().gzip_only();
332        let brotli_only = CompressionMiddleware::new().brotli_only();
333        let no_brotli = CompressionMiddleware::new().no_brotli();
334
335        // All should build without errors
336        assert_eq!(gzip_only.name(), "CompressionMiddleware");
337        assert_eq!(brotli_only.name(), "CompressionMiddleware");
338        assert_eq!(no_brotli.name(), "CompressionMiddleware");
339    }
340
341    #[test]
342    fn test_clone() {
343        let middleware = CompressionMiddleware::new().best();
344        let cloned = middleware.clone();
345
346        assert_eq!(cloned.name(), "CompressionMiddleware");
347    }
348}