rama_http/layer/compression/
service.rs

1use super::body::BodyInner;
2use super::predicate::{DefaultPredicate, Predicate};
3use super::CompressionBody;
4use super::CompressionLevel;
5use crate::dep::http_body::Body;
6use crate::layer::util::compression::WrapBody;
7use crate::layer::util::{compression::AcceptEncoding, content_encoding::Encoding};
8use crate::{header, Request, Response};
9use rama_core::{Context, Service};
10use rama_utils::macros::define_inner_service_accessors;
11
12/// Compress response bodies of the underlying service.
13///
14/// This uses the `Accept-Encoding` header to pick an appropriate encoding and adds the
15/// `Content-Encoding` header to responses.
16///
17/// See the [module docs](crate::layer::compression) for more details.
18pub struct Compression<S, P = DefaultPredicate> {
19    pub(crate) inner: S,
20    pub(crate) accept: AcceptEncoding,
21    pub(crate) predicate: P,
22    pub(crate) quality: CompressionLevel,
23}
24
25impl<S, P> std::fmt::Debug for Compression<S, P>
26where
27    S: std::fmt::Debug,
28    P: std::fmt::Debug,
29{
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.debug_struct("Compression")
32            .field("inner", &self.inner)
33            .field("accept", &self.accept)
34            .field("predicate", &self.predicate)
35            .field("quality", &self.quality)
36            .finish()
37    }
38}
39
40impl<S, P> Clone for Compression<S, P>
41where
42    S: Clone,
43    P: Clone,
44{
45    fn clone(&self) -> Self {
46        Self {
47            inner: self.inner.clone(),
48            accept: self.accept,
49            predicate: self.predicate.clone(),
50            quality: self.quality,
51        }
52    }
53}
54
55impl<S> Compression<S, DefaultPredicate> {
56    /// Creates a new `Compression` wrapping the `service`.
57    pub fn new(service: S) -> Compression<S, DefaultPredicate> {
58        Self {
59            inner: service,
60            accept: AcceptEncoding::default(),
61            predicate: DefaultPredicate::default(),
62            quality: CompressionLevel::default(),
63        }
64    }
65}
66
67impl<S, P> Compression<S, P> {
68    define_inner_service_accessors!();
69
70    /// Sets whether to enable the gzip encoding.
71    pub fn gzip(mut self, enable: bool) -> Self {
72        self.accept.set_gzip(enable);
73        self
74    }
75
76    /// Sets whether to enable the gzip encoding.
77    pub fn set_gzip(&mut self, enable: bool) -> &mut Self {
78        self.accept.set_gzip(enable);
79        self
80    }
81
82    /// Sets whether to enable the Deflate encoding.
83    pub fn deflate(mut self, enable: bool) -> Self {
84        self.accept.set_deflate(enable);
85        self
86    }
87
88    /// Sets whether to enable the Deflate encoding.
89    pub fn set_deflate(&mut self, enable: bool) -> &mut Self {
90        self.accept.set_deflate(enable);
91        self
92    }
93
94    /// Sets whether to enable the Brotli encoding.
95    pub fn br(mut self, enable: bool) -> Self {
96        self.accept.set_br(enable);
97        self
98    }
99
100    /// Sets whether to enable the Brotli encoding.
101    pub fn set_br(&mut self, enable: bool) -> &mut Self {
102        self.accept.set_br(enable);
103        self
104    }
105
106    /// Sets whether to enable the Zstd encoding.
107    pub fn zstd(mut self, enable: bool) -> Self {
108        self.accept.set_zstd(enable);
109        self
110    }
111
112    /// Sets whether to enable the Zstd encoding.
113    pub fn set_zstd(&mut self, enable: bool) -> &mut Self {
114        self.accept.set_zstd(enable);
115        self
116    }
117
118    /// Sets the compression quality.
119    pub fn quality(mut self, quality: CompressionLevel) -> Self {
120        self.quality = quality;
121        self
122    }
123
124    /// Sets the compression quality.
125    pub fn set_quality(&mut self, quality: CompressionLevel) -> &mut Self {
126        self.quality = quality;
127        self
128    }
129
130    /// Replace the current compression predicate.
131    ///
132    /// Predicates are used to determine whether a response should be compressed or not.
133    ///
134    /// The default predicate is [`DefaultPredicate`]. See its documentation for more
135    /// details on which responses it wont compress.
136    ///
137    /// # Changing the compression predicate
138    ///
139    /// ```
140    /// use rama_http::layer::compression::{
141    ///     Compression,
142    ///     predicate::{Predicate, NotForContentType, DefaultPredicate},
143    /// };
144    /// use rama_core::service::service_fn;
145    ///
146    /// // Placeholder service_fn
147    /// let service = service_fn(|_: ()| async {
148    ///     Ok::<_, std::io::Error>(rama_http::Response::new(()))
149    /// });
150    ///
151    /// // build our custom compression predicate
152    /// // its recommended to still include `DefaultPredicate` as part of
153    /// // custom predicates
154    /// let predicate = DefaultPredicate::new()
155    ///     // don't compress responses who's `content-type` starts with `application/json`
156    ///     .and(NotForContentType::new("application/json"));
157    ///
158    /// let service = Compression::new(service).compress_when(predicate);
159    /// ```
160    ///
161    /// See [`predicate`](super::predicate) for more utilities for building compression predicates.
162    ///
163    /// Responses that are already compressed (ie have a `content-encoding` header) will _never_ be
164    /// recompressed, regardless what they predicate says.
165    pub fn compress_when<C>(self, predicate: C) -> Compression<S, C>
166    where
167        C: Predicate,
168    {
169        Compression {
170            inner: self.inner,
171            accept: self.accept,
172            predicate,
173            quality: self.quality,
174        }
175    }
176}
177
178impl<ReqBody, ResBody, S, P, State> Service<State, Request<ReqBody>> for Compression<S, P>
179where
180    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
181    ResBody: Body<Data: Send + 'static, Error: Send + 'static> + Send + 'static,
182    P: Predicate + Send + Sync + 'static,
183    ReqBody: Send + 'static,
184    State: Clone + Send + Sync + 'static,
185{
186    type Response = Response<CompressionBody<ResBody>>;
187    type Error = S::Error;
188
189    #[allow(unreachable_code, unused_mut, unused_variables, unreachable_patterns)]
190    async fn serve(
191        &self,
192        ctx: Context<State>,
193        req: Request<ReqBody>,
194    ) -> Result<Self::Response, Self::Error> {
195        let encoding = Encoding::from_headers(req.headers(), self.accept);
196
197        let res = self.inner.serve(ctx, req).await?;
198
199        // never recompress responses that are already compressed
200        let should_compress = !res.headers().contains_key(header::CONTENT_ENCODING)
201            // never compress responses that are ranges
202            && !res.headers().contains_key(header::CONTENT_RANGE)
203            && self.predicate.should_compress(&res);
204
205        let (mut parts, body) = res.into_parts();
206
207        if should_compress {
208            parts
209                .headers
210                .append(header::VARY, header::ACCEPT_ENCODING.into());
211        }
212
213        let body = match (should_compress, encoding) {
214            // if compression is _not_ supported or the client doesn't accept it
215            (false, _) | (_, Encoding::Identity) => {
216                return Ok(Response::from_parts(
217                    parts,
218                    CompressionBody::new(BodyInner::identity(body)),
219                ))
220            }
221
222            (_, Encoding::Gzip) => {
223                CompressionBody::new(BodyInner::gzip(WrapBody::new(body, self.quality)))
224            }
225            (_, Encoding::Deflate) => {
226                CompressionBody::new(BodyInner::deflate(WrapBody::new(body, self.quality)))
227            }
228            (_, Encoding::Brotli) => {
229                CompressionBody::new(BodyInner::brotli(WrapBody::new(body, self.quality)))
230            }
231            (_, Encoding::Zstd) => {
232                CompressionBody::new(BodyInner::zstd(WrapBody::new(body, self.quality)))
233            }
234            #[allow(unreachable_patterns)]
235            (true, _) => {
236                // This should never happen because the `AcceptEncoding` struct which is used to determine
237                // `self.encoding` will only enable the different compression algorithms if the
238                // corresponding crate feature has been enabled. This means
239                // Encoding::[Gzip|Brotli|Deflate] should be impossible at this point without the
240                // features enabled.
241                //
242                // The match arm is still required though because the `fs` feature uses the
243                // Encoding struct independently and requires no compression logic to be enabled.
244                // This means a combination of an individual compression feature and `fs` will fail
245                // to compile without this branch even though it will never be reached.
246                //
247                // To safeguard against refactors that changes this relationship or other bugs the
248                // server will return an uncompressed response instead of panicking since that could
249                // become a ddos attack vector.
250                return Ok(Response::from_parts(
251                    parts,
252                    CompressionBody::new(BodyInner::identity(body)),
253                ));
254            }
255        };
256
257        parts.headers.remove(header::ACCEPT_RANGES);
258        parts.headers.remove(header::CONTENT_LENGTH);
259
260        parts
261            .headers
262            .insert(header::CONTENT_ENCODING, encoding.into_header_value());
263
264        let res = Response::from_parts(parts, body);
265        Ok(res)
266    }
267}