http_response_compression/
future.rs

1use crate::body::CompressionBody;
2use crate::codec::Codec;
3use http::{Response, header};
4use pin_project_lite::pin_project;
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9pin_project! {
10    /// Future for compression service responses.
11    pub struct ResponseFuture<F> {
12        #[pin]
13        inner: F,
14        accepted_codec: Option<Codec>,
15        min_size: usize,
16    }
17}
18
19impl<F> ResponseFuture<F> {
20    pub(crate) fn new(inner: F, accepted_codec: Option<Codec>, min_size: usize) -> Self {
21        Self {
22            inner,
23            accepted_codec,
24            min_size,
25        }
26    }
27}
28
29impl<F, B, E> Future for ResponseFuture<F>
30where
31    F: Future<Output = Result<Response<B>, E>>,
32{
33    type Output = Result<Response<CompressionBody<B>>, E>;
34
35    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
36        let this = self.project();
37
38        match this.inner.poll(cx) {
39            Poll::Pending => Poll::Pending,
40            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
41            Poll::Ready(Ok(response)) => {
42                let response = wrap_response(response, *this.accepted_codec, *this.min_size);
43                Poll::Ready(Ok(response))
44            }
45        }
46    }
47}
48
49/// Wraps the response body with compression if appropriate.
50fn wrap_response<B>(
51    response: Response<B>,
52    accepted_codec: Option<Codec>,
53    min_size: usize,
54) -> Response<CompressionBody<B>> {
55    let (mut parts, body) = response.into_parts();
56
57    // Determine if we should compress
58    let should_compress = accepted_codec.is_some()
59        && !has_content_encoding(&parts.headers)
60        && !has_content_range(&parts.headers)
61        && !is_uncompressible_content_type(&parts.headers)
62        && !is_below_min_size(&parts.headers, min_size);
63
64    // Check for x-accel-buffering: no header or streaming content types
65    let always_flush = parts
66        .headers
67        .get("x-accel-buffering")
68        .and_then(|v| v.to_str().ok())
69        .is_some_and(|v| v.eq_ignore_ascii_case("no"))
70        || is_streaming_content_type(&parts.headers);
71
72    let body = if should_compress {
73        let codec = accepted_codec.unwrap();
74
75        // Add Content-Encoding header
76        parts.headers.insert(
77            header::CONTENT_ENCODING,
78            header::HeaderValue::from_static(codec.content_encoding()),
79        );
80
81        // Remove Content-Length since compressed size is unknown
82        parts.headers.remove(header::CONTENT_LENGTH);
83
84        // Remove Accept-Ranges since we can't support ranges on compressed content
85        parts.headers.remove(header::ACCEPT_RANGES);
86
87        // Add Accept-Encoding to Vary header if not present
88        add_vary_accept_encoding(&mut parts.headers);
89
90        CompressionBody::compressed(body, codec, always_flush)
91    } else {
92        CompressionBody::passthrough(body)
93    };
94
95    Response::from_parts(parts, body)
96}
97
98/// Checks if Content-Encoding header is already present.
99fn has_content_encoding(headers: &header::HeaderMap) -> bool {
100    headers.contains_key(header::CONTENT_ENCODING)
101}
102
103/// Checks if Content-Range header is present (range response).
104fn has_content_range(headers: &header::HeaderMap) -> bool {
105    headers.contains_key(header::CONTENT_RANGE)
106}
107
108/// Adds Accept-Encoding to the Vary header if not already present.
109fn add_vary_accept_encoding(headers: &mut header::HeaderMap) {
110    // Check all Vary headers to see if Accept-Encoding is already present
111    for vary in headers.get_all(header::VARY) {
112        if let Ok(vary_str) = vary.to_str() {
113            let dominated = vary_str.split(',').any(|v| {
114                let v = v.trim();
115                v.eq_ignore_ascii_case("*") || v.eq_ignore_ascii_case("accept-encoding")
116            });
117            if dominated {
118                return;
119            }
120        }
121    }
122
123    // Append Accept-Encoding to Vary header
124    headers.append(
125        header::VARY,
126        header::HeaderValue::from_static("accept-encoding"),
127    );
128}
129
130/// Checks if the content type should not be compressed.
131fn is_uncompressible_content_type(headers: &header::HeaderMap) -> bool {
132    let Some(content_type) = headers
133        .get(header::CONTENT_TYPE)
134        .and_then(|v| v.to_str().ok())
135    else {
136        return false;
137    };
138
139    let ct = content_type.to_ascii_lowercase();
140
141    // Skip all images except SVG
142    if ct.starts_with("image/") {
143        return !ct.starts_with("image/svg+xml");
144    }
145
146    // Skip gRPC except grpc-web
147    if ct.starts_with("application/grpc") {
148        return !ct.starts_with("application/grpc-web");
149    }
150
151    false
152}
153
154/// Checks if the content type requires always flushing (e.g., streaming).
155fn is_streaming_content_type(headers: &header::HeaderMap) -> bool {
156    headers
157        .get(header::CONTENT_TYPE)
158        .and_then(|v| v.to_str().ok())
159        .is_some_and(|content_type| {
160            let ct = content_type.to_ascii_lowercase();
161            ct.starts_with("text/event-stream") || ct.starts_with("application/grpc-web")
162        })
163}
164
165/// Checks if Content-Length is below the minimum size.
166fn is_below_min_size(headers: &header::HeaderMap, min_size: usize) -> bool {
167    headers
168        .get(header::CONTENT_LENGTH)
169        .and_then(|v| v.to_str().ok())
170        .and_then(|v| v.parse::<usize>().ok())
171        .is_some_and(|len| len < min_size)
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use crate::body::CompressState;
178
179    fn make_response(body: &'static str) -> Response<&'static str> {
180        Response::new(body)
181    }
182
183    fn make_response_with_headers<I>(body: &'static str, headers: I) -> Response<&'static str>
184    where
185        I: IntoIterator<Item = (&'static str, &'static str)>,
186    {
187        let mut response = Response::new(body);
188        for (name, value) in headers {
189            response
190                .headers_mut()
191                .insert(name, header::HeaderValue::from_static(value));
192        }
193        response
194    }
195
196    #[test]
197    fn test_compress_when_accept_encoding_present() {
198        let response = make_response("hello world");
199        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
200
201        // Should be compressed
202        match wrapped.body() {
203            crate::body::CompressionBody::Compressed { state, .. } => {
204                assert_eq!(state.state(), CompressState::Reading);
205            }
206            _ => panic!("Expected compressed body"),
207        }
208
209        // Should have Content-Encoding header
210        assert_eq!(
211            wrapped.headers().get(header::CONTENT_ENCODING).unwrap(),
212            "gzip"
213        );
214    }
215
216    #[test]
217    fn test_no_compress_when_no_accept_encoding() {
218        let response = make_response("hello world");
219        let wrapped = wrap_response(response, None, 0);
220
221        // Should be passthrough
222        match wrapped.body() {
223            crate::body::CompressionBody::Passthrough { .. } => {}
224            _ => panic!("Expected passthrough body"),
225        }
226
227        // Should NOT have Content-Encoding header
228        assert!(wrapped.headers().get(header::CONTENT_ENCODING).is_none());
229    }
230
231    #[test]
232    fn test_no_compress_when_content_encoding_present() {
233        let response =
234            make_response_with_headers("hello world", [("content-encoding", "identity")]);
235        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
236
237        // Should be passthrough
238        match wrapped.body() {
239            crate::body::CompressionBody::Passthrough { .. } => {}
240            _ => panic!("Expected passthrough body"),
241        }
242    }
243
244    #[test]
245    fn test_no_compress_image_png() {
246        let response = make_response_with_headers("PNG data", [("content-type", "image/png")]);
247        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
248
249        // Should be passthrough
250        match wrapped.body() {
251            crate::body::CompressionBody::Passthrough { .. } => {}
252            _ => panic!("Expected passthrough body for image/png"),
253        }
254    }
255
256    #[test]
257    fn test_no_compress_image_jpeg() {
258        let response = make_response_with_headers("JPEG data", [("content-type", "image/jpeg")]);
259        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
260
261        // Should be passthrough
262        match wrapped.body() {
263            crate::body::CompressionBody::Passthrough { .. } => {}
264            _ => panic!("Expected passthrough body for image/jpeg"),
265        }
266    }
267
268    #[test]
269    fn test_no_compress_image_gif() {
270        let response = make_response_with_headers("GIF data", [("content-type", "image/gif")]);
271        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
272
273        // Should be passthrough
274        match wrapped.body() {
275            crate::body::CompressionBody::Passthrough { .. } => {}
276            _ => panic!("Expected passthrough body for image/gif"),
277        }
278    }
279
280    #[test]
281    fn test_no_compress_image_webp() {
282        let response = make_response_with_headers("WebP data", [("content-type", "image/webp")]);
283        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
284
285        // Should be passthrough
286        match wrapped.body() {
287            crate::body::CompressionBody::Passthrough { .. } => {}
288            _ => panic!("Expected passthrough body for image/webp"),
289        }
290    }
291
292    #[test]
293    fn test_compress_image_svg() {
294        let response =
295            make_response_with_headers("<svg></svg>", [("content-type", "image/svg+xml")]);
296        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
297
298        // Should be compressed (SVG is text-based)
299        match wrapped.body() {
300            crate::body::CompressionBody::Compressed { .. } => {}
301            _ => panic!("Expected compressed body for image/svg+xml"),
302        }
303    }
304
305    #[test]
306    fn test_compress_image_svg_with_charset() {
307        let response = make_response_with_headers(
308            "<svg></svg>",
309            [("content-type", "image/svg+xml; charset=utf-8")],
310        );
311        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
312
313        // Should be compressed
314        match wrapped.body() {
315            crate::body::CompressionBody::Compressed { .. } => {}
316            _ => panic!("Expected compressed body for image/svg+xml with charset"),
317        }
318    }
319
320    #[test]
321    fn test_compress_text_html() {
322        let response = make_response_with_headers("<html></html>", [("content-type", "text/html")]);
323        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
324
325        // Should be compressed
326        match wrapped.body() {
327            crate::body::CompressionBody::Compressed { .. } => {}
328            _ => panic!("Expected compressed body for text/html"),
329        }
330    }
331
332    #[test]
333    fn test_no_compress_below_min_size() {
334        let response = make_response_with_headers("small", [("content-length", "5")]);
335        let wrapped = wrap_response(response, Some(Codec::Gzip), 100);
336
337        // Should be passthrough (5 < 100)
338        match wrapped.body() {
339            crate::body::CompressionBody::Passthrough { .. } => {}
340            _ => panic!("Expected passthrough body below min size"),
341        }
342    }
343
344    #[test]
345    fn test_compress_above_min_size() {
346        let response =
347            make_response_with_headers("large enough content", [("content-length", "200")]);
348        let wrapped = wrap_response(response, Some(Codec::Gzip), 100);
349
350        // Should be compressed (200 >= 100)
351        match wrapped.body() {
352            crate::body::CompressionBody::Compressed { .. } => {}
353            _ => panic!("Expected compressed body above min size"),
354        }
355
356        // Content-Length should be removed
357        assert!(wrapped.headers().get(header::CONTENT_LENGTH).is_none());
358    }
359
360    #[test]
361    fn test_compress_unknown_size() {
362        // No Content-Length header means unknown size, should compress
363        let response = make_response("unknown size content");
364        let wrapped = wrap_response(response, Some(Codec::Gzip), 100);
365
366        // Should be compressed (unknown size doesn't trigger min_size check)
367        match wrapped.body() {
368            crate::body::CompressionBody::Compressed { .. } => {}
369            _ => panic!("Expected compressed body for unknown size"),
370        }
371    }
372
373    #[test]
374    fn test_always_flush_when_x_accel_buffering_no() {
375        let response = make_response_with_headers("streaming data", [("x-accel-buffering", "no")]);
376        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
377
378        match wrapped.body() {
379            crate::body::CompressionBody::Compressed { state, .. } => {
380                assert!(state.always_flush());
381            }
382            _ => panic!("Expected compressed body"),
383        }
384    }
385
386    #[test]
387    fn test_no_always_flush_by_default() {
388        let response = make_response("normal data");
389        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
390
391        match wrapped.body() {
392            crate::body::CompressionBody::Compressed { state, .. } => {
393                assert!(!state.always_flush());
394            }
395            _ => panic!("Expected compressed body"),
396        }
397    }
398
399    #[test]
400    fn test_x_accel_buffering_case_insensitive() {
401        let response = make_response_with_headers("streaming data", [("x-accel-buffering", "NO")]);
402        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
403
404        match wrapped.body() {
405            crate::body::CompressionBody::Compressed { state, .. } => {
406                assert!(state.always_flush());
407            }
408            _ => panic!("Expected compressed body"),
409        }
410    }
411
412    #[test]
413    fn test_brotli_content_encoding() {
414        let response = make_response("hello world");
415        let wrapped = wrap_response(response, Some(Codec::Brotli), 0);
416
417        assert_eq!(
418            wrapped.headers().get(header::CONTENT_ENCODING).unwrap(),
419            "br"
420        );
421    }
422
423    #[test]
424    fn test_zstd_content_encoding() {
425        let response = make_response("hello world");
426        let wrapped = wrap_response(response, Some(Codec::Zstd), 0);
427
428        assert_eq!(
429            wrapped.headers().get(header::CONTENT_ENCODING).unwrap(),
430            "zstd"
431        );
432    }
433
434    #[test]
435    fn test_no_compress_application_grpc() {
436        let response =
437            make_response_with_headers("grpc data", [("content-type", "application/grpc")]);
438        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
439
440        // Should be passthrough
441        match wrapped.body() {
442            crate::body::CompressionBody::Passthrough { .. } => {}
443            _ => panic!("Expected passthrough body for application/grpc"),
444        }
445    }
446
447    #[test]
448    fn test_no_compress_application_grpc_with_suffix() {
449        let response =
450            make_response_with_headers("grpc data", [("content-type", "application/grpc+proto")]);
451        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
452
453        // Should be passthrough (starts_with check)
454        match wrapped.body() {
455            crate::body::CompressionBody::Passthrough { .. } => {}
456            _ => panic!("Expected passthrough body for application/grpc+proto"),
457        }
458    }
459
460    #[test]
461    fn test_compress_application_grpc_web() {
462        let response =
463            make_response_with_headers("grpc-web data", [("content-type", "application/grpc-web")]);
464        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
465
466        match wrapped.body() {
467            crate::body::CompressionBody::Compressed { state, .. } => {
468                assert!(state.always_flush());
469            }
470            _ => panic!("Expected compressed body"),
471        }
472    }
473
474    #[test]
475    fn test_compress_application_grpc_web_proto() {
476        let response = make_response_with_headers(
477            "grpc-web data",
478            [("content-type", "application/grpc-web+proto")],
479        );
480        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
481
482        match wrapped.body() {
483            crate::body::CompressionBody::Compressed { state, .. } => {
484                assert!(state.always_flush());
485            }
486            _ => panic!("Expected compressed body"),
487        }
488    }
489
490    #[test]
491    fn test_always_flush_text_event_stream() {
492        let response =
493            make_response_with_headers("event: data\n\n", [("content-type", "text/event-stream")]);
494        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
495
496        match wrapped.body() {
497            crate::body::CompressionBody::Compressed { state, .. } => {
498                assert!(state.always_flush());
499            }
500            _ => panic!("Expected compressed body"),
501        }
502    }
503
504    #[test]
505    fn test_always_flush_text_event_stream_with_charset() {
506        let response = make_response_with_headers(
507            "event: data\n\n",
508            [("content-type", "text/event-stream; charset=utf-8")],
509        );
510        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
511
512        match wrapped.body() {
513            crate::body::CompressionBody::Compressed { state, .. } => {
514                assert!(state.always_flush());
515            }
516            _ => panic!("Expected compressed body"),
517        }
518    }
519
520    #[test]
521    fn test_no_compress_range_response() {
522        let response =
523            make_response_with_headers("partial content", [("content-range", "bytes 0-99/200")]);
524        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
525
526        // Should be passthrough for range responses
527        match wrapped.body() {
528            crate::body::CompressionBody::Passthrough { .. } => {}
529            _ => panic!("Expected passthrough body for range response"),
530        }
531    }
532
533    #[test]
534    fn test_vary_header_added() {
535        let response = make_response("hello world");
536        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
537
538        assert_eq!(
539            wrapped.headers().get(header::VARY).unwrap(),
540            "accept-encoding"
541        );
542    }
543
544    #[test]
545    fn test_vary_header_appended() {
546        let response = make_response_with_headers("hello world", [("vary", "origin")]);
547        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
548
549        // With append, there will be two Vary headers
550        let vary_values: Vec<_> = wrapped
551            .headers()
552            .get_all(header::VARY)
553            .iter()
554            .map(|v| v.to_str().unwrap())
555            .collect();
556        assert_eq!(vary_values, vec!["origin", "accept-encoding"]);
557    }
558
559    #[test]
560    fn test_vary_header_not_duplicated() {
561        let response = make_response_with_headers("hello world", [("vary", "accept-encoding")]);
562        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
563
564        assert_eq!(
565            wrapped.headers().get(header::VARY).unwrap(),
566            "accept-encoding"
567        );
568    }
569
570    #[test]
571    fn test_vary_header_star_not_modified() {
572        let response = make_response_with_headers("hello world", [("vary", "*")]);
573        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
574
575        assert_eq!(wrapped.headers().get(header::VARY).unwrap(), "*");
576    }
577
578    #[test]
579    fn test_accept_ranges_removed() {
580        let response = make_response_with_headers("hello world", [("accept-ranges", "bytes")]);
581        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
582
583        // Accept-Ranges should be removed when compressing
584        assert!(wrapped.headers().get(header::ACCEPT_RANGES).is_none());
585    }
586
587    #[test]
588    fn test_accept_ranges_kept_when_not_compressing() {
589        let response = make_response_with_headers("hello world", [("accept-ranges", "bytes")]);
590        let wrapped = wrap_response(response, None, 0);
591
592        // Accept-Ranges should be kept when not compressing
593        assert_eq!(
594            wrapped.headers().get(header::ACCEPT_RANGES).unwrap(),
595            "bytes"
596        );
597    }
598}