1use crate::body::CompressionBody;
2use crate::codec::Codec;
3use http::{Response, header};
4use pin_project_lite::pin_project;
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9pin_project! {
10 pub struct ResponseFuture<F> {
12 #[pin]
13 inner: F,
14 accepted_codec: Option<Codec>,
15 min_size: usize,
16 }
17}
18
19impl<F> ResponseFuture<F> {
20 pub(crate) fn new(inner: F, accepted_codec: Option<Codec>, min_size: usize) -> Self {
21 Self {
22 inner,
23 accepted_codec,
24 min_size,
25 }
26 }
27}
28
29impl<F, B, E> Future for ResponseFuture<F>
30where
31 F: Future<Output = Result<Response<B>, E>>,
32{
33 type Output = Result<Response<CompressionBody<B>>, E>;
34
35 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
36 let this = self.project();
37
38 match this.inner.poll(cx) {
39 Poll::Pending => Poll::Pending,
40 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
41 Poll::Ready(Ok(response)) => {
42 let response = wrap_response(response, *this.accepted_codec, *this.min_size);
43 Poll::Ready(Ok(response))
44 }
45 }
46 }
47}
48
49fn wrap_response<B>(
51 response: Response<B>,
52 accepted_codec: Option<Codec>,
53 min_size: usize,
54) -> Response<CompressionBody<B>> {
55 let (mut parts, body) = response.into_parts();
56
57 let dominated_codec = accepted_codec.filter(|_| {
59 !has_content_encoding(&parts.headers)
60 && !has_content_range(&parts.headers)
61 && !is_uncompressible_content_type(&parts.headers)
62 && !is_below_min_size(&parts.headers, min_size)
63 });
64
65 let body = if let Some(codec) = dominated_codec {
66 let always_flush = parts
68 .headers
69 .get("x-accel-buffering")
70 .and_then(|v| v.to_str().ok())
71 .is_some_and(|v| v.eq_ignore_ascii_case("no"))
72 || is_streaming_content_type(&parts.headers);
73
74 parts.headers.insert(
76 header::CONTENT_ENCODING,
77 header::HeaderValue::from_static(codec.content_encoding()),
78 );
79
80 parts.headers.remove(header::CONTENT_LENGTH);
82
83 parts.headers.remove(header::ACCEPT_RANGES);
85
86 add_vary_accept_encoding(&mut parts.headers);
88
89 CompressionBody::compressed(body, codec, always_flush)
90 } else {
91 CompressionBody::passthrough(body)
92 };
93
94 Response::from_parts(parts, body)
95}
96
97fn has_content_encoding(headers: &header::HeaderMap) -> bool {
99 headers.contains_key(header::CONTENT_ENCODING)
100}
101
102fn has_content_range(headers: &header::HeaderMap) -> bool {
104 headers.contains_key(header::CONTENT_RANGE)
105}
106
107fn add_vary_accept_encoding(headers: &mut header::HeaderMap) {
109 for vary in headers.get_all(header::VARY) {
111 if let Ok(vary_str) = vary.to_str() {
112 let dominated = vary_str.split(',').any(|v| {
113 let v = v.trim();
114 v.eq_ignore_ascii_case("*") || v.eq_ignore_ascii_case("accept-encoding")
115 });
116 if dominated {
117 return;
118 }
119 }
120 }
121
122 headers.append(
124 header::VARY,
125 header::HeaderValue::from_static("accept-encoding"),
126 );
127}
128
129fn is_uncompressible_content_type(headers: &header::HeaderMap) -> bool {
131 let Some(content_type) = headers
132 .get(header::CONTENT_TYPE)
133 .and_then(|v| v.to_str().ok())
134 else {
135 return false;
136 };
137
138 if content_type.starts_with("image/") {
140 return !content_type.starts_with("image/svg+xml");
141 }
142
143 if content_type.starts_with("application/grpc") {
145 return !content_type.starts_with("application/grpc-web");
146 }
147
148 false
149}
150
151fn is_streaming_content_type(headers: &header::HeaderMap) -> bool {
153 headers
154 .get(header::CONTENT_TYPE)
155 .and_then(|v| v.to_str().ok())
156 .is_some_and(|ct| {
157 ct.starts_with("text/event-stream") || ct.starts_with("application/grpc-web")
158 })
159}
160
161fn is_below_min_size(headers: &header::HeaderMap, min_size: usize) -> bool {
163 headers
164 .get(header::CONTENT_LENGTH)
165 .and_then(|v| v.to_str().ok())
166 .and_then(|v| v.parse::<usize>().ok())
167 .is_some_and(|len| len < min_size)
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 #[allow(unused_imports)]
174 use crate::body::CompressState;
175
176 fn make_response(body: &'static str) -> Response<&'static str> {
177 Response::new(body)
178 }
179
180 fn make_response_with_headers<I>(body: &'static str, headers: I) -> Response<&'static str>
181 where
182 I: IntoIterator<Item = (&'static str, &'static str)>,
183 {
184 let mut response = Response::new(body);
185 for (name, value) in headers {
186 response
187 .headers_mut()
188 .insert(name, header::HeaderValue::from_static(value));
189 }
190 response
191 }
192
193 #[test]
194 #[cfg(feature = "gzip")]
195 fn test_compress_when_accept_encoding_present() {
196 let response = make_response("hello world");
197 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
198
199 match wrapped.body() {
201 crate::body::CompressionBody::Compressed { state, .. } => {
202 assert_eq!(state.state(), CompressState::Reading);
203 }
204 _ => panic!("Expected compressed body"),
205 }
206
207 assert_eq!(
209 wrapped.headers().get(header::CONTENT_ENCODING).unwrap(),
210 "gzip"
211 );
212 }
213
214 #[test]
215 fn test_no_compress_when_no_accept_encoding() {
216 let response = make_response("hello world");
217 let wrapped = wrap_response(response, None, 0);
218
219 match wrapped.body() {
221 crate::body::CompressionBody::Passthrough { .. } => {}
222 _ => panic!("Expected passthrough body"),
223 }
224
225 assert!(wrapped.headers().get(header::CONTENT_ENCODING).is_none());
227 }
228
229 #[test]
230 #[cfg(feature = "gzip")]
231 fn test_no_compress_when_content_encoding_present() {
232 let response =
233 make_response_with_headers("hello world", [("content-encoding", "identity")]);
234 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
235
236 match wrapped.body() {
238 crate::body::CompressionBody::Passthrough { .. } => {}
239 _ => panic!("Expected passthrough body"),
240 }
241 }
242
243 #[test]
244 #[cfg(feature = "gzip")]
245 fn test_no_compress_image_png() {
246 let response = make_response_with_headers("PNG data", [("content-type", "image/png")]);
247 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
248
249 match wrapped.body() {
251 crate::body::CompressionBody::Passthrough { .. } => {}
252 _ => panic!("Expected passthrough body for image/png"),
253 }
254 }
255
256 #[test]
257 #[cfg(feature = "gzip")]
258 fn test_no_compress_image_jpeg() {
259 let response = make_response_with_headers("JPEG data", [("content-type", "image/jpeg")]);
260 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
261
262 match wrapped.body() {
264 crate::body::CompressionBody::Passthrough { .. } => {}
265 _ => panic!("Expected passthrough body for image/jpeg"),
266 }
267 }
268
269 #[test]
270 #[cfg(feature = "gzip")]
271 fn test_no_compress_image_gif() {
272 let response = make_response_with_headers("GIF data", [("content-type", "image/gif")]);
273 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
274
275 match wrapped.body() {
277 crate::body::CompressionBody::Passthrough { .. } => {}
278 _ => panic!("Expected passthrough body for image/gif"),
279 }
280 }
281
282 #[test]
283 #[cfg(feature = "gzip")]
284 fn test_no_compress_image_webp() {
285 let response = make_response_with_headers("WebP data", [("content-type", "image/webp")]);
286 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
287
288 match wrapped.body() {
290 crate::body::CompressionBody::Passthrough { .. } => {}
291 _ => panic!("Expected passthrough body for image/webp"),
292 }
293 }
294
295 #[test]
296 #[cfg(feature = "gzip")]
297 fn test_compress_image_svg() {
298 let response =
299 make_response_with_headers("<svg></svg>", [("content-type", "image/svg+xml")]);
300 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
301
302 match wrapped.body() {
304 crate::body::CompressionBody::Compressed { .. } => {}
305 _ => panic!("Expected compressed body for image/svg+xml"),
306 }
307 }
308
309 #[test]
310 #[cfg(feature = "gzip")]
311 fn test_compress_image_svg_with_charset() {
312 let response = make_response_with_headers(
313 "<svg></svg>",
314 [("content-type", "image/svg+xml; charset=utf-8")],
315 );
316 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
317
318 match wrapped.body() {
320 crate::body::CompressionBody::Compressed { .. } => {}
321 _ => panic!("Expected compressed body for image/svg+xml with charset"),
322 }
323 }
324
325 #[test]
326 #[cfg(feature = "gzip")]
327 fn test_compress_text_html() {
328 let response = make_response_with_headers("<html></html>", [("content-type", "text/html")]);
329 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
330
331 match wrapped.body() {
333 crate::body::CompressionBody::Compressed { .. } => {}
334 _ => panic!("Expected compressed body for text/html"),
335 }
336 }
337
338 #[test]
339 #[cfg(feature = "gzip")]
340 fn test_no_compress_below_min_size() {
341 let response = make_response_with_headers("small", [("content-length", "5")]);
342 let wrapped = wrap_response(response, Some(Codec::Gzip), 100);
343
344 match wrapped.body() {
346 crate::body::CompressionBody::Passthrough { .. } => {}
347 _ => panic!("Expected passthrough body below min size"),
348 }
349 }
350
351 #[test]
352 #[cfg(feature = "gzip")]
353 fn test_compress_above_min_size() {
354 let response =
355 make_response_with_headers("large enough content", [("content-length", "200")]);
356 let wrapped = wrap_response(response, Some(Codec::Gzip), 100);
357
358 match wrapped.body() {
360 crate::body::CompressionBody::Compressed { .. } => {}
361 _ => panic!("Expected compressed body above min size"),
362 }
363
364 assert!(wrapped.headers().get(header::CONTENT_LENGTH).is_none());
366 }
367
368 #[test]
369 #[cfg(feature = "gzip")]
370 fn test_compress_unknown_size() {
371 let response = make_response("unknown size content");
373 let wrapped = wrap_response(response, Some(Codec::Gzip), 100);
374
375 match wrapped.body() {
377 crate::body::CompressionBody::Compressed { .. } => {}
378 _ => panic!("Expected compressed body for unknown size"),
379 }
380 }
381
382 #[test]
383 #[cfg(feature = "gzip")]
384 fn test_always_flush_when_x_accel_buffering_no() {
385 let response = make_response_with_headers("streaming data", [("x-accel-buffering", "no")]);
386 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
387
388 match wrapped.body() {
389 crate::body::CompressionBody::Compressed { state, .. } => {
390 assert!(state.always_flush());
391 }
392 _ => panic!("Expected compressed body"),
393 }
394 }
395
396 #[test]
397 #[cfg(feature = "gzip")]
398 fn test_no_always_flush_by_default() {
399 let response = make_response("normal data");
400 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
401
402 match wrapped.body() {
403 crate::body::CompressionBody::Compressed { state, .. } => {
404 assert!(!state.always_flush());
405 }
406 _ => panic!("Expected compressed body"),
407 }
408 }
409
410 #[test]
411 #[cfg(feature = "gzip")]
412 fn test_x_accel_buffering_case_insensitive() {
413 let response = make_response_with_headers("streaming data", [("x-accel-buffering", "NO")]);
414 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
415
416 match wrapped.body() {
417 crate::body::CompressionBody::Compressed { state, .. } => {
418 assert!(state.always_flush());
419 }
420 _ => panic!("Expected compressed body"),
421 }
422 }
423
424 #[test]
425 #[cfg(feature = "brotli")]
426 fn test_brotli_content_encoding() {
427 let response = make_response("hello world");
428 let wrapped = wrap_response(response, Some(Codec::Brotli), 0);
429
430 assert_eq!(
431 wrapped.headers().get(header::CONTENT_ENCODING).unwrap(),
432 "br"
433 );
434 }
435
436 #[test]
437 #[cfg(feature = "zstd")]
438 fn test_zstd_content_encoding() {
439 let response = make_response("hello world");
440 let wrapped = wrap_response(response, Some(Codec::Zstd), 0);
441
442 assert_eq!(
443 wrapped.headers().get(header::CONTENT_ENCODING).unwrap(),
444 "zstd"
445 );
446 }
447
448 #[test]
449 #[cfg(feature = "gzip")]
450 fn test_no_compress_application_grpc() {
451 let response =
452 make_response_with_headers("grpc data", [("content-type", "application/grpc")]);
453 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
454
455 match wrapped.body() {
457 crate::body::CompressionBody::Passthrough { .. } => {}
458 _ => panic!("Expected passthrough body for application/grpc"),
459 }
460 }
461
462 #[test]
463 #[cfg(feature = "gzip")]
464 fn test_no_compress_application_grpc_with_suffix() {
465 let response =
466 make_response_with_headers("grpc data", [("content-type", "application/grpc+proto")]);
467 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
468
469 match wrapped.body() {
471 crate::body::CompressionBody::Passthrough { .. } => {}
472 _ => panic!("Expected passthrough body for application/grpc+proto"),
473 }
474 }
475
476 #[test]
477 #[cfg(feature = "gzip")]
478 fn test_compress_application_grpc_web() {
479 let response =
480 make_response_with_headers("grpc-web data", [("content-type", "application/grpc-web")]);
481 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
482
483 match wrapped.body() {
484 crate::body::CompressionBody::Compressed { state, .. } => {
485 assert!(state.always_flush());
486 }
487 _ => panic!("Expected compressed body"),
488 }
489 }
490
491 #[test]
492 #[cfg(feature = "gzip")]
493 fn test_compress_application_grpc_web_proto() {
494 let response = make_response_with_headers(
495 "grpc-web data",
496 [("content-type", "application/grpc-web+proto")],
497 );
498 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
499
500 match wrapped.body() {
501 crate::body::CompressionBody::Compressed { state, .. } => {
502 assert!(state.always_flush());
503 }
504 _ => panic!("Expected compressed body"),
505 }
506 }
507
508 #[test]
509 #[cfg(feature = "gzip")]
510 fn test_always_flush_text_event_stream() {
511 let response =
512 make_response_with_headers("event: data\n\n", [("content-type", "text/event-stream")]);
513 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
514
515 match wrapped.body() {
516 crate::body::CompressionBody::Compressed { state, .. } => {
517 assert!(state.always_flush());
518 }
519 _ => panic!("Expected compressed body"),
520 }
521 }
522
523 #[test]
524 #[cfg(feature = "gzip")]
525 fn test_always_flush_text_event_stream_with_charset() {
526 let response = make_response_with_headers(
527 "event: data\n\n",
528 [("content-type", "text/event-stream; charset=utf-8")],
529 );
530 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
531
532 match wrapped.body() {
533 crate::body::CompressionBody::Compressed { state, .. } => {
534 assert!(state.always_flush());
535 }
536 _ => panic!("Expected compressed body"),
537 }
538 }
539
540 #[test]
541 #[cfg(feature = "gzip")]
542 fn test_no_compress_range_response() {
543 let response =
544 make_response_with_headers("partial content", [("content-range", "bytes 0-99/200")]);
545 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
546
547 match wrapped.body() {
549 crate::body::CompressionBody::Passthrough { .. } => {}
550 _ => panic!("Expected passthrough body for range response"),
551 }
552 }
553
554 #[test]
555 #[cfg(feature = "gzip")]
556 fn test_vary_header_added() {
557 let response = make_response("hello world");
558 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
559
560 assert_eq!(
561 wrapped.headers().get(header::VARY).unwrap(),
562 "accept-encoding"
563 );
564 }
565
566 #[test]
567 #[cfg(feature = "gzip")]
568 fn test_vary_header_appended() {
569 let response = make_response_with_headers("hello world", [("vary", "origin")]);
570 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
571
572 let vary_values: Vec<_> = wrapped
574 .headers()
575 .get_all(header::VARY)
576 .iter()
577 .map(|v| v.to_str().unwrap())
578 .collect();
579 assert_eq!(vary_values, vec!["origin", "accept-encoding"]);
580 }
581
582 #[test]
583 #[cfg(feature = "gzip")]
584 fn test_vary_header_not_duplicated() {
585 let response = make_response_with_headers("hello world", [("vary", "accept-encoding")]);
586 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
587
588 assert_eq!(
589 wrapped.headers().get(header::VARY).unwrap(),
590 "accept-encoding"
591 );
592 }
593
594 #[test]
595 #[cfg(feature = "gzip")]
596 fn test_vary_header_star_not_modified() {
597 let response = make_response_with_headers("hello world", [("vary", "*")]);
598 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
599
600 assert_eq!(wrapped.headers().get(header::VARY).unwrap(), "*");
601 }
602
603 #[test]
604 #[cfg(feature = "gzip")]
605 fn test_accept_ranges_removed() {
606 let response = make_response_with_headers("hello world", [("accept-ranges", "bytes")]);
607 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
608
609 assert!(wrapped.headers().get(header::ACCEPT_RANGES).is_none());
611 }
612
613 #[test]
614 fn test_accept_ranges_kept_when_not_compressing() {
615 let response = make_response_with_headers("hello world", [("accept-ranges", "bytes")]);
616 let wrapped = wrap_response(response, None, 0);
617
618 assert_eq!(
620 wrapped.headers().get(header::ACCEPT_RANGES).unwrap(),
621 "bytes"
622 );
623 }
624}