rama_http/layer/compression/
mod.rs

1//! Middleware that compresses response bodies.
2//!
3//! # Example
4//!
5//! Example showing how to respond with the compressed contents of a file.
6//!
7//! ```rust
8//! use bytes::Bytes;
9//! use futures_lite::stream::StreamExt;
10//! use rama_core::error::BoxError;
11//! use rama_http::dep::http_body::Frame;
12//! use rama_http::dep::http_body_util::{BodyExt , StreamBody};
13//! use rama_http::dep::http_body_util::combinators::BoxBody as InnerBoxBody;
14//! use rama_http::layer::compression::CompressionLayer;
15//! use rama_http::{Body, Request, Response, header::ACCEPT_ENCODING};
16//! use rama_core::service::service_fn;
17//! use rama_core::{Context, Service, Layer};
18//! use std::convert::Infallible;
19//! use tokio::fs::{self, File};
20//! use tokio_util::io::ReaderStream;
21//!
22//! type BoxBody = InnerBoxBody<Bytes, std::io::Error>;
23//!
24//! # #[tokio::main]
25//! # async fn main() -> Result<(), BoxError> {
26//! async fn handle(req: Request) -> Result<Response<BoxBody>, Infallible> {
27//!     // Open the file.
28//!     let file = File::open("Cargo.toml").await.expect("file missing");
29//!     // Convert the file into a `Stream` of `Bytes`.
30//!     let stream = ReaderStream::new(file);
31//!     // Convert the stream into a stream of data `Frame`s.
32//!     let stream = stream.map(|res| match res {
33//!         Ok(v) => Ok(Frame::data(v)),
34//!         Err(e) => Err(e),
35//!     });
36//!     // Convert the `Stream` into a `Body`.
37//!     let body = StreamBody::new(stream);
38//!     // Erase the type because its very hard to name in the function signature.
39//!     let body = BodyExt::boxed(body);
40//!     // Create response.
41//!     Ok(Response::new(body))
42//! }
43//!
44//! let mut service = (
45//!     // Compress responses based on the `Accept-Encoding` header.
46//!     CompressionLayer::new(),
47//! ).into_layer(service_fn(handle));
48//!
49//! // Call the service.
50//! let request = Request::builder()
51//!     .header(ACCEPT_ENCODING, "gzip")
52//!     .body(Body::default())?;
53//!
54//! let response = service
55//!     .serve(Context::default(), request)
56//!     .await?;
57//!
58//! assert_eq!(response.headers()["content-encoding"], "gzip");
59//!
60//! // Read the body
61//! let bytes = response
62//!     .into_body()
63//!     .collect()
64//!     .await?
65//!     .to_bytes();
66//!
67//! // The compressed body should be smaller 🤞
68//! let uncompressed_len = fs::read_to_string("Cargo.toml").await?.len();
69//! assert!(bytes.len() < uncompressed_len);
70//! #
71//! # Ok(())
72//! # }
73//! ```
74//!
75
76pub mod predicate;
77
78pub(crate) mod body;
79mod layer;
80mod pin_project_cfg;
81mod service;
82
83#[doc(inline)]
84pub use self::{
85    body::CompressionBody,
86    layer::CompressionLayer,
87    predicate::{DefaultPredicate, Predicate},
88    service::Compression,
89};
90#[doc(inline)]
91pub use crate::layer::util::compression::CompressionLevel;
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96
97    use crate::layer::compression::predicate::SizeAbove;
98
99    use crate::dep::http_body::Body as _;
100    use crate::dep::http_body_util::BodyExt;
101    use crate::header::{
102        ACCEPT_ENCODING, ACCEPT_RANGES, CONTENT_ENCODING, CONTENT_RANGE, CONTENT_TYPE, RANGE,
103    };
104    use crate::{Body, HeaderValue, Request, Response};
105    use async_compression::tokio::write::{BrotliDecoder, BrotliEncoder};
106    use flate2::read::GzDecoder;
107    use rama_core::service::service_fn;
108    use rama_core::{Context, Service};
109    use std::convert::Infallible;
110    use std::io::Read;
111    use std::sync::{Arc, RwLock};
112    use tokio::io::{AsyncReadExt, AsyncWriteExt};
113    use tokio_util::io::StreamReader;
114
115    // Compression filter allows every other request to be compressed
116    #[derive(Clone)]
117    struct Always;
118
119    impl Predicate for Always {
120        fn should_compress<B>(&self, _: &rama_http_types::Response<B>) -> bool
121        where
122            B: rama_http_types::dep::http_body::Body,
123        {
124            true
125        }
126    }
127
128    #[tokio::test]
129    async fn gzip_works() {
130        let svc = service_fn(handle);
131        let svc = Compression::new(svc).compress_when(Always);
132
133        // call the service
134        let req = Request::builder()
135            .header("accept-encoding", "gzip")
136            .body(Body::empty())
137            .unwrap();
138        let res = svc.serve(Context::default(), req).await.unwrap();
139
140        // read the compressed body
141        let collected = res.into_body().collect().await.unwrap();
142        let compressed_data = collected.to_bytes();
143
144        // decompress the body
145        // doing this with flate2 as that is much easier than async-compression and blocking during
146        // tests is fine
147        let mut decoder = GzDecoder::new(&compressed_data[..]);
148        let mut decompressed = String::new();
149        decoder.read_to_string(&mut decompressed).unwrap();
150
151        assert_eq!(decompressed, "Hello, World!");
152    }
153
154    #[tokio::test]
155    async fn x_gzip_works() {
156        let svc = service_fn(handle);
157        let svc = Compression::new(svc).compress_when(Always);
158
159        // call the service
160        let req = Request::builder()
161            .header("accept-encoding", "x-gzip")
162            .body(Body::empty())
163            .unwrap();
164        let res = svc.serve(Context::default(), req).await.unwrap();
165
166        // we treat x-gzip as equivalent to gzip and don't have to return x-gzip
167        // taking extra caution by checking all headers with this name
168        assert_eq!(
169            res.headers()
170                .get_all("content-encoding")
171                .iter()
172                .collect::<Vec<&HeaderValue>>(),
173            vec!(HeaderValue::from_static("gzip"))
174        );
175
176        // read the compressed body
177        let collected = res.into_body().collect().await.unwrap();
178        let compressed_data = collected.to_bytes();
179
180        // decompress the body
181        // doing this with flate2 as that is much easier than async-compression and blocking during
182        // tests is fine
183        let mut decoder = GzDecoder::new(&compressed_data[..]);
184        let mut decompressed = String::new();
185        decoder.read_to_string(&mut decompressed).unwrap();
186
187        assert_eq!(decompressed, "Hello, World!");
188    }
189
190    #[tokio::test]
191    async fn zstd_works() {
192        let svc = service_fn(handle);
193        let svc = Compression::new(svc).compress_when(Always);
194
195        // call the service
196        let req = Request::builder()
197            .header("accept-encoding", "zstd")
198            .body(Body::empty())
199            .unwrap();
200        let res = svc.serve(Context::default(), req).await.unwrap();
201
202        // read the compressed body
203        let body = res.into_body();
204        let compressed_data = body.collect().await.unwrap().to_bytes();
205
206        // decompress the body
207        let decompressed = zstd::stream::decode_all(std::io::Cursor::new(compressed_data)).unwrap();
208        let decompressed = String::from_utf8(decompressed).unwrap();
209
210        assert_eq!(decompressed, "Hello, World!");
211    }
212
213    #[tokio::test]
214    async fn no_recompress() {
215        const DATA: &str = "Hello, World! I'm already compressed with br!";
216
217        let svc = service_fn(async |_| {
218            let buf = {
219                let mut buf = Vec::new();
220
221                let mut enc = BrotliEncoder::new(&mut buf);
222                enc.write_all(DATA.as_bytes()).await?;
223                enc.flush().await?;
224                buf
225            };
226
227            let resp = Response::builder()
228                .header("content-encoding", "br")
229                .body(Body::from(buf))
230                .unwrap();
231            Ok::<_, std::io::Error>(resp)
232        });
233        let svc = Compression::new(svc);
234
235        // call the service
236        //
237        // note: the accept-encoding doesn't match the content-encoding above, so that
238        // we're able to see if the compression layer triggered or not
239        let req = Request::builder()
240            .header("accept-encoding", "gzip")
241            .body(Body::empty())
242            .unwrap();
243        let res = svc.serve(Context::default(), req).await.unwrap();
244
245        // check we didn't recompress
246        assert_eq!(
247            res.headers()
248                .get("content-encoding")
249                .and_then(|h| h.to_str().ok())
250                .unwrap_or_default(),
251            "br",
252        );
253
254        // read the compressed body
255        let body = res.into_body();
256        let data = body.collect().await.unwrap().to_bytes();
257
258        // decompress the body
259        let data = {
260            let mut output_buf = Vec::new();
261            let mut decoder = BrotliDecoder::new(&mut output_buf);
262            decoder
263                .write_all(&data)
264                .await
265                .expect("couldn't brotli-decode");
266            decoder.flush().await.expect("couldn't flush");
267            output_buf
268        };
269
270        assert_eq!(data, DATA.as_bytes());
271    }
272
273    async fn handle(_req: Request) -> Result<Response, Infallible> {
274        let body = Body::from("Hello, World!");
275        Ok(Response::builder().body(body).unwrap())
276    }
277
278    #[tokio::test]
279    async fn will_not_compress_if_filtered_out() {
280        use predicate::Predicate;
281
282        const DATA: &str = "Hello world uncompressed";
283
284        let svc_fn = service_fn(async |_| {
285            let resp = Response::builder()
286                // .header("content-encoding", "br")
287                .body(Body::from(DATA.as_bytes()))
288                .unwrap();
289            Ok::<_, std::io::Error>(resp)
290        });
291
292        // Compression filter allows every other request to be compressed
293        #[derive(Default, Clone)]
294        struct EveryOtherResponse(Arc<RwLock<u64>>);
295
296        #[allow(clippy::dbg_macro)]
297        impl Predicate for EveryOtherResponse {
298            fn should_compress<B>(&self, _: &rama_http_types::Response<B>) -> bool
299            where
300                B: rama_http_types::dep::http_body::Body,
301            {
302                let mut guard = self.0.write().unwrap();
303                let should_compress = *guard % 2 != 0;
304                *guard += 1;
305                should_compress
306            }
307        }
308
309        let svc = Compression::new(svc_fn).compress_when(EveryOtherResponse::default());
310        let req = Request::builder()
311            .header("accept-encoding", "br")
312            .body(Body::empty())
313            .unwrap();
314        let res = svc.serve(Context::default(), req).await.unwrap();
315
316        // read the uncompressed body
317        let body = res.into_body();
318        let data = body.collect().await.unwrap().to_bytes();
319        let still_uncompressed = String::from_utf8(data.to_vec()).unwrap();
320        assert_eq!(DATA, &still_uncompressed);
321
322        // Compression filter will compress the next body
323        let req = Request::builder()
324            .header("accept-encoding", "br")
325            .body(Body::empty())
326            .unwrap();
327        let res = svc.serve(Context::default(), req).await.unwrap();
328
329        // read the compressed body
330        let body = res.into_body();
331        let data = body.collect().await.unwrap().to_bytes();
332        assert!(String::from_utf8(data.to_vec()).is_err());
333    }
334
335    #[tokio::test]
336    async fn doesnt_compress_images() {
337        async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
338            let mut res = Response::new(Body::from(
339                "a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize),
340            ));
341            res.headers_mut()
342                .insert(CONTENT_TYPE, "image/png".parse().unwrap());
343            Ok(res)
344        }
345
346        let svc = Compression::new(service_fn(handle));
347
348        let res = svc
349            .serve(
350                Context::default(),
351                Request::builder()
352                    .header(ACCEPT_ENCODING, "gzip")
353                    .body(Body::empty())
354                    .unwrap(),
355            )
356            .await
357            .unwrap();
358        assert!(res.headers().get(CONTENT_ENCODING).is_none());
359    }
360
361    #[tokio::test]
362    async fn does_compress_svg() {
363        async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
364            let mut res = Response::new(Body::from(
365                "a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize),
366            ));
367            res.headers_mut()
368                .insert(CONTENT_TYPE, "image/svg+xml".parse().unwrap());
369            Ok(res)
370        }
371
372        let svc = Compression::new(service_fn(handle));
373
374        let res = svc
375            .serve(
376                Context::default(),
377                Request::builder()
378                    .header(ACCEPT_ENCODING, "gzip")
379                    .body(Body::empty())
380                    .unwrap(),
381            )
382            .await
383            .unwrap();
384        assert_eq!(res.headers()[CONTENT_ENCODING], "gzip");
385    }
386
387    #[tokio::test]
388    async fn compress_with_quality() {
389        const DATA: &str = "Check compression quality level! Check compression quality level! Check compression quality level!";
390        let level = CompressionLevel::Best;
391
392        let svc = service_fn(async |_| {
393            let resp = Response::builder()
394                .body(Body::from(DATA.as_bytes()))
395                .unwrap();
396            Ok::<_, std::io::Error>(resp)
397        });
398
399        let svc = Compression::new(svc).quality(level);
400
401        // call the service
402        let req = Request::builder()
403            .header("accept-encoding", "br")
404            .body(Body::empty())
405            .unwrap();
406        let res = svc.serve(Context::default(), req).await.unwrap();
407
408        // read the compressed body
409        let body = res.into_body();
410        let compressed_data = body.collect().await.unwrap().to_bytes();
411
412        // build the compressed body with the same quality level
413        let compressed_with_level = {
414            use async_compression::tokio::bufread::BrotliEncoder;
415
416            let stream = Box::pin(futures_lite::stream::once({
417                Ok::<_, std::io::Error>(DATA.as_bytes())
418            }));
419            let reader = StreamReader::new(stream);
420            let mut enc = BrotliEncoder::with_quality(reader, level.into_async_compression());
421
422            let mut buf = Vec::new();
423            enc.read_to_end(&mut buf).await.unwrap();
424            buf
425        };
426
427        assert_eq!(
428            compressed_data,
429            compressed_with_level.as_slice(),
430            "Compression level is not respected"
431        );
432    }
433
434    #[tokio::test]
435    async fn should_not_compress_ranges() {
436        let svc = service_fn(async |_| {
437            let mut res = Response::new(Body::from("Hello"));
438            let headers = res.headers_mut();
439            headers.insert(ACCEPT_RANGES, "bytes".parse().unwrap());
440            headers.insert(CONTENT_RANGE, "bytes 0-4/*".parse().unwrap());
441            Ok::<_, std::io::Error>(res)
442        });
443        let svc = Compression::new(svc).compress_when(Always);
444
445        // call the service
446        let req = Request::builder()
447            .header(ACCEPT_ENCODING, "gzip")
448            .header(RANGE, "bytes=0-4")
449            .body(Body::empty())
450            .unwrap();
451        let res = svc.serve(Context::default(), req).await.unwrap();
452        let headers = res.headers().clone();
453
454        // read the uncompressed body
455        let collected = res.into_body().collect().await.unwrap().to_bytes();
456
457        assert_eq!(headers[ACCEPT_RANGES], "bytes");
458        assert!(!headers.contains_key(CONTENT_ENCODING));
459        assert_eq!(collected, "Hello");
460    }
461
462    #[tokio::test]
463    async fn should_strip_accept_ranges_header_when_compressing() {
464        let svc = service_fn(async |_| {
465            let mut res = Response::new(Body::from("Hello, World!"));
466            res.headers_mut()
467                .insert(ACCEPT_RANGES, "bytes".parse().unwrap());
468            Ok::<_, std::io::Error>(res)
469        });
470        let svc = Compression::new(svc).compress_when(Always);
471
472        // call the service
473        let req = Request::builder()
474            .header(ACCEPT_ENCODING, "gzip")
475            .body(Body::empty())
476            .unwrap();
477        let res = svc.serve(Context::default(), req).await.unwrap();
478        let headers = res.headers().clone();
479
480        // read the compressed body
481        let collected = res.into_body().collect().await.unwrap();
482        let compressed_data = collected.to_bytes();
483
484        // decompress the body
485        // doing this with flate2 as that is much easier than async-compression and blocking during
486        // tests is fine
487        let mut decoder = GzDecoder::new(&compressed_data[..]);
488        let mut decompressed = String::new();
489        decoder.read_to_string(&mut decompressed).unwrap();
490
491        assert!(!headers.contains_key(ACCEPT_RANGES));
492        assert_eq!(headers[CONTENT_ENCODING], "gzip");
493        assert_eq!(decompressed, "Hello, World!");
494    }
495
496    #[tokio::test]
497    async fn size_hint_identity() {
498        const MSG: &str = "Hello, world!";
499        let svc = service_fn(async |_| Ok::<_, std::io::Error>(Response::new(Body::from(MSG))));
500        let svc = Compression::new(svc);
501
502        let req = Request::new(Body::empty());
503        let res = svc.serve(Context::default(), req).await.unwrap();
504        let body = res.into_body();
505        assert_eq!(body.size_hint().exact().unwrap(), MSG.len() as u64);
506    }
507}