rama_http/layer/compression/
service.rs

1use super::CompressionBody;
2use super::CompressionLevel;
3use super::body::BodyInner;
4use super::predicate::{DefaultPredicate, Predicate};
5use crate::dep::http_body::Body;
6use crate::headers::encoding::{AcceptEncoding, Encoding};
7use crate::layer::util::compression::WrapBody;
8use crate::{Request, Response, header};
9use rama_core::{Context, Service};
10use rama_http_types::HeaderValue;
11use rama_utils::macros::define_inner_service_accessors;
12
13/// Compress response bodies of the underlying service.
14///
15/// This uses the `Accept-Encoding` header to pick an appropriate encoding and adds the
16/// `Content-Encoding` header to responses.
17///
18/// See the [module docs](crate::layer::compression) for more details.
19pub struct Compression<S, P = DefaultPredicate> {
20    pub(crate) inner: S,
21    pub(crate) accept: AcceptEncoding,
22    pub(crate) predicate: P,
23    pub(crate) quality: CompressionLevel,
24}
25
26impl<S, P> std::fmt::Debug for Compression<S, P>
27where
28    S: std::fmt::Debug,
29    P: std::fmt::Debug,
30{
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("Compression")
33            .field("inner", &self.inner)
34            .field("accept", &self.accept)
35            .field("predicate", &self.predicate)
36            .field("quality", &self.quality)
37            .finish()
38    }
39}
40
41impl<S, P> Clone for Compression<S, P>
42where
43    S: Clone,
44    P: Clone,
45{
46    fn clone(&self) -> Self {
47        Self {
48            inner: self.inner.clone(),
49            accept: self.accept,
50            predicate: self.predicate.clone(),
51            quality: self.quality,
52        }
53    }
54}
55
56impl<S> Compression<S, DefaultPredicate> {
57    /// Creates a new `Compression` wrapping the `service`.
58    pub fn new(service: S) -> Compression<S, DefaultPredicate> {
59        Self {
60            inner: service,
61            accept: AcceptEncoding::default(),
62            predicate: DefaultPredicate::default(),
63            quality: CompressionLevel::default(),
64        }
65    }
66}
67
68impl<S, P> Compression<S, P> {
69    define_inner_service_accessors!();
70
71    /// Sets whether to enable the gzip encoding.
72    pub fn gzip(mut self, enable: bool) -> Self {
73        self.accept.set_gzip(enable);
74        self
75    }
76
77    /// Sets whether to enable the gzip encoding.
78    pub fn set_gzip(&mut self, enable: bool) -> &mut Self {
79        self.accept.set_gzip(enable);
80        self
81    }
82
83    /// Sets whether to enable the Deflate encoding.
84    pub fn deflate(mut self, enable: bool) -> Self {
85        self.accept.set_deflate(enable);
86        self
87    }
88
89    /// Sets whether to enable the Deflate encoding.
90    pub fn set_deflate(&mut self, enable: bool) -> &mut Self {
91        self.accept.set_deflate(enable);
92        self
93    }
94
95    /// Sets whether to enable the Brotli encoding.
96    pub fn br(mut self, enable: bool) -> Self {
97        self.accept.set_br(enable);
98        self
99    }
100
101    /// Sets whether to enable the Brotli encoding.
102    pub fn set_br(&mut self, enable: bool) -> &mut Self {
103        self.accept.set_br(enable);
104        self
105    }
106
107    /// Sets whether to enable the Zstd encoding.
108    pub fn zstd(mut self, enable: bool) -> Self {
109        self.accept.set_zstd(enable);
110        self
111    }
112
113    /// Sets whether to enable the Zstd encoding.
114    pub fn set_zstd(&mut self, enable: bool) -> &mut Self {
115        self.accept.set_zstd(enable);
116        self
117    }
118
119    /// Sets the compression quality.
120    pub fn quality(mut self, quality: CompressionLevel) -> Self {
121        self.quality = quality;
122        self
123    }
124
125    /// Sets the compression quality.
126    pub fn set_quality(&mut self, quality: CompressionLevel) -> &mut Self {
127        self.quality = quality;
128        self
129    }
130
131    /// Replace the current compression predicate.
132    ///
133    /// Predicates are used to determine whether a response should be compressed or not.
134    ///
135    /// The default predicate is [`DefaultPredicate`]. See its documentation for more
136    /// details on which responses it wont compress.
137    ///
138    /// # Changing the compression predicate
139    ///
140    /// ```
141    /// use rama_http::layer::compression::{
142    ///     Compression,
143    ///     predicate::{Predicate, NotForContentType, DefaultPredicate},
144    /// };
145    /// use rama_core::service::service_fn;
146    ///
147    /// // Placeholder service_fn
148    /// let service = service_fn(async |_: ()| {
149    ///     Ok::<_, std::io::Error>(rama_http::Response::new(()))
150    /// });
151    ///
152    /// // build our custom compression predicate
153    /// // its recommended to still include `DefaultPredicate` as part of
154    /// // custom predicates
155    /// let predicate = DefaultPredicate::new()
156    ///     // don't compress responses who's `content-type` starts with `application/json`
157    ///     .and(NotForContentType::new("application/json"));
158    ///
159    /// let service = Compression::new(service).compress_when(predicate);
160    /// ```
161    ///
162    /// See [`predicate`](super::predicate) for more utilities for building compression predicates.
163    ///
164    /// Responses that are already compressed (ie have a `content-encoding` header) will _never_ be
165    /// recompressed, regardless what they predicate says.
166    pub fn compress_when<C>(self, predicate: C) -> Compression<S, C>
167    where
168        C: Predicate,
169    {
170        Compression {
171            inner: self.inner,
172            accept: self.accept,
173            predicate,
174            quality: self.quality,
175        }
176    }
177}
178
179impl<ReqBody, ResBody, S, P, State> Service<State, Request<ReqBody>> for Compression<S, P>
180where
181    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
182    ResBody: Body<Data: Send + 'static, Error: Send + 'static> + Send + 'static,
183    P: Predicate + Send + Sync + 'static,
184    ReqBody: Send + 'static,
185    State: Clone + Send + Sync + 'static,
186{
187    type Response = Response<CompressionBody<ResBody>>;
188    type Error = S::Error;
189
190    #[allow(unreachable_code, unused_mut, unused_variables, unreachable_patterns)]
191    async fn serve(
192        &self,
193        ctx: Context<State>,
194        req: Request<ReqBody>,
195    ) -> Result<Self::Response, Self::Error> {
196        let encoding = Encoding::from_accept_encoding_headers(req.headers(), self.accept);
197
198        let res = self.inner.serve(ctx, req).await?;
199
200        // never recompress responses that are already compressed
201        let should_compress = !res.headers().contains_key(header::CONTENT_ENCODING)
202            // never compress responses that are ranges
203            && !res.headers().contains_key(header::CONTENT_RANGE)
204            && self.predicate.should_compress(&res);
205
206        let (mut parts, body) = res.into_parts();
207
208        if should_compress {
209            parts
210                .headers
211                .append(header::VARY, header::ACCEPT_ENCODING.into());
212        }
213
214        let body = match (should_compress, encoding) {
215            // if compression is _not_ supported or the client doesn't accept it
216            (false, _) | (_, Encoding::Identity) => {
217                return Ok(Response::from_parts(
218                    parts,
219                    CompressionBody::new(BodyInner::identity(body)),
220                ));
221            }
222
223            (_, Encoding::Gzip) => {
224                CompressionBody::new(BodyInner::gzip(WrapBody::new(body, self.quality)))
225            }
226            (_, Encoding::Deflate) => {
227                CompressionBody::new(BodyInner::deflate(WrapBody::new(body, self.quality)))
228            }
229            (_, Encoding::Brotli) => {
230                CompressionBody::new(BodyInner::brotli(WrapBody::new(body, self.quality)))
231            }
232            (_, Encoding::Zstd) => {
233                CompressionBody::new(BodyInner::zstd(WrapBody::new(body, self.quality)))
234            }
235            #[allow(unreachable_patterns)]
236            (true, _) => {
237                // This should never happen because the `AcceptEncoding` struct which is used to determine
238                // `self.encoding` will only enable the different compression algorithms if the
239                // corresponding crate feature has been enabled. This means
240                // Encoding::[Gzip|Brotli|Deflate] should be impossible at this point without the
241                // features enabled.
242                //
243                // The match arm is still required though because the `fs` feature uses the
244                // Encoding struct independently and requires no compression logic to be enabled.
245                // This means a combination of an individual compression feature and `fs` will fail
246                // to compile without this branch even though it will never be reached.
247                //
248                // To safeguard against refactors that changes this relationship or other bugs the
249                // server will return an uncompressed response instead of panicking since that could
250                // become a ddos attack vector.
251                return Ok(Response::from_parts(
252                    parts,
253                    CompressionBody::new(BodyInner::identity(body)),
254                ));
255            }
256        };
257
258        parts.headers.remove(header::ACCEPT_RANGES);
259        parts.headers.remove(header::CONTENT_LENGTH);
260
261        parts
262            .headers
263            .insert(header::CONTENT_ENCODING, HeaderValue::from(encoding));
264
265        let res = Response::from_parts(parts, body);
266        Ok(res)
267    }
268}