http_response_compression/
future.rs

1use crate::body::CompressionBody;
2use crate::codec::Codec;
3use http::{Response, header};
4use pin_project_lite::pin_project;
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9pin_project! {
10    /// Future for compression service responses.
11    pub struct ResponseFuture<F> {
12        #[pin]
13        inner: F,
14        accepted_codec: Option<Codec>,
15        min_size: usize,
16    }
17}
18
19impl<F> ResponseFuture<F> {
20    pub(crate) fn new(inner: F, accepted_codec: Option<Codec>, min_size: usize) -> Self {
21        Self {
22            inner,
23            accepted_codec,
24            min_size,
25        }
26    }
27}
28
29impl<F, B, E> Future for ResponseFuture<F>
30where
31    F: Future<Output = Result<Response<B>, E>>,
32{
33    type Output = Result<Response<CompressionBody<B>>, E>;
34
35    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
36        let this = self.project();
37
38        match this.inner.poll(cx) {
39            Poll::Pending => Poll::Pending,
40            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
41            Poll::Ready(Ok(response)) => {
42                let response = wrap_response(response, *this.accepted_codec, *this.min_size);
43                Poll::Ready(Ok(response))
44            }
45        }
46    }
47}
48
49/// Wraps the response body with compression if appropriate.
50fn wrap_response<B>(
51    response: Response<B>,
52    accepted_codec: Option<Codec>,
53    min_size: usize,
54) -> Response<CompressionBody<B>> {
55    let (mut parts, body) = response.into_parts();
56
57    // Determine if we should compress
58    let should_compress = accepted_codec.is_some()
59        && !has_content_encoding(&parts.headers)
60        && !has_content_range(&parts.headers)
61        && !is_uncompressible_content_type(&parts.headers)
62        && !is_below_min_size(&parts.headers, min_size);
63
64    // Check for x-accel-buffering: no header or streaming content types
65    let always_flush = parts
66        .headers
67        .get("x-accel-buffering")
68        .and_then(|v| v.to_str().ok())
69        .is_some_and(|v| v.eq_ignore_ascii_case("no"))
70        || is_streaming_content_type(&parts.headers);
71
72    let body = if should_compress {
73        let codec = accepted_codec.unwrap();
74
75        // Add Content-Encoding header
76        parts.headers.insert(
77            header::CONTENT_ENCODING,
78            header::HeaderValue::from_static(codec.content_encoding()),
79        );
80
81        // Remove Content-Length since compressed size is unknown
82        parts.headers.remove(header::CONTENT_LENGTH);
83
84        // Remove Accept-Ranges since we can't support ranges on compressed content
85        parts.headers.remove(header::ACCEPT_RANGES);
86
87        // Add Accept-Encoding to Vary header if not present
88        add_vary_accept_encoding(&mut parts.headers);
89
90        CompressionBody::compressed(body, codec, always_flush)
91    } else {
92        CompressionBody::passthrough(body)
93    };
94
95    Response::from_parts(parts, body)
96}
97
98/// Checks if Content-Encoding header is already present.
99fn has_content_encoding(headers: &header::HeaderMap) -> bool {
100    headers.contains_key(header::CONTENT_ENCODING)
101}
102
103/// Checks if Content-Range header is present (range response).
104fn has_content_range(headers: &header::HeaderMap) -> bool {
105    headers.contains_key(header::CONTENT_RANGE)
106}
107
108/// Adds Accept-Encoding to the Vary header if not already present.
109fn add_vary_accept_encoding(headers: &mut header::HeaderMap) {
110    // Check all Vary headers to see if Accept-Encoding is already present
111    for vary in headers.get_all(header::VARY) {
112        if let Ok(vary_str) = vary.to_str() {
113            let dominated = vary_str.split(',').any(|v| {
114                let v = v.trim();
115                v.eq_ignore_ascii_case("*") || v.eq_ignore_ascii_case("accept-encoding")
116            });
117            if dominated {
118                return;
119            }
120        }
121    }
122
123    // Append Accept-Encoding to Vary header
124    headers.append(
125        header::VARY,
126        header::HeaderValue::from_static("accept-encoding"),
127    );
128}
129
130/// Checks if the content type should not be compressed.
131fn is_uncompressible_content_type(headers: &header::HeaderMap) -> bool {
132    let Some(content_type) = headers
133        .get(header::CONTENT_TYPE)
134        .and_then(|v| v.to_str().ok())
135    else {
136        return false;
137    };
138
139    let ct = content_type.to_ascii_lowercase();
140
141    // Skip all images except SVG
142    if ct.starts_with("image/") {
143        return !ct.starts_with("image/svg+xml");
144    }
145
146    // Skip gRPC except grpc-web
147    if ct.starts_with("application/grpc") {
148        return !ct.starts_with("application/grpc-web");
149    }
150
151    false
152}
153
154/// Checks if the content type requires always flushing (e.g., streaming).
155fn is_streaming_content_type(headers: &header::HeaderMap) -> bool {
156    headers
157        .get(header::CONTENT_TYPE)
158        .and_then(|v| v.to_str().ok())
159        .is_some_and(|content_type| {
160            let ct = content_type.to_ascii_lowercase();
161            ct.starts_with("text/event-stream") || ct.starts_with("application/grpc-web")
162        })
163}
164
165/// Checks if Content-Length is below the minimum size.
166fn is_below_min_size(headers: &header::HeaderMap, min_size: usize) -> bool {
167    headers
168        .get(header::CONTENT_LENGTH)
169        .and_then(|v| v.to_str().ok())
170        .and_then(|v| v.parse::<usize>().ok())
171        .is_some_and(|len| len < min_size)
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    #[allow(unused_imports)]
178    use crate::body::CompressState;
179
180    fn make_response(body: &'static str) -> Response<&'static str> {
181        Response::new(body)
182    }
183
184    fn make_response_with_headers<I>(body: &'static str, headers: I) -> Response<&'static str>
185    where
186        I: IntoIterator<Item = (&'static str, &'static str)>,
187    {
188        let mut response = Response::new(body);
189        for (name, value) in headers {
190            response
191                .headers_mut()
192                .insert(name, header::HeaderValue::from_static(value));
193        }
194        response
195    }
196
197    #[test]
198    #[cfg(feature = "gzip")]
199    fn test_compress_when_accept_encoding_present() {
200        let response = make_response("hello world");
201        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
202
203        // Should be compressed
204        match wrapped.body() {
205            crate::body::CompressionBody::Compressed { state, .. } => {
206                assert_eq!(state.state(), CompressState::Reading);
207            }
208            _ => panic!("Expected compressed body"),
209        }
210
211        // Should have Content-Encoding header
212        assert_eq!(
213            wrapped.headers().get(header::CONTENT_ENCODING).unwrap(),
214            "gzip"
215        );
216    }
217
218    #[test]
219    fn test_no_compress_when_no_accept_encoding() {
220        let response = make_response("hello world");
221        let wrapped = wrap_response(response, None, 0);
222
223        // Should be passthrough
224        match wrapped.body() {
225            crate::body::CompressionBody::Passthrough { .. } => {}
226            _ => panic!("Expected passthrough body"),
227        }
228
229        // Should NOT have Content-Encoding header
230        assert!(wrapped.headers().get(header::CONTENT_ENCODING).is_none());
231    }
232
233    #[test]
234    #[cfg(feature = "gzip")]
235    fn test_no_compress_when_content_encoding_present() {
236        let response =
237            make_response_with_headers("hello world", [("content-encoding", "identity")]);
238        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
239
240        // Should be passthrough
241        match wrapped.body() {
242            crate::body::CompressionBody::Passthrough { .. } => {}
243            _ => panic!("Expected passthrough body"),
244        }
245    }
246
247    #[test]
248    #[cfg(feature = "gzip")]
249    fn test_no_compress_image_png() {
250        let response = make_response_with_headers("PNG data", [("content-type", "image/png")]);
251        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
252
253        // Should be passthrough
254        match wrapped.body() {
255            crate::body::CompressionBody::Passthrough { .. } => {}
256            _ => panic!("Expected passthrough body for image/png"),
257        }
258    }
259
260    #[test]
261    #[cfg(feature = "gzip")]
262    fn test_no_compress_image_jpeg() {
263        let response = make_response_with_headers("JPEG data", [("content-type", "image/jpeg")]);
264        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
265
266        // Should be passthrough
267        match wrapped.body() {
268            crate::body::CompressionBody::Passthrough { .. } => {}
269            _ => panic!("Expected passthrough body for image/jpeg"),
270        }
271    }
272
273    #[test]
274    #[cfg(feature = "gzip")]
275    fn test_no_compress_image_gif() {
276        let response = make_response_with_headers("GIF data", [("content-type", "image/gif")]);
277        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
278
279        // Should be passthrough
280        match wrapped.body() {
281            crate::body::CompressionBody::Passthrough { .. } => {}
282            _ => panic!("Expected passthrough body for image/gif"),
283        }
284    }
285
286    #[test]
287    #[cfg(feature = "gzip")]
288    fn test_no_compress_image_webp() {
289        let response = make_response_with_headers("WebP data", [("content-type", "image/webp")]);
290        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
291
292        // Should be passthrough
293        match wrapped.body() {
294            crate::body::CompressionBody::Passthrough { .. } => {}
295            _ => panic!("Expected passthrough body for image/webp"),
296        }
297    }
298
299    #[test]
300    #[cfg(feature = "gzip")]
301    fn test_compress_image_svg() {
302        let response =
303            make_response_with_headers("<svg></svg>", [("content-type", "image/svg+xml")]);
304        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
305
306        // Should be compressed (SVG is text-based)
307        match wrapped.body() {
308            crate::body::CompressionBody::Compressed { .. } => {}
309            _ => panic!("Expected compressed body for image/svg+xml"),
310        }
311    }
312
313    #[test]
314    #[cfg(feature = "gzip")]
315    fn test_compress_image_svg_with_charset() {
316        let response = make_response_with_headers(
317            "<svg></svg>",
318            [("content-type", "image/svg+xml; charset=utf-8")],
319        );
320        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
321
322        // Should be compressed
323        match wrapped.body() {
324            crate::body::CompressionBody::Compressed { .. } => {}
325            _ => panic!("Expected compressed body for image/svg+xml with charset"),
326        }
327    }
328
329    #[test]
330    #[cfg(feature = "gzip")]
331    fn test_compress_text_html() {
332        let response = make_response_with_headers("<html></html>", [("content-type", "text/html")]);
333        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
334
335        // Should be compressed
336        match wrapped.body() {
337            crate::body::CompressionBody::Compressed { .. } => {}
338            _ => panic!("Expected compressed body for text/html"),
339        }
340    }
341
342    #[test]
343    #[cfg(feature = "gzip")]
344    fn test_no_compress_below_min_size() {
345        let response = make_response_with_headers("small", [("content-length", "5")]);
346        let wrapped = wrap_response(response, Some(Codec::Gzip), 100);
347
348        // Should be passthrough (5 < 100)
349        match wrapped.body() {
350            crate::body::CompressionBody::Passthrough { .. } => {}
351            _ => panic!("Expected passthrough body below min size"),
352        }
353    }
354
355    #[test]
356    #[cfg(feature = "gzip")]
357    fn test_compress_above_min_size() {
358        let response =
359            make_response_with_headers("large enough content", [("content-length", "200")]);
360        let wrapped = wrap_response(response, Some(Codec::Gzip), 100);
361
362        // Should be compressed (200 >= 100)
363        match wrapped.body() {
364            crate::body::CompressionBody::Compressed { .. } => {}
365            _ => panic!("Expected compressed body above min size"),
366        }
367
368        // Content-Length should be removed
369        assert!(wrapped.headers().get(header::CONTENT_LENGTH).is_none());
370    }
371
372    #[test]
373    #[cfg(feature = "gzip")]
374    fn test_compress_unknown_size() {
375        // No Content-Length header means unknown size, should compress
376        let response = make_response("unknown size content");
377        let wrapped = wrap_response(response, Some(Codec::Gzip), 100);
378
379        // Should be compressed (unknown size doesn't trigger min_size check)
380        match wrapped.body() {
381            crate::body::CompressionBody::Compressed { .. } => {}
382            _ => panic!("Expected compressed body for unknown size"),
383        }
384    }
385
386    #[test]
387    #[cfg(feature = "gzip")]
388    fn test_always_flush_when_x_accel_buffering_no() {
389        let response = make_response_with_headers("streaming data", [("x-accel-buffering", "no")]);
390        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
391
392        match wrapped.body() {
393            crate::body::CompressionBody::Compressed { state, .. } => {
394                assert!(state.always_flush());
395            }
396            _ => panic!("Expected compressed body"),
397        }
398    }
399
400    #[test]
401    #[cfg(feature = "gzip")]
402    fn test_no_always_flush_by_default() {
403        let response = make_response("normal data");
404        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
405
406        match wrapped.body() {
407            crate::body::CompressionBody::Compressed { state, .. } => {
408                assert!(!state.always_flush());
409            }
410            _ => panic!("Expected compressed body"),
411        }
412    }
413
414    #[test]
415    #[cfg(feature = "gzip")]
416    fn test_x_accel_buffering_case_insensitive() {
417        let response = make_response_with_headers("streaming data", [("x-accel-buffering", "NO")]);
418        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
419
420        match wrapped.body() {
421            crate::body::CompressionBody::Compressed { state, .. } => {
422                assert!(state.always_flush());
423            }
424            _ => panic!("Expected compressed body"),
425        }
426    }
427
428    #[test]
429    #[cfg(feature = "brotli")]
430    fn test_brotli_content_encoding() {
431        let response = make_response("hello world");
432        let wrapped = wrap_response(response, Some(Codec::Brotli), 0);
433
434        assert_eq!(
435            wrapped.headers().get(header::CONTENT_ENCODING).unwrap(),
436            "br"
437        );
438    }
439
440    #[test]
441    #[cfg(feature = "zstd")]
442    fn test_zstd_content_encoding() {
443        let response = make_response("hello world");
444        let wrapped = wrap_response(response, Some(Codec::Zstd), 0);
445
446        assert_eq!(
447            wrapped.headers().get(header::CONTENT_ENCODING).unwrap(),
448            "zstd"
449        );
450    }
451
452    #[test]
453    #[cfg(feature = "gzip")]
454    fn test_no_compress_application_grpc() {
455        let response =
456            make_response_with_headers("grpc data", [("content-type", "application/grpc")]);
457        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
458
459        // Should be passthrough
460        match wrapped.body() {
461            crate::body::CompressionBody::Passthrough { .. } => {}
462            _ => panic!("Expected passthrough body for application/grpc"),
463        }
464    }
465
466    #[test]
467    #[cfg(feature = "gzip")]
468    fn test_no_compress_application_grpc_with_suffix() {
469        let response =
470            make_response_with_headers("grpc data", [("content-type", "application/grpc+proto")]);
471        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
472
473        // Should be passthrough (starts_with check)
474        match wrapped.body() {
475            crate::body::CompressionBody::Passthrough { .. } => {}
476            _ => panic!("Expected passthrough body for application/grpc+proto"),
477        }
478    }
479
480    #[test]
481    #[cfg(feature = "gzip")]
482    fn test_compress_application_grpc_web() {
483        let response =
484            make_response_with_headers("grpc-web data", [("content-type", "application/grpc-web")]);
485        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
486
487        match wrapped.body() {
488            crate::body::CompressionBody::Compressed { state, .. } => {
489                assert!(state.always_flush());
490            }
491            _ => panic!("Expected compressed body"),
492        }
493    }
494
495    #[test]
496    #[cfg(feature = "gzip")]
497    fn test_compress_application_grpc_web_proto() {
498        let response = make_response_with_headers(
499            "grpc-web data",
500            [("content-type", "application/grpc-web+proto")],
501        );
502        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
503
504        match wrapped.body() {
505            crate::body::CompressionBody::Compressed { state, .. } => {
506                assert!(state.always_flush());
507            }
508            _ => panic!("Expected compressed body"),
509        }
510    }
511
512    #[test]
513    #[cfg(feature = "gzip")]
514    fn test_always_flush_text_event_stream() {
515        let response =
516            make_response_with_headers("event: data\n\n", [("content-type", "text/event-stream")]);
517        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
518
519        match wrapped.body() {
520            crate::body::CompressionBody::Compressed { state, .. } => {
521                assert!(state.always_flush());
522            }
523            _ => panic!("Expected compressed body"),
524        }
525    }
526
527    #[test]
528    #[cfg(feature = "gzip")]
529    fn test_always_flush_text_event_stream_with_charset() {
530        let response = make_response_with_headers(
531            "event: data\n\n",
532            [("content-type", "text/event-stream; charset=utf-8")],
533        );
534        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
535
536        match wrapped.body() {
537            crate::body::CompressionBody::Compressed { state, .. } => {
538                assert!(state.always_flush());
539            }
540            _ => panic!("Expected compressed body"),
541        }
542    }
543
544    #[test]
545    #[cfg(feature = "gzip")]
546    fn test_no_compress_range_response() {
547        let response =
548            make_response_with_headers("partial content", [("content-range", "bytes 0-99/200")]);
549        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
550
551        // Should be passthrough for range responses
552        match wrapped.body() {
553            crate::body::CompressionBody::Passthrough { .. } => {}
554            _ => panic!("Expected passthrough body for range response"),
555        }
556    }
557
558    #[test]
559    #[cfg(feature = "gzip")]
560    fn test_vary_header_added() {
561        let response = make_response("hello world");
562        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
563
564        assert_eq!(
565            wrapped.headers().get(header::VARY).unwrap(),
566            "accept-encoding"
567        );
568    }
569
570    #[test]
571    #[cfg(feature = "gzip")]
572    fn test_vary_header_appended() {
573        let response = make_response_with_headers("hello world", [("vary", "origin")]);
574        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
575
576        // With append, there will be two Vary headers
577        let vary_values: Vec<_> = wrapped
578            .headers()
579            .get_all(header::VARY)
580            .iter()
581            .map(|v| v.to_str().unwrap())
582            .collect();
583        assert_eq!(vary_values, vec!["origin", "accept-encoding"]);
584    }
585
586    #[test]
587    #[cfg(feature = "gzip")]
588    fn test_vary_header_not_duplicated() {
589        let response = make_response_with_headers("hello world", [("vary", "accept-encoding")]);
590        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
591
592        assert_eq!(
593            wrapped.headers().get(header::VARY).unwrap(),
594            "accept-encoding"
595        );
596    }
597
598    #[test]
599    #[cfg(feature = "gzip")]
600    fn test_vary_header_star_not_modified() {
601        let response = make_response_with_headers("hello world", [("vary", "*")]);
602        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
603
604        assert_eq!(wrapped.headers().get(header::VARY).unwrap(), "*");
605    }
606
607    #[test]
608    #[cfg(feature = "gzip")]
609    fn test_accept_ranges_removed() {
610        let response = make_response_with_headers("hello world", [("accept-ranges", "bytes")]);
611        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
612
613        // Accept-Ranges should be removed when compressing
614        assert!(wrapped.headers().get(header::ACCEPT_RANGES).is_none());
615    }
616
617    #[test]
618    fn test_accept_ranges_kept_when_not_compressing() {
619        let response = make_response_with_headers("hello world", [("accept-ranges", "bytes")]);
620        let wrapped = wrap_response(response, None, 0);
621
622        // Accept-Ranges should be kept when not compressing
623        assert_eq!(
624            wrapped.headers().get(header::ACCEPT_RANGES).unwrap(),
625            "bytes"
626        );
627    }
628}