llm_cost_ops/compression/
middleware.rs1use super::{
4 codec, CompressionAlgorithm, CompressionConfig, CompressionError,
5 CompressionMetrics, CompressionResult,
6};
7use axum::{
8 body::Body,
9 extract::Request,
10 http::{header, HeaderValue, StatusCode},
11 middleware::Next,
12 response::{IntoResponse, Response},
13};
14use http_body_util::BodyExt;
15use std::str::FromStr;
16use std::sync::Arc;
17use tower::{Layer, Service};
18use tracing::{debug, warn};
19
20#[derive(Clone)]
22pub struct CompressionLayer {
23 config: Arc<CompressionConfig>,
24 metrics: Arc<CompressionMetrics>,
25}
26
27impl CompressionLayer {
28 pub fn new(config: CompressionConfig) -> Self {
30 Self {
31 config: Arc::new(config),
32 metrics: super::metrics::get_metrics(),
33 }
34 }
35
36 pub fn with_metrics(config: CompressionConfig, metrics: Arc<CompressionMetrics>) -> Self {
38 Self {
39 config: Arc::new(config),
40 metrics,
41 }
42 }
43}
44
45impl<S> Layer<S> for CompressionLayer {
46 type Service = CompressionService<S>;
47
48 fn layer(&self, inner: S) -> Self::Service {
49 CompressionService {
50 inner,
51 config: self.config.clone(),
52 metrics: self.metrics.clone(),
53 }
54 }
55}
56
57#[derive(Clone)]
59pub struct CompressionService<S> {
60 inner: S,
61 config: Arc<CompressionConfig>,
62 metrics: Arc<CompressionMetrics>,
63}
64
65impl<S> Service<Request> for CompressionService<S>
66where
67 S: Service<Request, Response = Response> + Clone + Send + 'static,
68 S::Future: Send + 'static,
69{
70 type Response = Response;
71 type Error = S::Error;
72 type Future = std::pin::Pin<
73 Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
74 >;
75
76 fn poll_ready(
77 &mut self,
78 cx: &mut std::task::Context<'_>,
79 ) -> std::task::Poll<Result<(), Self::Error>> {
80 self.inner.poll_ready(cx)
81 }
82
83 fn call(&mut self, request: Request) -> Self::Future {
84 let config = self.config.clone();
85 let metrics = self.metrics.clone();
86 let mut inner = self.inner.clone();
87
88 Box::pin(async move {
89 let request = if config.compress_requests {
91 match decompress_request(request, &config, &metrics).await {
92 Ok(req) => req,
93 Err(e) => {
94 warn!(error = %e, "Failed to decompress request");
95 return Ok(error_response(
96 StatusCode::BAD_REQUEST,
97 "Failed to decompress request body",
98 ));
99 }
100 }
101 } else {
102 request
103 };
104
105 let accept_encoding = request
107 .headers()
108 .get(header::ACCEPT_ENCODING)
109 .and_then(|h| h.to_str().ok())
110 .map(|s| s.to_string());
111
112 let response = inner.call(request).await?;
114
115 let response = if config.compress_responses {
117 if let Some(accept_encoding) = accept_encoding {
118 compress_response(response, &accept_encoding, &config, &metrics).await
119 } else {
120 response
121 }
122 } else {
123 response
124 };
125
126 Ok(response)
127 })
128 }
129}
130
131async fn decompress_request(
133 request: Request,
134 _config: &CompressionConfig,
135 metrics: &CompressionMetrics,
136) -> CompressionResult<Request> {
137 let (mut parts, body) = request.into_parts();
138
139 let encoding = parts
141 .headers
142 .get(header::CONTENT_ENCODING)
143 .and_then(|h| h.to_str().ok());
144
145 let encoding = match encoding {
146 Some(e) => e,
147 None => return Ok(Request::from_parts(parts, body)), };
149
150 let algorithm = match CompressionAlgorithm::from_str(encoding) {
152 Ok(algo) => algo,
153 Err(_) => return Ok(Request::from_parts(parts, body)), };
155
156 if algorithm == CompressionAlgorithm::Identity {
157 return Ok(Request::from_parts(parts, body));
158 }
159
160 let bytes = body
162 .collect()
163 .await
164 .map_err(|e| CompressionError::DecompressionFailed(e.to_string()))?
165 .to_bytes();
166
167 let (decompressed, stats) = codec::decompress(&bytes, algorithm)?;
169
170 metrics.record_decompression(&stats);
172
173 parts.headers.remove(header::CONTENT_ENCODING);
175
176 parts.headers.insert(
178 header::CONTENT_LENGTH,
179 HeaderValue::from_str(&decompressed.len().to_string()).unwrap(),
180 );
181
182 debug!(
183 algorithm = %algorithm,
184 original_size = stats.compressed_size,
185 decompressed_size = stats.original_size,
186 "Request decompressed"
187 );
188
189 Ok(Request::from_parts(parts, Body::from(decompressed)))
190}
191
192async fn compress_response(
194 response: Response,
195 accept_encoding: &str,
196 config: &CompressionConfig,
197 metrics: &CompressionMetrics,
198) -> Response {
199 let algorithm = match config.select_algorithm(Some(accept_encoding)) {
201 Some(algo) if algo != CompressionAlgorithm::Identity => algo,
202 _ => return response, };
204
205 let (mut parts, body) = response.into_parts();
206
207 let content_type = parts
209 .headers
210 .get(header::CONTENT_TYPE)
211 .and_then(|h| h.to_str().ok())
212 .unwrap_or("");
213
214 if !config.should_compress_mime_type(content_type) {
215 debug!(
216 content_type = content_type,
217 "Content type not compressible"
218 );
219 return Response::from_parts(parts, body);
220 }
221
222 let bytes = match body.collect().await {
224 Ok(collected) => collected.to_bytes(),
225 Err(e) => {
226 warn!(error = %e, "Failed to collect response body");
227 return error_response(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error");
228 }
229 };
230
231 if !config.should_compress_size(bytes.len()) {
233 debug!(
234 size = bytes.len(),
235 min_size = config.min_size,
236 "Response too small to compress"
237 );
238 return Response::from_parts(parts, Body::from(bytes));
239 }
240
241 let (compressed, stats) = match codec::compress(&bytes, algorithm, config.level) {
243 Ok(result) => result,
244 Err(e) => {
245 warn!(error = %e, algorithm = %algorithm, "Compression failed");
246 metrics.record_error(Some(algorithm), "compress");
247 return Response::from_parts(parts, Body::from(bytes));
248 }
249 };
250
251 if compressed.len() >= bytes.len() {
253 debug!(
254 original_size = bytes.len(),
255 compressed_size = compressed.len(),
256 "Compression not beneficial, using original"
257 );
258 return Response::from_parts(parts, Body::from(bytes));
259 }
260
261 metrics.record_compression(&stats);
263
264 parts.headers.insert(
266 header::CONTENT_ENCODING,
267 HeaderValue::from_str(algorithm.as_str()).unwrap(),
268 );
269
270 parts.headers.insert(
271 header::CONTENT_LENGTH,
272 HeaderValue::from_str(&compressed.len().to_string()).unwrap(),
273 );
274
275 parts
277 .headers
278 .entry(header::VARY)
279 .or_insert(HeaderValue::from_static("Accept-Encoding"));
280
281 debug!(
282 algorithm = %algorithm,
283 original_size = stats.original_size,
284 compressed_size = stats.compressed_size,
285 ratio = stats.compression_ratio,
286 savings_pct = stats.compression_percentage(),
287 "Response compressed"
288 );
289
290 Response::from_parts(parts, Body::from(compressed))
291}
292
293fn error_response(status: StatusCode, message: &str) -> Response {
295 (status, message.to_string()).into_response()
296}
297
298pub fn compression_layer() -> CompressionLayer {
300 CompressionLayer::new(CompressionConfig::default())
301}
302
303pub async fn compression_middleware(
305 request: Request,
306 next: Next,
307) -> Result<Response, StatusCode> {
308 let config = Arc::new(CompressionConfig::default());
309 let metrics = super::metrics::get_metrics();
310
311 let request = if config.compress_requests {
313 match decompress_request(request, &config, &metrics).await {
314 Ok(req) => req,
315 Err(e) => {
316 warn!(error = %e, "Failed to decompress request");
317 return Ok(error_response(
318 StatusCode::BAD_REQUEST,
319 "Failed to decompress request body",
320 ));
321 }
322 }
323 } else {
324 request
325 };
326
327 let accept_encoding = request
329 .headers()
330 .get(header::ACCEPT_ENCODING)
331 .and_then(|h| h.to_str().ok())
332 .map(|s| s.to_string());
333
334 let response = next.run(request).await;
336
337 let response = if config.compress_responses {
339 if let Some(accept_encoding) = accept_encoding {
340 compress_response(response, &accept_encoding, &config, &metrics).await
341 } else {
342 response
343 }
344 } else {
345 response
346 };
347
348 Ok(response)
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354 use crate::CompressionLevel;
355 use axum::{routing::get, Router};
356 use tower::ServiceExt;
357
358 async fn test_handler() -> ([(header::HeaderName, &'static str); 1], &'static str) {
359 ([
360 (header::CONTENT_TYPE, "text/plain"),
361 ], "Hello, World! This is a test response that should be compressed. Adding more text to ensure it compresses well and the compressed size is smaller than the original.")
362 }
363
364 #[tokio::test]
365 async fn test_compression_layer_creation() {
366 let config = CompressionConfig::default();
367 let layer = CompressionLayer::new(config);
368 drop(layer);
370 }
371
372 #[tokio::test]
373 async fn test_compression_response() {
374 let config = CompressionConfig {
375 enabled: true,
376 level: CompressionLevel::Default,
377 algorithms: vec![CompressionAlgorithm::Gzip],
378 min_size: 10, max_size: None,
380 mime_types: vec!["text/*".to_string()],
381 compress_requests: false,
382 compress_responses: true,
383 buffer_size: 8192,
384 };
385
386 let app = Router::new()
387 .route("/", get(test_handler))
388 .layer(CompressionLayer::new(config));
389
390 let request = Request::builder()
391 .uri("/")
392 .header(header::ACCEPT_ENCODING, "gzip")
393 .body(Body::empty())
394 .unwrap();
395
396 let response = app.oneshot(request).await.unwrap();
397
398 assert_eq!(
400 response.headers().get(header::CONTENT_ENCODING).unwrap(),
401 "gzip"
402 );
403 }
404
405 #[tokio::test]
406 async fn test_no_compression_without_accept_encoding() {
407 let config = CompressionConfig::default();
408
409 let app = Router::new()
410 .route("/", get(test_handler))
411 .layer(CompressionLayer::new(config));
412
413 let request = Request::builder()
414 .uri("/")
415 .body(Body::empty())
416 .unwrap();
417
418 let response = app.oneshot(request).await.unwrap();
419
420 assert!(response.headers().get(header::CONTENT_ENCODING).is_none());
422 }
423
424 #[tokio::test]
425 async fn test_compression_with_brotli() {
426 let config = CompressionConfig {
427 enabled: true,
428 level: CompressionLevel::Default,
429 algorithms: vec![CompressionAlgorithm::Brotli],
430 min_size: 10,
431 max_size: None,
432 mime_types: vec!["text/*".to_string()],
433 compress_requests: false,
434 compress_responses: true,
435 buffer_size: 8192,
436 };
437
438 let app = Router::new()
439 .route("/", get(test_handler))
440 .layer(CompressionLayer::new(config));
441
442 let request = Request::builder()
443 .uri("/")
444 .header(header::ACCEPT_ENCODING, "br")
445 .body(Body::empty())
446 .unwrap();
447
448 let response = app.oneshot(request).await.unwrap();
449
450 assert_eq!(
452 response.headers().get(header::CONTENT_ENCODING).unwrap(),
453 "br"
454 );
455 }
456}