rama_http/layer/compression/
layer.rs

1use super::predicate::DefaultPredicate;
2use super::{Compression, Predicate};
3use crate::layer::util::compression::CompressionLevel;
4use rama_core::Layer;
5use rama_http_types::headers::encoding::AcceptEncoding;
6
7/// Compress response bodies of the underlying service.
8///
9/// This uses the `Accept-Encoding` header to pick an appropriate encoding and adds the
10/// `Content-Encoding` header to responses.
11///
12/// See the [module docs](crate::layer::compression) for more details.
13#[derive(Clone, Debug, Default)]
14pub struct CompressionLayer<P = DefaultPredicate> {
15    accept: AcceptEncoding,
16    predicate: P,
17    quality: CompressionLevel,
18}
19
20impl<S, P> Layer<S> for CompressionLayer<P>
21where
22    P: Predicate,
23{
24    type Service = Compression<S, P>;
25
26    fn layer(&self, inner: S) -> Self::Service {
27        Compression {
28            inner,
29            accept: self.accept,
30            predicate: self.predicate.clone(),
31            quality: self.quality,
32        }
33    }
34
35    fn into_layer(self, inner: S) -> Self::Service {
36        Compression {
37            inner,
38            accept: self.accept,
39            predicate: self.predicate,
40            quality: self.quality,
41        }
42    }
43}
44
45impl CompressionLayer {
46    /// Creates a new [`CompressionLayer`].
47    pub fn new() -> Self {
48        Self::default()
49    }
50
51    /// Sets whether to enable the gzip encoding.
52    pub fn gzip(mut self, enable: bool) -> Self {
53        self.accept.set_gzip(enable);
54        self
55    }
56
57    /// Sets whether to enable the gzip encoding.
58    pub fn set_gzip(&mut self, enable: bool) -> &mut Self {
59        self.accept.set_gzip(enable);
60        self
61    }
62
63    /// Sets whether to enable the Deflate encoding.
64    pub fn deflate(mut self, enable: bool) -> Self {
65        self.accept.set_deflate(enable);
66        self
67    }
68
69    /// Sets whether to enable the Deflate encoding.
70    pub fn set_deflate(&mut self, enable: bool) -> &mut Self {
71        self.accept.set_deflate(enable);
72        self
73    }
74
75    /// Sets whether to enable the Brotli encoding.
76    pub fn br(mut self, enable: bool) -> Self {
77        self.accept.set_br(enable);
78        self
79    }
80
81    /// Sets whether to enable the Brotli encoding.
82    pub fn set_br(&mut self, enable: bool) -> &mut Self {
83        self.accept.set_br(enable);
84        self
85    }
86
87    /// Sets whether to enable the Zstd encoding.
88    pub fn zstd(mut self, enable: bool) -> Self {
89        self.accept.set_zstd(enable);
90        self
91    }
92
93    /// Sets whether to enable the Zstd encoding.
94    pub fn set_zstd(&mut self, enable: bool) -> &mut Self {
95        self.accept.set_zstd(enable);
96        self
97    }
98
99    /// Sets the compression quality.
100    pub fn quality(mut self, quality: CompressionLevel) -> Self {
101        self.quality = quality;
102        self
103    }
104
105    /// Sets the compression quality.
106    pub fn set_quality(&mut self, quality: CompressionLevel) -> &mut Self {
107        self.quality = quality;
108        self
109    }
110
111    /// Replace the current compression predicate.
112    ///
113    /// See [`Compression::compress_when`] for more details.
114    pub fn compress_when<C>(self, predicate: C) -> CompressionLayer<C>
115    where
116        C: Predicate,
117    {
118        CompressionLayer {
119            accept: self.accept,
120            predicate,
121            quality: self.quality,
122        }
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    use crate::dep::http_body_util::BodyExt;
131    use crate::{Body, Request, Response, header::ACCEPT_ENCODING};
132    use rama_core::service::service_fn;
133    use rama_core::{Context, Service};
134    use std::convert::Infallible;
135    use tokio::fs::File;
136    use tokio_util::io::ReaderStream;
137
138    async fn handle(_req: Request) -> Result<Response, Infallible> {
139        // Open the file.
140        let file = File::open("Cargo.toml").await.expect("file missing");
141        // Convert the file into a `Stream`.
142        let stream = ReaderStream::new(file);
143        // Convert the `Stream` into a `Body`.
144        let body = Body::from_stream(stream);
145        // Create response.
146        Ok(Response::new(body))
147    }
148
149    #[tokio::test]
150    async fn accept_encoding_configuration_works() -> Result<(), rama_core::error::BoxError> {
151        let deflate_only_layer = CompressionLayer::new()
152            .quality(CompressionLevel::Best)
153            .br(false)
154            .gzip(false);
155
156        // Compress responses based on the `Accept-Encoding` header.
157        let service = deflate_only_layer.into_layer(service_fn(handle));
158
159        // Call the service with the deflate only layer
160        let request = Request::builder()
161            .header(ACCEPT_ENCODING, "gzip, deflate, br")
162            .body(Body::empty())?;
163
164        let response = service.serve(Context::default(), request).await?;
165
166        assert_eq!(response.headers()["content-encoding"], "deflate");
167
168        // Read the body
169        let body = response.into_body();
170        let bytes = body.collect().await.unwrap().to_bytes();
171
172        let deflate_bytes_len = bytes.len();
173
174        let br_only_layer = CompressionLayer::new()
175            .quality(CompressionLevel::Best)
176            .gzip(false)
177            .deflate(false);
178
179        // Compress responses based on the `Accept-Encoding` header.
180        let service = br_only_layer.into_layer(service_fn(handle));
181
182        // Call the service with the br only layer
183        let request = Request::builder()
184            .header(ACCEPT_ENCODING, "gzip, deflate, br")
185            .body(Body::empty())?;
186
187        let response = service.serve(Context::default(), request).await?;
188
189        assert_eq!(response.headers()["content-encoding"], "br");
190
191        // Read the body
192        let body = response.into_body();
193        let bytes = body.collect().await.unwrap().to_bytes();
194
195        let br_byte_length = bytes.len();
196
197        // check the corresponding algorithms are actually used
198        // br should compresses better than deflate
199        assert!(br_byte_length < deflate_bytes_len);
200
201        Ok(())
202    }
203
204    #[tokio::test]
205    async fn zstd_is_web_safe() -> Result<(), rama_core::error::BoxError> {
206        // Test ensuring that zstd compression will not exceed an 8MiB window size; browsers do not
207        // accept responses using 16MiB+ window sizes.
208
209        async fn zeroes(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
210            Ok(Response::new(Body::from(vec![0u8; 18_874_368])))
211        }
212        // zstd will (I believe) lower its window size if a larger one isn't beneficial and
213        // it knows the size of the input; use an 18MiB body to ensure it would want a
214        // >=16MiB window (though it might not be able to see the input size here).
215
216        let zstd_layer = CompressionLayer::new()
217            .quality(CompressionLevel::Best)
218            .br(false)
219            .deflate(false)
220            .gzip(false);
221
222        let service = zstd_layer.into_layer(service_fn(zeroes));
223
224        let request = Request::builder()
225            .header(ACCEPT_ENCODING, "zstd")
226            .body(Body::empty())?;
227
228        let response = service.serve(Context::default(), request).await?;
229
230        assert_eq!(response.headers()["content-encoding"], "zstd");
231
232        let body = response.into_body();
233        let bytes = body.collect().await?.to_bytes();
234        let mut dec = zstd::Decoder::new(&*bytes)?;
235        dec.window_log_max(23)?; // Limit window size accepted by decoder to 2 ^ 23 bytes (8MiB)
236
237        std::io::copy(&mut dec, &mut std::io::sink())?;
238
239        Ok(())
240    }
241}