1use crate::body::CompressionBody;
2use crate::codec::Codec;
3use http::{Response, header};
4use pin_project_lite::pin_project;
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9pin_project! {
10 pub struct ResponseFuture<F> {
12 #[pin]
13 inner: F,
14 accepted_codec: Option<Codec>,
15 min_size: usize,
16 }
17}
18
19impl<F> ResponseFuture<F> {
20 pub(crate) fn new(inner: F, accepted_codec: Option<Codec>, min_size: usize) -> Self {
21 Self {
22 inner,
23 accepted_codec,
24 min_size,
25 }
26 }
27}
28
29impl<F, B, E> Future for ResponseFuture<F>
30where
31 F: Future<Output = Result<Response<B>, E>>,
32{
33 type Output = Result<Response<CompressionBody<B>>, E>;
34
35 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
36 let this = self.project();
37
38 match this.inner.poll(cx) {
39 Poll::Pending => Poll::Pending,
40 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
41 Poll::Ready(Ok(response)) => {
42 let response = wrap_response(response, *this.accepted_codec, *this.min_size);
43 Poll::Ready(Ok(response))
44 }
45 }
46 }
47}
48
49fn wrap_response<B>(
51 response: Response<B>,
52 accepted_codec: Option<Codec>,
53 min_size: usize,
54) -> Response<CompressionBody<B>> {
55 let (mut parts, body) = response.into_parts();
56
57 let should_compress = accepted_codec.is_some()
59 && !has_content_encoding(&parts.headers)
60 && !has_content_range(&parts.headers)
61 && !is_uncompressible_content_type(&parts.headers)
62 && !is_below_min_size(&parts.headers, min_size);
63
64 let always_flush = parts
66 .headers
67 .get("x-accel-buffering")
68 .and_then(|v| v.to_str().ok())
69 .is_some_and(|v| v.eq_ignore_ascii_case("no"))
70 || is_streaming_content_type(&parts.headers);
71
72 let body = if should_compress {
73 let codec = accepted_codec.unwrap();
74
75 parts.headers.insert(
77 header::CONTENT_ENCODING,
78 header::HeaderValue::from_static(codec.content_encoding()),
79 );
80
81 parts.headers.remove(header::CONTENT_LENGTH);
83
84 parts.headers.remove(header::ACCEPT_RANGES);
86
87 add_vary_accept_encoding(&mut parts.headers);
89
90 CompressionBody::compressed(body, codec, always_flush)
91 } else {
92 CompressionBody::passthrough(body)
93 };
94
95 Response::from_parts(parts, body)
96}
97
98fn has_content_encoding(headers: &header::HeaderMap) -> bool {
100 headers.contains_key(header::CONTENT_ENCODING)
101}
102
103fn has_content_range(headers: &header::HeaderMap) -> bool {
105 headers.contains_key(header::CONTENT_RANGE)
106}
107
108fn add_vary_accept_encoding(headers: &mut header::HeaderMap) {
110 for vary in headers.get_all(header::VARY) {
112 if let Ok(vary_str) = vary.to_str() {
113 let dominated = vary_str.split(',').any(|v| {
114 let v = v.trim();
115 v.eq_ignore_ascii_case("*") || v.eq_ignore_ascii_case("accept-encoding")
116 });
117 if dominated {
118 return;
119 }
120 }
121 }
122
123 headers.append(
125 header::VARY,
126 header::HeaderValue::from_static("accept-encoding"),
127 );
128}
129
130fn is_uncompressible_content_type(headers: &header::HeaderMap) -> bool {
132 let Some(content_type) = headers
133 .get(header::CONTENT_TYPE)
134 .and_then(|v| v.to_str().ok())
135 else {
136 return false;
137 };
138
139 let ct = content_type.to_ascii_lowercase();
140
141 if ct.starts_with("image/") {
143 return !ct.starts_with("image/svg+xml");
144 }
145
146 if ct.starts_with("application/grpc") {
148 return !ct.starts_with("application/grpc-web");
149 }
150
151 false
152}
153
154fn is_streaming_content_type(headers: &header::HeaderMap) -> bool {
156 headers
157 .get(header::CONTENT_TYPE)
158 .and_then(|v| v.to_str().ok())
159 .is_some_and(|content_type| {
160 let ct = content_type.to_ascii_lowercase();
161 ct.starts_with("text/event-stream") || ct.starts_with("application/grpc-web")
162 })
163}
164
165fn is_below_min_size(headers: &header::HeaderMap, min_size: usize) -> bool {
167 headers
168 .get(header::CONTENT_LENGTH)
169 .and_then(|v| v.to_str().ok())
170 .and_then(|v| v.parse::<usize>().ok())
171 .is_some_and(|len| len < min_size)
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 #[allow(unused_imports)]
178 use crate::body::CompressState;
179
180 fn make_response(body: &'static str) -> Response<&'static str> {
181 Response::new(body)
182 }
183
184 fn make_response_with_headers<I>(body: &'static str, headers: I) -> Response<&'static str>
185 where
186 I: IntoIterator<Item = (&'static str, &'static str)>,
187 {
188 let mut response = Response::new(body);
189 for (name, value) in headers {
190 response
191 .headers_mut()
192 .insert(name, header::HeaderValue::from_static(value));
193 }
194 response
195 }
196
197 #[test]
198 #[cfg(feature = "gzip")]
199 fn test_compress_when_accept_encoding_present() {
200 let response = make_response("hello world");
201 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
202
203 match wrapped.body() {
205 crate::body::CompressionBody::Compressed { state, .. } => {
206 assert_eq!(state.state(), CompressState::Reading);
207 }
208 _ => panic!("Expected compressed body"),
209 }
210
211 assert_eq!(
213 wrapped.headers().get(header::CONTENT_ENCODING).unwrap(),
214 "gzip"
215 );
216 }
217
218 #[test]
219 fn test_no_compress_when_no_accept_encoding() {
220 let response = make_response("hello world");
221 let wrapped = wrap_response(response, None, 0);
222
223 match wrapped.body() {
225 crate::body::CompressionBody::Passthrough { .. } => {}
226 _ => panic!("Expected passthrough body"),
227 }
228
229 assert!(wrapped.headers().get(header::CONTENT_ENCODING).is_none());
231 }
232
233 #[test]
234 #[cfg(feature = "gzip")]
235 fn test_no_compress_when_content_encoding_present() {
236 let response =
237 make_response_with_headers("hello world", [("content-encoding", "identity")]);
238 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
239
240 match wrapped.body() {
242 crate::body::CompressionBody::Passthrough { .. } => {}
243 _ => panic!("Expected passthrough body"),
244 }
245 }
246
247 #[test]
248 #[cfg(feature = "gzip")]
249 fn test_no_compress_image_png() {
250 let response = make_response_with_headers("PNG data", [("content-type", "image/png")]);
251 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
252
253 match wrapped.body() {
255 crate::body::CompressionBody::Passthrough { .. } => {}
256 _ => panic!("Expected passthrough body for image/png"),
257 }
258 }
259
260 #[test]
261 #[cfg(feature = "gzip")]
262 fn test_no_compress_image_jpeg() {
263 let response = make_response_with_headers("JPEG data", [("content-type", "image/jpeg")]);
264 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
265
266 match wrapped.body() {
268 crate::body::CompressionBody::Passthrough { .. } => {}
269 _ => panic!("Expected passthrough body for image/jpeg"),
270 }
271 }
272
273 #[test]
274 #[cfg(feature = "gzip")]
275 fn test_no_compress_image_gif() {
276 let response = make_response_with_headers("GIF data", [("content-type", "image/gif")]);
277 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
278
279 match wrapped.body() {
281 crate::body::CompressionBody::Passthrough { .. } => {}
282 _ => panic!("Expected passthrough body for image/gif"),
283 }
284 }
285
286 #[test]
287 #[cfg(feature = "gzip")]
288 fn test_no_compress_image_webp() {
289 let response = make_response_with_headers("WebP data", [("content-type", "image/webp")]);
290 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
291
292 match wrapped.body() {
294 crate::body::CompressionBody::Passthrough { .. } => {}
295 _ => panic!("Expected passthrough body for image/webp"),
296 }
297 }
298
299 #[test]
300 #[cfg(feature = "gzip")]
301 fn test_compress_image_svg() {
302 let response =
303 make_response_with_headers("<svg></svg>", [("content-type", "image/svg+xml")]);
304 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
305
306 match wrapped.body() {
308 crate::body::CompressionBody::Compressed { .. } => {}
309 _ => panic!("Expected compressed body for image/svg+xml"),
310 }
311 }
312
313 #[test]
314 #[cfg(feature = "gzip")]
315 fn test_compress_image_svg_with_charset() {
316 let response = make_response_with_headers(
317 "<svg></svg>",
318 [("content-type", "image/svg+xml; charset=utf-8")],
319 );
320 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
321
322 match wrapped.body() {
324 crate::body::CompressionBody::Compressed { .. } => {}
325 _ => panic!("Expected compressed body for image/svg+xml with charset"),
326 }
327 }
328
329 #[test]
330 #[cfg(feature = "gzip")]
331 fn test_compress_text_html() {
332 let response = make_response_with_headers("<html></html>", [("content-type", "text/html")]);
333 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
334
335 match wrapped.body() {
337 crate::body::CompressionBody::Compressed { .. } => {}
338 _ => panic!("Expected compressed body for text/html"),
339 }
340 }
341
342 #[test]
343 #[cfg(feature = "gzip")]
344 fn test_no_compress_below_min_size() {
345 let response = make_response_with_headers("small", [("content-length", "5")]);
346 let wrapped = wrap_response(response, Some(Codec::Gzip), 100);
347
348 match wrapped.body() {
350 crate::body::CompressionBody::Passthrough { .. } => {}
351 _ => panic!("Expected passthrough body below min size"),
352 }
353 }
354
355 #[test]
356 #[cfg(feature = "gzip")]
357 fn test_compress_above_min_size() {
358 let response =
359 make_response_with_headers("large enough content", [("content-length", "200")]);
360 let wrapped = wrap_response(response, Some(Codec::Gzip), 100);
361
362 match wrapped.body() {
364 crate::body::CompressionBody::Compressed { .. } => {}
365 _ => panic!("Expected compressed body above min size"),
366 }
367
368 assert!(wrapped.headers().get(header::CONTENT_LENGTH).is_none());
370 }
371
372 #[test]
373 #[cfg(feature = "gzip")]
374 fn test_compress_unknown_size() {
375 let response = make_response("unknown size content");
377 let wrapped = wrap_response(response, Some(Codec::Gzip), 100);
378
379 match wrapped.body() {
381 crate::body::CompressionBody::Compressed { .. } => {}
382 _ => panic!("Expected compressed body for unknown size"),
383 }
384 }
385
386 #[test]
387 #[cfg(feature = "gzip")]
388 fn test_always_flush_when_x_accel_buffering_no() {
389 let response = make_response_with_headers("streaming data", [("x-accel-buffering", "no")]);
390 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
391
392 match wrapped.body() {
393 crate::body::CompressionBody::Compressed { state, .. } => {
394 assert!(state.always_flush());
395 }
396 _ => panic!("Expected compressed body"),
397 }
398 }
399
400 #[test]
401 #[cfg(feature = "gzip")]
402 fn test_no_always_flush_by_default() {
403 let response = make_response("normal data");
404 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
405
406 match wrapped.body() {
407 crate::body::CompressionBody::Compressed { state, .. } => {
408 assert!(!state.always_flush());
409 }
410 _ => panic!("Expected compressed body"),
411 }
412 }
413
414 #[test]
415 #[cfg(feature = "gzip")]
416 fn test_x_accel_buffering_case_insensitive() {
417 let response = make_response_with_headers("streaming data", [("x-accel-buffering", "NO")]);
418 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
419
420 match wrapped.body() {
421 crate::body::CompressionBody::Compressed { state, .. } => {
422 assert!(state.always_flush());
423 }
424 _ => panic!("Expected compressed body"),
425 }
426 }
427
428 #[test]
429 #[cfg(feature = "brotli")]
430 fn test_brotli_content_encoding() {
431 let response = make_response("hello world");
432 let wrapped = wrap_response(response, Some(Codec::Brotli), 0);
433
434 assert_eq!(
435 wrapped.headers().get(header::CONTENT_ENCODING).unwrap(),
436 "br"
437 );
438 }
439
440 #[test]
441 #[cfg(feature = "zstd")]
442 fn test_zstd_content_encoding() {
443 let response = make_response("hello world");
444 let wrapped = wrap_response(response, Some(Codec::Zstd), 0);
445
446 assert_eq!(
447 wrapped.headers().get(header::CONTENT_ENCODING).unwrap(),
448 "zstd"
449 );
450 }
451
452 #[test]
453 #[cfg(feature = "gzip")]
454 fn test_no_compress_application_grpc() {
455 let response =
456 make_response_with_headers("grpc data", [("content-type", "application/grpc")]);
457 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
458
459 match wrapped.body() {
461 crate::body::CompressionBody::Passthrough { .. } => {}
462 _ => panic!("Expected passthrough body for application/grpc"),
463 }
464 }
465
466 #[test]
467 #[cfg(feature = "gzip")]
468 fn test_no_compress_application_grpc_with_suffix() {
469 let response =
470 make_response_with_headers("grpc data", [("content-type", "application/grpc+proto")]);
471 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
472
473 match wrapped.body() {
475 crate::body::CompressionBody::Passthrough { .. } => {}
476 _ => panic!("Expected passthrough body for application/grpc+proto"),
477 }
478 }
479
480 #[test]
481 #[cfg(feature = "gzip")]
482 fn test_compress_application_grpc_web() {
483 let response =
484 make_response_with_headers("grpc-web data", [("content-type", "application/grpc-web")]);
485 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
486
487 match wrapped.body() {
488 crate::body::CompressionBody::Compressed { state, .. } => {
489 assert!(state.always_flush());
490 }
491 _ => panic!("Expected compressed body"),
492 }
493 }
494
495 #[test]
496 #[cfg(feature = "gzip")]
497 fn test_compress_application_grpc_web_proto() {
498 let response = make_response_with_headers(
499 "grpc-web data",
500 [("content-type", "application/grpc-web+proto")],
501 );
502 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
503
504 match wrapped.body() {
505 crate::body::CompressionBody::Compressed { state, .. } => {
506 assert!(state.always_flush());
507 }
508 _ => panic!("Expected compressed body"),
509 }
510 }
511
512 #[test]
513 #[cfg(feature = "gzip")]
514 fn test_always_flush_text_event_stream() {
515 let response =
516 make_response_with_headers("event: data\n\n", [("content-type", "text/event-stream")]);
517 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
518
519 match wrapped.body() {
520 crate::body::CompressionBody::Compressed { state, .. } => {
521 assert!(state.always_flush());
522 }
523 _ => panic!("Expected compressed body"),
524 }
525 }
526
527 #[test]
528 #[cfg(feature = "gzip")]
529 fn test_always_flush_text_event_stream_with_charset() {
530 let response = make_response_with_headers(
531 "event: data\n\n",
532 [("content-type", "text/event-stream; charset=utf-8")],
533 );
534 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
535
536 match wrapped.body() {
537 crate::body::CompressionBody::Compressed { state, .. } => {
538 assert!(state.always_flush());
539 }
540 _ => panic!("Expected compressed body"),
541 }
542 }
543
544 #[test]
545 #[cfg(feature = "gzip")]
546 fn test_no_compress_range_response() {
547 let response =
548 make_response_with_headers("partial content", [("content-range", "bytes 0-99/200")]);
549 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
550
551 match wrapped.body() {
553 crate::body::CompressionBody::Passthrough { .. } => {}
554 _ => panic!("Expected passthrough body for range response"),
555 }
556 }
557
558 #[test]
559 #[cfg(feature = "gzip")]
560 fn test_vary_header_added() {
561 let response = make_response("hello world");
562 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
563
564 assert_eq!(
565 wrapped.headers().get(header::VARY).unwrap(),
566 "accept-encoding"
567 );
568 }
569
570 #[test]
571 #[cfg(feature = "gzip")]
572 fn test_vary_header_appended() {
573 let response = make_response_with_headers("hello world", [("vary", "origin")]);
574 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
575
576 let vary_values: Vec<_> = wrapped
578 .headers()
579 .get_all(header::VARY)
580 .iter()
581 .map(|v| v.to_str().unwrap())
582 .collect();
583 assert_eq!(vary_values, vec!["origin", "accept-encoding"]);
584 }
585
586 #[test]
587 #[cfg(feature = "gzip")]
588 fn test_vary_header_not_duplicated() {
589 let response = make_response_with_headers("hello world", [("vary", "accept-encoding")]);
590 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
591
592 assert_eq!(
593 wrapped.headers().get(header::VARY).unwrap(),
594 "accept-encoding"
595 );
596 }
597
598 #[test]
599 #[cfg(feature = "gzip")]
600 fn test_vary_header_star_not_modified() {
601 let response = make_response_with_headers("hello world", [("vary", "*")]);
602 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
603
604 assert_eq!(wrapped.headers().get(header::VARY).unwrap(), "*");
605 }
606
607 #[test]
608 #[cfg(feature = "gzip")]
609 fn test_accept_ranges_removed() {
610 let response = make_response_with_headers("hello world", [("accept-ranges", "bytes")]);
611 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
612
613 assert!(wrapped.headers().get(header::ACCEPT_RANGES).is_none());
615 }
616
617 #[test]
618 fn test_accept_ranges_kept_when_not_compressing() {
619 let response = make_response_with_headers("hello world", [("accept-ranges", "bytes")]);
620 let wrapped = wrap_response(response, None, 0);
621
622 assert_eq!(
624 wrapped.headers().get(header::ACCEPT_RANGES).unwrap(),
625 "bytes"
626 );
627 }
628}