1use crate::body::CompressionBody;
2use crate::codec::Codec;
3use http::{Response, header};
4use pin_project_lite::pin_project;
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9pin_project! {
10 pub struct ResponseFuture<F> {
12 #[pin]
13 inner: F,
14 accepted_codec: Option<Codec>,
15 min_size: usize,
16 }
17}
18
19impl<F> ResponseFuture<F> {
20 pub(crate) fn new(inner: F, accepted_codec: Option<Codec>, min_size: usize) -> Self {
21 Self {
22 inner,
23 accepted_codec,
24 min_size,
25 }
26 }
27}
28
29impl<F, B, E> Future for ResponseFuture<F>
30where
31 F: Future<Output = Result<Response<B>, E>>,
32{
33 type Output = Result<Response<CompressionBody<B>>, E>;
34
35 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
36 let this = self.project();
37
38 match this.inner.poll(cx) {
39 Poll::Pending => Poll::Pending,
40 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
41 Poll::Ready(Ok(response)) => {
42 let response = wrap_response(response, *this.accepted_codec, *this.min_size);
43 Poll::Ready(Ok(response))
44 }
45 }
46 }
47}
48
49fn wrap_response<B>(
51 response: Response<B>,
52 accepted_codec: Option<Codec>,
53 min_size: usize,
54) -> Response<CompressionBody<B>> {
55 let (mut parts, body) = response.into_parts();
56
57 let should_compress = accepted_codec.is_some()
59 && !has_content_encoding(&parts.headers)
60 && !has_content_range(&parts.headers)
61 && !is_uncompressible_content_type(&parts.headers)
62 && !is_below_min_size(&parts.headers, min_size);
63
64 let always_flush = parts
66 .headers
67 .get("x-accel-buffering")
68 .and_then(|v| v.to_str().ok())
69 .is_some_and(|v| v.eq_ignore_ascii_case("no"))
70 || is_streaming_content_type(&parts.headers);
71
72 let body = if should_compress {
73 let codec = accepted_codec.unwrap();
74
75 parts.headers.insert(
77 header::CONTENT_ENCODING,
78 header::HeaderValue::from_static(codec.content_encoding()),
79 );
80
81 parts.headers.remove(header::CONTENT_LENGTH);
83
84 parts.headers.remove(header::ACCEPT_RANGES);
86
87 add_vary_accept_encoding(&mut parts.headers);
89
90 CompressionBody::compressed(body, codec, always_flush)
91 } else {
92 CompressionBody::passthrough(body)
93 };
94
95 Response::from_parts(parts, body)
96}
97
98fn has_content_encoding(headers: &header::HeaderMap) -> bool {
100 headers.contains_key(header::CONTENT_ENCODING)
101}
102
103fn has_content_range(headers: &header::HeaderMap) -> bool {
105 headers.contains_key(header::CONTENT_RANGE)
106}
107
108fn add_vary_accept_encoding(headers: &mut header::HeaderMap) {
110 for vary in headers.get_all(header::VARY) {
112 if let Ok(vary_str) = vary.to_str() {
113 let dominated = vary_str.split(',').any(|v| {
114 let v = v.trim();
115 v.eq_ignore_ascii_case("*") || v.eq_ignore_ascii_case("accept-encoding")
116 });
117 if dominated {
118 return;
119 }
120 }
121 }
122
123 headers.append(
125 header::VARY,
126 header::HeaderValue::from_static("accept-encoding"),
127 );
128}
129
130fn is_uncompressible_content_type(headers: &header::HeaderMap) -> bool {
132 let Some(content_type) = headers
133 .get(header::CONTENT_TYPE)
134 .and_then(|v| v.to_str().ok())
135 else {
136 return false;
137 };
138
139 let ct = content_type.to_ascii_lowercase();
140
141 if ct.starts_with("image/") {
143 return !ct.starts_with("image/svg+xml");
144 }
145
146 if ct.starts_with("application/grpc") {
148 return !ct.starts_with("application/grpc-web");
149 }
150
151 false
152}
153
154fn is_streaming_content_type(headers: &header::HeaderMap) -> bool {
156 headers
157 .get(header::CONTENT_TYPE)
158 .and_then(|v| v.to_str().ok())
159 .is_some_and(|content_type| {
160 let ct = content_type.to_ascii_lowercase();
161 ct.starts_with("text/event-stream") || ct.starts_with("application/grpc-web")
162 })
163}
164
165fn is_below_min_size(headers: &header::HeaderMap, min_size: usize) -> bool {
167 headers
168 .get(header::CONTENT_LENGTH)
169 .and_then(|v| v.to_str().ok())
170 .and_then(|v| v.parse::<usize>().ok())
171 .is_some_and(|len| len < min_size)
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use crate::body::CompressState;
178
179 fn make_response(body: &'static str) -> Response<&'static str> {
180 Response::new(body)
181 }
182
183 fn make_response_with_headers<I>(body: &'static str, headers: I) -> Response<&'static str>
184 where
185 I: IntoIterator<Item = (&'static str, &'static str)>,
186 {
187 let mut response = Response::new(body);
188 for (name, value) in headers {
189 response
190 .headers_mut()
191 .insert(name, header::HeaderValue::from_static(value));
192 }
193 response
194 }
195
196 #[test]
197 fn test_compress_when_accept_encoding_present() {
198 let response = make_response("hello world");
199 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
200
201 match wrapped.body() {
203 crate::body::CompressionBody::Compressed { state, .. } => {
204 assert_eq!(state.state(), CompressState::Reading);
205 }
206 _ => panic!("Expected compressed body"),
207 }
208
209 assert_eq!(
211 wrapped.headers().get(header::CONTENT_ENCODING).unwrap(),
212 "gzip"
213 );
214 }
215
216 #[test]
217 fn test_no_compress_when_no_accept_encoding() {
218 let response = make_response("hello world");
219 let wrapped = wrap_response(response, None, 0);
220
221 match wrapped.body() {
223 crate::body::CompressionBody::Passthrough { .. } => {}
224 _ => panic!("Expected passthrough body"),
225 }
226
227 assert!(wrapped.headers().get(header::CONTENT_ENCODING).is_none());
229 }
230
231 #[test]
232 fn test_no_compress_when_content_encoding_present() {
233 let response =
234 make_response_with_headers("hello world", [("content-encoding", "identity")]);
235 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
236
237 match wrapped.body() {
239 crate::body::CompressionBody::Passthrough { .. } => {}
240 _ => panic!("Expected passthrough body"),
241 }
242 }
243
244 #[test]
245 fn test_no_compress_image_png() {
246 let response = make_response_with_headers("PNG data", [("content-type", "image/png")]);
247 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
248
249 match wrapped.body() {
251 crate::body::CompressionBody::Passthrough { .. } => {}
252 _ => panic!("Expected passthrough body for image/png"),
253 }
254 }
255
256 #[test]
257 fn test_no_compress_image_jpeg() {
258 let response = make_response_with_headers("JPEG data", [("content-type", "image/jpeg")]);
259 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
260
261 match wrapped.body() {
263 crate::body::CompressionBody::Passthrough { .. } => {}
264 _ => panic!("Expected passthrough body for image/jpeg"),
265 }
266 }
267
268 #[test]
269 fn test_no_compress_image_gif() {
270 let response = make_response_with_headers("GIF data", [("content-type", "image/gif")]);
271 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
272
273 match wrapped.body() {
275 crate::body::CompressionBody::Passthrough { .. } => {}
276 _ => panic!("Expected passthrough body for image/gif"),
277 }
278 }
279
280 #[test]
281 fn test_no_compress_image_webp() {
282 let response = make_response_with_headers("WebP data", [("content-type", "image/webp")]);
283 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
284
285 match wrapped.body() {
287 crate::body::CompressionBody::Passthrough { .. } => {}
288 _ => panic!("Expected passthrough body for image/webp"),
289 }
290 }
291
292 #[test]
293 fn test_compress_image_svg() {
294 let response =
295 make_response_with_headers("<svg></svg>", [("content-type", "image/svg+xml")]);
296 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
297
298 match wrapped.body() {
300 crate::body::CompressionBody::Compressed { .. } => {}
301 _ => panic!("Expected compressed body for image/svg+xml"),
302 }
303 }
304
305 #[test]
306 fn test_compress_image_svg_with_charset() {
307 let response = make_response_with_headers(
308 "<svg></svg>",
309 [("content-type", "image/svg+xml; charset=utf-8")],
310 );
311 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
312
313 match wrapped.body() {
315 crate::body::CompressionBody::Compressed { .. } => {}
316 _ => panic!("Expected compressed body for image/svg+xml with charset"),
317 }
318 }
319
320 #[test]
321 fn test_compress_text_html() {
322 let response = make_response_with_headers("<html></html>", [("content-type", "text/html")]);
323 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
324
325 match wrapped.body() {
327 crate::body::CompressionBody::Compressed { .. } => {}
328 _ => panic!("Expected compressed body for text/html"),
329 }
330 }
331
332 #[test]
333 fn test_no_compress_below_min_size() {
334 let response = make_response_with_headers("small", [("content-length", "5")]);
335 let wrapped = wrap_response(response, Some(Codec::Gzip), 100);
336
337 match wrapped.body() {
339 crate::body::CompressionBody::Passthrough { .. } => {}
340 _ => panic!("Expected passthrough body below min size"),
341 }
342 }
343
344 #[test]
345 fn test_compress_above_min_size() {
346 let response =
347 make_response_with_headers("large enough content", [("content-length", "200")]);
348 let wrapped = wrap_response(response, Some(Codec::Gzip), 100);
349
350 match wrapped.body() {
352 crate::body::CompressionBody::Compressed { .. } => {}
353 _ => panic!("Expected compressed body above min size"),
354 }
355
356 assert!(wrapped.headers().get(header::CONTENT_LENGTH).is_none());
358 }
359
360 #[test]
361 fn test_compress_unknown_size() {
362 let response = make_response("unknown size content");
364 let wrapped = wrap_response(response, Some(Codec::Gzip), 100);
365
366 match wrapped.body() {
368 crate::body::CompressionBody::Compressed { .. } => {}
369 _ => panic!("Expected compressed body for unknown size"),
370 }
371 }
372
373 #[test]
374 fn test_always_flush_when_x_accel_buffering_no() {
375 let response = make_response_with_headers("streaming data", [("x-accel-buffering", "no")]);
376 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
377
378 match wrapped.body() {
379 crate::body::CompressionBody::Compressed { state, .. } => {
380 assert!(state.always_flush());
381 }
382 _ => panic!("Expected compressed body"),
383 }
384 }
385
386 #[test]
387 fn test_no_always_flush_by_default() {
388 let response = make_response("normal data");
389 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
390
391 match wrapped.body() {
392 crate::body::CompressionBody::Compressed { state, .. } => {
393 assert!(!state.always_flush());
394 }
395 _ => panic!("Expected compressed body"),
396 }
397 }
398
399 #[test]
400 fn test_x_accel_buffering_case_insensitive() {
401 let response = make_response_with_headers("streaming data", [("x-accel-buffering", "NO")]);
402 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
403
404 match wrapped.body() {
405 crate::body::CompressionBody::Compressed { state, .. } => {
406 assert!(state.always_flush());
407 }
408 _ => panic!("Expected compressed body"),
409 }
410 }
411
412 #[test]
413 fn test_brotli_content_encoding() {
414 let response = make_response("hello world");
415 let wrapped = wrap_response(response, Some(Codec::Brotli), 0);
416
417 assert_eq!(
418 wrapped.headers().get(header::CONTENT_ENCODING).unwrap(),
419 "br"
420 );
421 }
422
423 #[test]
424 fn test_zstd_content_encoding() {
425 let response = make_response("hello world");
426 let wrapped = wrap_response(response, Some(Codec::Zstd), 0);
427
428 assert_eq!(
429 wrapped.headers().get(header::CONTENT_ENCODING).unwrap(),
430 "zstd"
431 );
432 }
433
434 #[test]
435 fn test_no_compress_application_grpc() {
436 let response =
437 make_response_with_headers("grpc data", [("content-type", "application/grpc")]);
438 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
439
440 match wrapped.body() {
442 crate::body::CompressionBody::Passthrough { .. } => {}
443 _ => panic!("Expected passthrough body for application/grpc"),
444 }
445 }
446
447 #[test]
448 fn test_no_compress_application_grpc_with_suffix() {
449 let response =
450 make_response_with_headers("grpc data", [("content-type", "application/grpc+proto")]);
451 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
452
453 match wrapped.body() {
455 crate::body::CompressionBody::Passthrough { .. } => {}
456 _ => panic!("Expected passthrough body for application/grpc+proto"),
457 }
458 }
459
460 #[test]
461 fn test_compress_application_grpc_web() {
462 let response =
463 make_response_with_headers("grpc-web data", [("content-type", "application/grpc-web")]);
464 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
465
466 match wrapped.body() {
467 crate::body::CompressionBody::Compressed { state, .. } => {
468 assert!(state.always_flush());
469 }
470 _ => panic!("Expected compressed body"),
471 }
472 }
473
474 #[test]
475 fn test_compress_application_grpc_web_proto() {
476 let response = make_response_with_headers(
477 "grpc-web data",
478 [("content-type", "application/grpc-web+proto")],
479 );
480 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
481
482 match wrapped.body() {
483 crate::body::CompressionBody::Compressed { state, .. } => {
484 assert!(state.always_flush());
485 }
486 _ => panic!("Expected compressed body"),
487 }
488 }
489
490 #[test]
491 fn test_always_flush_text_event_stream() {
492 let response =
493 make_response_with_headers("event: data\n\n", [("content-type", "text/event-stream")]);
494 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
495
496 match wrapped.body() {
497 crate::body::CompressionBody::Compressed { state, .. } => {
498 assert!(state.always_flush());
499 }
500 _ => panic!("Expected compressed body"),
501 }
502 }
503
504 #[test]
505 fn test_always_flush_text_event_stream_with_charset() {
506 let response = make_response_with_headers(
507 "event: data\n\n",
508 [("content-type", "text/event-stream; charset=utf-8")],
509 );
510 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
511
512 match wrapped.body() {
513 crate::body::CompressionBody::Compressed { state, .. } => {
514 assert!(state.always_flush());
515 }
516 _ => panic!("Expected compressed body"),
517 }
518 }
519
520 #[test]
521 fn test_no_compress_range_response() {
522 let response =
523 make_response_with_headers("partial content", [("content-range", "bytes 0-99/200")]);
524 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
525
526 match wrapped.body() {
528 crate::body::CompressionBody::Passthrough { .. } => {}
529 _ => panic!("Expected passthrough body for range response"),
530 }
531 }
532
533 #[test]
534 fn test_vary_header_added() {
535 let response = make_response("hello world");
536 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
537
538 assert_eq!(
539 wrapped.headers().get(header::VARY).unwrap(),
540 "accept-encoding"
541 );
542 }
543
544 #[test]
545 fn test_vary_header_appended() {
546 let response = make_response_with_headers("hello world", [("vary", "origin")]);
547 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
548
549 let vary_values: Vec<_> = wrapped
551 .headers()
552 .get_all(header::VARY)
553 .iter()
554 .map(|v| v.to_str().unwrap())
555 .collect();
556 assert_eq!(vary_values, vec!["origin", "accept-encoding"]);
557 }
558
559 #[test]
560 fn test_vary_header_not_duplicated() {
561 let response = make_response_with_headers("hello world", [("vary", "accept-encoding")]);
562 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
563
564 assert_eq!(
565 wrapped.headers().get(header::VARY).unwrap(),
566 "accept-encoding"
567 );
568 }
569
570 #[test]
571 fn test_vary_header_star_not_modified() {
572 let response = make_response_with_headers("hello world", [("vary", "*")]);
573 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
574
575 assert_eq!(wrapped.headers().get(header::VARY).unwrap(), "*");
576 }
577
578 #[test]
579 fn test_accept_ranges_removed() {
580 let response = make_response_with_headers("hello world", [("accept-ranges", "bytes")]);
581 let wrapped = wrap_response(response, Some(Codec::Gzip), 0);
582
583 assert!(wrapped.headers().get(header::ACCEPT_RANGES).is_none());
585 }
586
587 #[test]
588 fn test_accept_ranges_kept_when_not_compressing() {
589 let response = make_response_with_headers("hello world", [("accept-ranges", "bytes")]);
590 let wrapped = wrap_response(response, None, 0);
591
592 assert_eq!(
594 wrapped.headers().get(header::ACCEPT_RANGES).unwrap(),
595 "bytes"
596 );
597 }
598}