http_response_compression/
future.rs

1use crate::body::CompressionBody;
2use crate::codec::Codec;
3use http::{Response, header};
4use pin_project_lite::pin_project;
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9pin_project! {
10    /// Future for compression service responses.
11    pub struct ResponseFuture<F> {
12        #[pin]
13        inner: F,
14        accepted_codec: Option<Codec>,
15        min_size: usize,
16    }
17}
18
19impl<F> ResponseFuture<F> {
20    pub(crate) fn new(inner: F, accepted_codec: Option<Codec>, min_size: usize) -> Self {
21        Self {
22            inner,
23            accepted_codec,
24            min_size,
25        }
26    }
27}
28
29impl<F, B, E> Future for ResponseFuture<F>
30where
31    F: Future<Output = Result<Response<B>, E>>,
32{
33    type Output = Result<Response<CompressionBody<B>>, E>;
34
35    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
36        let this = self.project();
37
38        match this.inner.poll(cx) {
39            Poll::Pending => Poll::Pending,
40            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
41            Poll::Ready(Ok(response)) => {
42                let response = wrap_response(response, *this.accepted_codec, *this.min_size);
43                Poll::Ready(Ok(response))
44            }
45        }
46    }
47}
48
49/// Wraps the response body with compression if appropriate.
50fn wrap_response<B>(
51    response: Response<B>,
52    accepted_codec: Option<Codec>,
53    min_size: usize,
54) -> Response<CompressionBody<B>> {
55    let (mut parts, body) = response.into_parts();
56
57    // Determine if we should compress
58    let dominated_codec = accepted_codec.filter(|_| {
59        !has_content_encoding(&parts.headers)
60            && !has_content_range(&parts.headers)
61            && !is_uncompressible_content_type(&parts.headers)
62            && !is_below_min_size(&parts.headers, min_size)
63    });
64
65    let body = if let Some(codec) = dominated_codec {
66        // Check for x-accel-buffering: no header or streaming content types
67        let always_flush = parts
68            .headers
69            .get("x-accel-buffering")
70            .and_then(|v| v.to_str().ok())
71            .is_some_and(|v| v.eq_ignore_ascii_case("no"))
72            || is_streaming_content_type(&parts.headers);
73
74        // Add Content-Encoding header
75        parts.headers.insert(
76            header::CONTENT_ENCODING,
77            header::HeaderValue::from_static(codec.content_encoding()),
78        );
79
80        // Remove Content-Length since compressed size is unknown
81        parts.headers.remove(header::CONTENT_LENGTH);
82
83        // Remove Accept-Ranges since we can't support ranges on compressed content
84        parts.headers.remove(header::ACCEPT_RANGES);
85
86        // Add Accept-Encoding to Vary header if not present
87        add_vary_accept_encoding(&mut parts.headers);
88
89        CompressionBody::compressed(body, codec, always_flush)
90    } else {
91        CompressionBody::passthrough(body)
92    };
93
94    Response::from_parts(parts, body)
95}
96
97/// Checks if Content-Encoding header is already present.
98fn has_content_encoding(headers: &header::HeaderMap) -> bool {
99    headers.contains_key(header::CONTENT_ENCODING)
100}
101
102/// Checks if Content-Range header is present (range response).
103fn has_content_range(headers: &header::HeaderMap) -> bool {
104    headers.contains_key(header::CONTENT_RANGE)
105}
106
107/// Adds Accept-Encoding to the Vary header if not already present.
108fn add_vary_accept_encoding(headers: &mut header::HeaderMap) {
109    // Check all Vary headers to see if Accept-Encoding is already present
110    for vary in headers.get_all(header::VARY) {
111        if let Ok(vary_str) = vary.to_str() {
112            let dominated = vary_str.split(',').any(|v| {
113                let v = v.trim();
114                v.eq_ignore_ascii_case("*") || v.eq_ignore_ascii_case("accept-encoding")
115            });
116            if dominated {
117                return;
118            }
119        }
120    }
121
122    // Append Accept-Encoding to Vary header
123    headers.append(
124        header::VARY,
125        header::HeaderValue::from_static("accept-encoding"),
126    );
127}
128
129/// Checks if the content type should not be compressed.
130fn is_uncompressible_content_type(headers: &header::HeaderMap) -> bool {
131    let Some(content_type) = headers
132        .get(header::CONTENT_TYPE)
133        .and_then(|v| v.to_str().ok())
134    else {
135        return false;
136    };
137
138    // Skip all images except SVG
139    if content_type.starts_with("image/") {
140        return !content_type.starts_with("image/svg+xml");
141    }
142
143    // Skip gRPC except grpc-web
144    if content_type.starts_with("application/grpc") {
145        return !content_type.starts_with("application/grpc-web");
146    }
147
148    false
149}
150
151/// Checks if the content type requires always flushing (e.g., streaming).
152fn is_streaming_content_type(headers: &header::HeaderMap) -> bool {
153    headers
154        .get(header::CONTENT_TYPE)
155        .and_then(|v| v.to_str().ok())
156        .is_some_and(|ct| {
157            ct.starts_with("text/event-stream") || ct.starts_with("application/grpc-web")
158        })
159}
160
161/// Checks if Content-Length is below the minimum size.
162fn is_below_min_size(headers: &header::HeaderMap, min_size: usize) -> bool {
163    headers
164        .get(header::CONTENT_LENGTH)
165        .and_then(|v| v.to_str().ok())
166        .and_then(|v| v.parse::<usize>().ok())
167        .is_some_and(|len| len < min_size)
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    #[allow(unused_imports)]
174    use crate::body::CompressState;
175
176    fn make_response(body: &'static str) -> Response<&'static str> {
177        Response::new(body)
178    }
179
180    fn make_response_with_headers<I>(body: &'static str, headers: I) -> Response<&'static str>
181    where
182        I: IntoIterator<Item = (&'static str, &'static str)>,
183    {
184        let mut response = Response::new(body);
185        for (name, value) in headers {
186            response
187                .headers_mut()
188                .insert(name, header::HeaderValue::from_static(value));
189        }
190        response
191    }
192
193    #[test]
194    #[cfg(feature = "gzip")]
195    fn test_compress_when_accept_encoding_present() {
196        let response = make_response("hello world");
197        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
198
199        // Should be compressed
200        match wrapped.body() {
201            crate::body::CompressionBody::Compressed { state, .. } => {
202                assert_eq!(state.state(), CompressState::Reading);
203            }
204            _ => panic!("Expected compressed body"),
205        }
206
207        // Should have Content-Encoding header
208        assert_eq!(
209            wrapped.headers().get(header::CONTENT_ENCODING).unwrap(),
210            "gzip"
211        );
212    }
213
214    #[test]
215    fn test_no_compress_when_no_accept_encoding() {
216        let response = make_response("hello world");
217        let wrapped = wrap_response(response, None, 0);
218
219        // Should be passthrough
220        match wrapped.body() {
221            crate::body::CompressionBody::Passthrough { .. } => {}
222            _ => panic!("Expected passthrough body"),
223        }
224
225        // Should NOT have Content-Encoding header
226        assert!(wrapped.headers().get(header::CONTENT_ENCODING).is_none());
227    }
228
229    #[test]
230    #[cfg(feature = "gzip")]
231    fn test_no_compress_when_content_encoding_present() {
232        let response =
233            make_response_with_headers("hello world", [("content-encoding", "identity")]);
234        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
235
236        // Should be passthrough
237        match wrapped.body() {
238            crate::body::CompressionBody::Passthrough { .. } => {}
239            _ => panic!("Expected passthrough body"),
240        }
241    }
242
243    #[test]
244    #[cfg(feature = "gzip")]
245    fn test_no_compress_image_png() {
246        let response = make_response_with_headers("PNG data", [("content-type", "image/png")]);
247        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
248
249        // Should be passthrough
250        match wrapped.body() {
251            crate::body::CompressionBody::Passthrough { .. } => {}
252            _ => panic!("Expected passthrough body for image/png"),
253        }
254    }
255
256    #[test]
257    #[cfg(feature = "gzip")]
258    fn test_no_compress_image_jpeg() {
259        let response = make_response_with_headers("JPEG data", [("content-type", "image/jpeg")]);
260        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
261
262        // Should be passthrough
263        match wrapped.body() {
264            crate::body::CompressionBody::Passthrough { .. } => {}
265            _ => panic!("Expected passthrough body for image/jpeg"),
266        }
267    }
268
269    #[test]
270    #[cfg(feature = "gzip")]
271    fn test_no_compress_image_gif() {
272        let response = make_response_with_headers("GIF data", [("content-type", "image/gif")]);
273        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
274
275        // Should be passthrough
276        match wrapped.body() {
277            crate::body::CompressionBody::Passthrough { .. } => {}
278            _ => panic!("Expected passthrough body for image/gif"),
279        }
280    }
281
282    #[test]
283    #[cfg(feature = "gzip")]
284    fn test_no_compress_image_webp() {
285        let response = make_response_with_headers("WebP data", [("content-type", "image/webp")]);
286        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
287
288        // Should be passthrough
289        match wrapped.body() {
290            crate::body::CompressionBody::Passthrough { .. } => {}
291            _ => panic!("Expected passthrough body for image/webp"),
292        }
293    }
294
295    #[test]
296    #[cfg(feature = "gzip")]
297    fn test_compress_image_svg() {
298        let response =
299            make_response_with_headers("<svg></svg>", [("content-type", "image/svg+xml")]);
300        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
301
302        // Should be compressed (SVG is text-based)
303        match wrapped.body() {
304            crate::body::CompressionBody::Compressed { .. } => {}
305            _ => panic!("Expected compressed body for image/svg+xml"),
306        }
307    }
308
309    #[test]
310    #[cfg(feature = "gzip")]
311    fn test_compress_image_svg_with_charset() {
312        let response = make_response_with_headers(
313            "<svg></svg>",
314            [("content-type", "image/svg+xml; charset=utf-8")],
315        );
316        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
317
318        // Should be compressed
319        match wrapped.body() {
320            crate::body::CompressionBody::Compressed { .. } => {}
321            _ => panic!("Expected compressed body for image/svg+xml with charset"),
322        }
323    }
324
325    #[test]
326    #[cfg(feature = "gzip")]
327    fn test_compress_text_html() {
328        let response = make_response_with_headers("<html></html>", [("content-type", "text/html")]);
329        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
330
331        // Should be compressed
332        match wrapped.body() {
333            crate::body::CompressionBody::Compressed { .. } => {}
334            _ => panic!("Expected compressed body for text/html"),
335        }
336    }
337
338    #[test]
339    #[cfg(feature = "gzip")]
340    fn test_no_compress_below_min_size() {
341        let response = make_response_with_headers("small", [("content-length", "5")]);
342        let wrapped = wrap_response(response, Some(Codec::Gzip), 100);
343
344        // Should be passthrough (5 < 100)
345        match wrapped.body() {
346            crate::body::CompressionBody::Passthrough { .. } => {}
347            _ => panic!("Expected passthrough body below min size"),
348        }
349    }
350
351    #[test]
352    #[cfg(feature = "gzip")]
353    fn test_compress_above_min_size() {
354        let response =
355            make_response_with_headers("large enough content", [("content-length", "200")]);
356        let wrapped = wrap_response(response, Some(Codec::Gzip), 100);
357
358        // Should be compressed (200 >= 100)
359        match wrapped.body() {
360            crate::body::CompressionBody::Compressed { .. } => {}
361            _ => panic!("Expected compressed body above min size"),
362        }
363
364        // Content-Length should be removed
365        assert!(wrapped.headers().get(header::CONTENT_LENGTH).is_none());
366    }
367
368    #[test]
369    #[cfg(feature = "gzip")]
370    fn test_compress_unknown_size() {
371        // No Content-Length header means unknown size, should compress
372        let response = make_response("unknown size content");
373        let wrapped = wrap_response(response, Some(Codec::Gzip), 100);
374
375        // Should be compressed (unknown size doesn't trigger min_size check)
376        match wrapped.body() {
377            crate::body::CompressionBody::Compressed { .. } => {}
378            _ => panic!("Expected compressed body for unknown size"),
379        }
380    }
381
382    #[test]
383    #[cfg(feature = "gzip")]
384    fn test_always_flush_when_x_accel_buffering_no() {
385        let response = make_response_with_headers("streaming data", [("x-accel-buffering", "no")]);
386        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
387
388        match wrapped.body() {
389            crate::body::CompressionBody::Compressed { state, .. } => {
390                assert!(state.always_flush());
391            }
392            _ => panic!("Expected compressed body"),
393        }
394    }
395
396    #[test]
397    #[cfg(feature = "gzip")]
398    fn test_no_always_flush_by_default() {
399        let response = make_response("normal data");
400        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
401
402        match wrapped.body() {
403            crate::body::CompressionBody::Compressed { state, .. } => {
404                assert!(!state.always_flush());
405            }
406            _ => panic!("Expected compressed body"),
407        }
408    }
409
410    #[test]
411    #[cfg(feature = "gzip")]
412    fn test_x_accel_buffering_case_insensitive() {
413        let response = make_response_with_headers("streaming data", [("x-accel-buffering", "NO")]);
414        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
415
416        match wrapped.body() {
417            crate::body::CompressionBody::Compressed { state, .. } => {
418                assert!(state.always_flush());
419            }
420            _ => panic!("Expected compressed body"),
421        }
422    }
423
424    #[test]
425    #[cfg(feature = "brotli")]
426    fn test_brotli_content_encoding() {
427        let response = make_response("hello world");
428        let wrapped = wrap_response(response, Some(Codec::Brotli), 0);
429
430        assert_eq!(
431            wrapped.headers().get(header::CONTENT_ENCODING).unwrap(),
432            "br"
433        );
434    }
435
436    #[test]
437    #[cfg(feature = "zstd")]
438    fn test_zstd_content_encoding() {
439        let response = make_response("hello world");
440        let wrapped = wrap_response(response, Some(Codec::Zstd), 0);
441
442        assert_eq!(
443            wrapped.headers().get(header::CONTENT_ENCODING).unwrap(),
444            "zstd"
445        );
446    }
447
448    #[test]
449    #[cfg(feature = "gzip")]
450    fn test_no_compress_application_grpc() {
451        let response =
452            make_response_with_headers("grpc data", [("content-type", "application/grpc")]);
453        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
454
455        // Should be passthrough
456        match wrapped.body() {
457            crate::body::CompressionBody::Passthrough { .. } => {}
458            _ => panic!("Expected passthrough body for application/grpc"),
459        }
460    }
461
462    #[test]
463    #[cfg(feature = "gzip")]
464    fn test_no_compress_application_grpc_with_suffix() {
465        let response =
466            make_response_with_headers("grpc data", [("content-type", "application/grpc+proto")]);
467        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
468
469        // Should be passthrough (starts_with check)
470        match wrapped.body() {
471            crate::body::CompressionBody::Passthrough { .. } => {}
472            _ => panic!("Expected passthrough body for application/grpc+proto"),
473        }
474    }
475
476    #[test]
477    #[cfg(feature = "gzip")]
478    fn test_compress_application_grpc_web() {
479        let response =
480            make_response_with_headers("grpc-web data", [("content-type", "application/grpc-web")]);
481        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
482
483        match wrapped.body() {
484            crate::body::CompressionBody::Compressed { state, .. } => {
485                assert!(state.always_flush());
486            }
487            _ => panic!("Expected compressed body"),
488        }
489    }
490
491    #[test]
492    #[cfg(feature = "gzip")]
493    fn test_compress_application_grpc_web_proto() {
494        let response = make_response_with_headers(
495            "grpc-web data",
496            [("content-type", "application/grpc-web+proto")],
497        );
498        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
499
500        match wrapped.body() {
501            crate::body::CompressionBody::Compressed { state, .. } => {
502                assert!(state.always_flush());
503            }
504            _ => panic!("Expected compressed body"),
505        }
506    }
507
508    #[test]
509    #[cfg(feature = "gzip")]
510    fn test_always_flush_text_event_stream() {
511        let response =
512            make_response_with_headers("event: data\n\n", [("content-type", "text/event-stream")]);
513        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
514
515        match wrapped.body() {
516            crate::body::CompressionBody::Compressed { state, .. } => {
517                assert!(state.always_flush());
518            }
519            _ => panic!("Expected compressed body"),
520        }
521    }
522
523    #[test]
524    #[cfg(feature = "gzip")]
525    fn test_always_flush_text_event_stream_with_charset() {
526        let response = make_response_with_headers(
527            "event: data\n\n",
528            [("content-type", "text/event-stream; charset=utf-8")],
529        );
530        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
531
532        match wrapped.body() {
533            crate::body::CompressionBody::Compressed { state, .. } => {
534                assert!(state.always_flush());
535            }
536            _ => panic!("Expected compressed body"),
537        }
538    }
539
540    #[test]
541    #[cfg(feature = "gzip")]
542    fn test_no_compress_range_response() {
543        let response =
544            make_response_with_headers("partial content", [("content-range", "bytes 0-99/200")]);
545        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
546
547        // Should be passthrough for range responses
548        match wrapped.body() {
549            crate::body::CompressionBody::Passthrough { .. } => {}
550            _ => panic!("Expected passthrough body for range response"),
551        }
552    }
553
554    #[test]
555    #[cfg(feature = "gzip")]
556    fn test_vary_header_added() {
557        let response = make_response("hello world");
558        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
559
560        assert_eq!(
561            wrapped.headers().get(header::VARY).unwrap(),
562            "accept-encoding"
563        );
564    }
565
566    #[test]
567    #[cfg(feature = "gzip")]
568    fn test_vary_header_appended() {
569        let response = make_response_with_headers("hello world", [("vary", "origin")]);
570        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
571
572        // With append, there will be two Vary headers
573        let vary_values: Vec<_> = wrapped
574            .headers()
575            .get_all(header::VARY)
576            .iter()
577            .map(|v| v.to_str().unwrap())
578            .collect();
579        assert_eq!(vary_values, vec!["origin", "accept-encoding"]);
580    }
581
582    #[test]
583    #[cfg(feature = "gzip")]
584    fn test_vary_header_not_duplicated() {
585        let response = make_response_with_headers("hello world", [("vary", "accept-encoding")]);
586        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
587
588        assert_eq!(
589            wrapped.headers().get(header::VARY).unwrap(),
590            "accept-encoding"
591        );
592    }
593
594    #[test]
595    #[cfg(feature = "gzip")]
596    fn test_vary_header_star_not_modified() {
597        let response = make_response_with_headers("hello world", [("vary", "*")]);
598        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
599
600        assert_eq!(wrapped.headers().get(header::VARY).unwrap(), "*");
601    }
602
603    #[test]
604    #[cfg(feature = "gzip")]
605    fn test_accept_ranges_removed() {
606        let response = make_response_with_headers("hello world", [("accept-ranges", "bytes")]);
607        let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
608
609        // Accept-Ranges should be removed when compressing
610        assert!(wrapped.headers().get(header::ACCEPT_RANGES).is_none());
611    }
612
613    #[test]
614    fn test_accept_ranges_kept_when_not_compressing() {
615        let response = make_response_with_headers("hello world", [("accept-ranges", "bytes")]);
616        let wrapped = wrap_response(response, None, 0);
617
618        // Accept-Ranges should be kept when not compressing
619        assert_eq!(
620            wrapped.headers().get(header::ACCEPT_RANGES).unwrap(),
621            "bytes"
622        );
623    }
624}