llm_cost_ops/compression/
middleware.rs

1// HTTP compression middleware for Axum
2
3use super::{
4    codec, CompressionAlgorithm, CompressionConfig, CompressionError,
5    CompressionMetrics, CompressionResult,
6};
7use axum::{
8    body::Body,
9    extract::Request,
10    http::{header, HeaderValue, StatusCode},
11    middleware::Next,
12    response::{IntoResponse, Response},
13};
14use http_body_util::BodyExt;
15use std::str::FromStr;
16use std::sync::Arc;
17use tower::{Layer, Service};
18use tracing::{debug, warn};
19
20/// Compression layer for Axum
21#[derive(Clone)]
22pub struct CompressionLayer {
23    config: Arc<CompressionConfig>,
24    metrics: Arc<CompressionMetrics>,
25}
26
27impl CompressionLayer {
28    /// Create a new compression layer
29    pub fn new(config: CompressionConfig) -> Self {
30        Self {
31            config: Arc::new(config),
32            metrics: super::metrics::get_metrics(),
33        }
34    }
35
36    /// Create with custom metrics
37    pub fn with_metrics(config: CompressionConfig, metrics: Arc<CompressionMetrics>) -> Self {
38        Self {
39            config: Arc::new(config),
40            metrics,
41        }
42    }
43}
44
45impl<S> Layer<S> for CompressionLayer {
46    type Service = CompressionService<S>;
47
48    fn layer(&self, inner: S) -> Self::Service {
49        CompressionService {
50            inner,
51            config: self.config.clone(),
52            metrics: self.metrics.clone(),
53        }
54    }
55}
56
57/// Compression service
58#[derive(Clone)]
59pub struct CompressionService<S> {
60    inner: S,
61    config: Arc<CompressionConfig>,
62    metrics: Arc<CompressionMetrics>,
63}
64
65impl<S> Service<Request> for CompressionService<S>
66where
67    S: Service<Request, Response = Response> + Clone + Send + 'static,
68    S::Future: Send + 'static,
69{
70    type Response = Response;
71    type Error = S::Error;
72    type Future = std::pin::Pin<
73        Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
74    >;
75
76    fn poll_ready(
77        &mut self,
78        cx: &mut std::task::Context<'_>,
79    ) -> std::task::Poll<Result<(), Self::Error>> {
80        self.inner.poll_ready(cx)
81    }
82
83    fn call(&mut self, request: Request) -> Self::Future {
84        let config = self.config.clone();
85        let metrics = self.metrics.clone();
86        let mut inner = self.inner.clone();
87
88        Box::pin(async move {
89            // Handle request decompression if enabled
90            let request = if config.compress_requests {
91                match decompress_request(request, &config, &metrics).await {
92                    Ok(req) => req,
93                    Err(e) => {
94                        warn!(error = %e, "Failed to decompress request");
95                        return Ok(error_response(
96                            StatusCode::BAD_REQUEST,
97                            "Failed to decompress request body",
98                        ));
99                    }
100                }
101            } else {
102                request
103            };
104
105            // Get Accept-Encoding header before passing request
106            let accept_encoding = request
107                .headers()
108                .get(header::ACCEPT_ENCODING)
109                .and_then(|h| h.to_str().ok())
110                .map(|s| s.to_string());
111
112            // Call inner service
113            let response = inner.call(request).await?;
114
115            // Handle response compression if enabled
116            let response = if config.compress_responses {
117                if let Some(accept_encoding) = accept_encoding {
118                    compress_response(response, &accept_encoding, &config, &metrics).await
119                } else {
120                    response
121                }
122            } else {
123                response
124            };
125
126            Ok(response)
127        })
128    }
129}
130
131/// Decompress request body
132async fn decompress_request(
133    request: Request,
134    _config: &CompressionConfig,
135    metrics: &CompressionMetrics,
136) -> CompressionResult<Request> {
137    let (mut parts, body) = request.into_parts();
138
139    // Check for Content-Encoding header
140    let encoding = parts
141        .headers
142        .get(header::CONTENT_ENCODING)
143        .and_then(|h| h.to_str().ok());
144
145    let encoding = match encoding {
146        Some(e) => e,
147        None => return Ok(Request::from_parts(parts, body)), // No encoding, return as-is
148    };
149
150    // Parse algorithm
151    let algorithm = match CompressionAlgorithm::from_str(encoding) {
152        Ok(algo) => algo,
153        Err(_) => return Ok(Request::from_parts(parts, body)), // Unknown encoding, pass through
154    };
155
156    if algorithm == CompressionAlgorithm::Identity {
157        return Ok(Request::from_parts(parts, body));
158    }
159
160    // Collect body bytes
161    let bytes = body
162        .collect()
163        .await
164        .map_err(|e| CompressionError::DecompressionFailed(e.to_string()))?
165        .to_bytes();
166
167    // Decompress
168    let (decompressed, stats) = codec::decompress(&bytes, algorithm)?;
169
170    // Record metrics
171    metrics.record_decompression(&stats);
172
173    // Remove Content-Encoding header
174    parts.headers.remove(header::CONTENT_ENCODING);
175
176    // Update Content-Length
177    parts.headers.insert(
178        header::CONTENT_LENGTH,
179        HeaderValue::from_str(&decompressed.len().to_string()).unwrap(),
180    );
181
182    debug!(
183        algorithm = %algorithm,
184        original_size = stats.compressed_size,
185        decompressed_size = stats.original_size,
186        "Request decompressed"
187    );
188
189    Ok(Request::from_parts(parts, Body::from(decompressed)))
190}
191
192/// Compress response body
193async fn compress_response(
194    response: Response,
195    accept_encoding: &str,
196    config: &CompressionConfig,
197    metrics: &CompressionMetrics,
198) -> Response {
199    // Select compression algorithm
200    let algorithm = match config.select_algorithm(Some(accept_encoding)) {
201        Some(algo) if algo != CompressionAlgorithm::Identity => algo,
202        _ => return response, // No compression
203    };
204
205    let (mut parts, body) = response.into_parts();
206
207    // Check if content type should be compressed
208    let content_type = parts
209        .headers
210        .get(header::CONTENT_TYPE)
211        .and_then(|h| h.to_str().ok())
212        .unwrap_or("");
213
214    if !config.should_compress_mime_type(content_type) {
215        debug!(
216            content_type = content_type,
217            "Content type not compressible"
218        );
219        return Response::from_parts(parts, body);
220    }
221
222    // Collect body bytes
223    let bytes = match body.collect().await {
224        Ok(collected) => collected.to_bytes(),
225        Err(e) => {
226            warn!(error = %e, "Failed to collect response body");
227            return error_response(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error");
228        }
229    };
230
231    // Check size thresholds
232    if !config.should_compress_size(bytes.len()) {
233        debug!(
234            size = bytes.len(),
235            min_size = config.min_size,
236            "Response too small to compress"
237        );
238        return Response::from_parts(parts, Body::from(bytes));
239    }
240
241    // Compress
242    let (compressed, stats) = match codec::compress(&bytes, algorithm, config.level) {
243        Ok(result) => result,
244        Err(e) => {
245            warn!(error = %e, algorithm = %algorithm, "Compression failed");
246            metrics.record_error(Some(algorithm), "compress");
247            return Response::from_parts(parts, Body::from(bytes));
248        }
249    };
250
251    // Check if compression was beneficial
252    if compressed.len() >= bytes.len() {
253        debug!(
254            original_size = bytes.len(),
255            compressed_size = compressed.len(),
256            "Compression not beneficial, using original"
257        );
258        return Response::from_parts(parts, Body::from(bytes));
259    }
260
261    // Record metrics
262    metrics.record_compression(&stats);
263
264    // Update headers
265    parts.headers.insert(
266        header::CONTENT_ENCODING,
267        HeaderValue::from_str(algorithm.as_str()).unwrap(),
268    );
269
270    parts.headers.insert(
271        header::CONTENT_LENGTH,
272        HeaderValue::from_str(&compressed.len().to_string()).unwrap(),
273    );
274
275    // Add Vary header to indicate content negotiation
276    parts
277        .headers
278        .entry(header::VARY)
279        .or_insert(HeaderValue::from_static("Accept-Encoding"));
280
281    debug!(
282        algorithm = %algorithm,
283        original_size = stats.original_size,
284        compressed_size = stats.compressed_size,
285        ratio = stats.compression_ratio,
286        savings_pct = stats.compression_percentage(),
287        "Response compressed"
288    );
289
290    Response::from_parts(parts, Body::from(compressed))
291}
292
293/// Create error response
294fn error_response(status: StatusCode, message: &str) -> Response {
295    (status, message.to_string()).into_response()
296}
297
298/// Create compression layer with default config
299pub fn compression_layer() -> CompressionLayer {
300    CompressionLayer::new(CompressionConfig::default())
301}
302
303/// Middleware function for manual application
304pub async fn compression_middleware(
305    request: Request,
306    next: Next,
307) -> Result<Response, StatusCode> {
308    let config = Arc::new(CompressionConfig::default());
309    let metrics = super::metrics::get_metrics();
310
311    // Handle request decompression
312    let request = if config.compress_requests {
313        match decompress_request(request, &config, &metrics).await {
314            Ok(req) => req,
315            Err(e) => {
316                warn!(error = %e, "Failed to decompress request");
317                return Ok(error_response(
318                    StatusCode::BAD_REQUEST,
319                    "Failed to decompress request body",
320                ));
321            }
322        }
323    } else {
324        request
325    };
326
327    // Get Accept-Encoding before passing request
328    let accept_encoding = request
329        .headers()
330        .get(header::ACCEPT_ENCODING)
331        .and_then(|h| h.to_str().ok())
332        .map(|s| s.to_string());
333
334    // Call next middleware/handler
335    let response = next.run(request).await;
336
337    // Handle response compression
338    let response = if config.compress_responses {
339        if let Some(accept_encoding) = accept_encoding {
340            compress_response(response, &accept_encoding, &config, &metrics).await
341        } else {
342            response
343        }
344    } else {
345        response
346    };
347
348    Ok(response)
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use crate::CompressionLevel;
355    use axum::{routing::get, Router};
356    use tower::ServiceExt;
357
358    async fn test_handler() -> ([(header::HeaderName, &'static str); 1], &'static str) {
359        ([
360            (header::CONTENT_TYPE, "text/plain"),
361        ], "Hello, World! This is a test response that should be compressed. Adding more text to ensure it compresses well and the compressed size is smaller than the original.")
362    }
363
364    #[tokio::test]
365    async fn test_compression_layer_creation() {
366        let config = CompressionConfig::default();
367        let layer = CompressionLayer::new(config);
368        // Should not panic
369        drop(layer);
370    }
371
372    #[tokio::test]
373    async fn test_compression_response() {
374        let config = CompressionConfig {
375            enabled: true,
376            level: CompressionLevel::Default,
377            algorithms: vec![CompressionAlgorithm::Gzip],
378            min_size: 10, // Low threshold for testing
379            max_size: None,
380            mime_types: vec!["text/*".to_string()],
381            compress_requests: false,
382            compress_responses: true,
383            buffer_size: 8192,
384        };
385
386        let app = Router::new()
387            .route("/", get(test_handler))
388            .layer(CompressionLayer::new(config));
389
390        let request = Request::builder()
391            .uri("/")
392            .header(header::ACCEPT_ENCODING, "gzip")
393            .body(Body::empty())
394            .unwrap();
395
396        let response = app.oneshot(request).await.unwrap();
397
398        // Should have Content-Encoding header
399        assert_eq!(
400            response.headers().get(header::CONTENT_ENCODING).unwrap(),
401            "gzip"
402        );
403    }
404
405    #[tokio::test]
406    async fn test_no_compression_without_accept_encoding() {
407        let config = CompressionConfig::default();
408
409        let app = Router::new()
410            .route("/", get(test_handler))
411            .layer(CompressionLayer::new(config));
412
413        let request = Request::builder()
414            .uri("/")
415            .body(Body::empty())
416            .unwrap();
417
418        let response = app.oneshot(request).await.unwrap();
419
420        // Should not have Content-Encoding header
421        assert!(response.headers().get(header::CONTENT_ENCODING).is_none());
422    }
423
424    #[tokio::test]
425    async fn test_compression_with_brotli() {
426        let config = CompressionConfig {
427            enabled: true,
428            level: CompressionLevel::Default,
429            algorithms: vec![CompressionAlgorithm::Brotli],
430            min_size: 10,
431            max_size: None,
432            mime_types: vec!["text/*".to_string()],
433            compress_requests: false,
434            compress_responses: true,
435            buffer_size: 8192,
436        };
437
438        let app = Router::new()
439            .route("/", get(test_handler))
440            .layer(CompressionLayer::new(config));
441
442        let request = Request::builder()
443            .uri("/")
444            .header(header::ACCEPT_ENCODING, "br")
445            .body(Body::empty())
446            .unwrap();
447
448        let response = app.oneshot(request).await.unwrap();
449
450        // Should use brotli
451        assert_eq!(
452            response.headers().get(header::CONTENT_ENCODING).unwrap(),
453            "br"
454        );
455    }
456}