rama_http/layer/decompression/
service.rs

1use std::fmt;
2
3use super::{DecompressionBody, body::BodyInner};
4use crate::dep::http_body::Body;
5use crate::headers::encoding::{AcceptEncoding, SupportedEncodings};
6use crate::layer::util::compression::{CompressionLevel, WrapBody};
7use crate::{
8    Request, Response,
9    header::{self, ACCEPT_ENCODING},
10};
11use rama_core::{Context, Service};
12use rama_utils::macros::define_inner_service_accessors;
13
14/// Decompresses response bodies of the underlying service.
15///
16/// This adds the `Accept-Encoding` header to requests and transparently decompresses response
17/// bodies based on the `Content-Encoding` header.
18///
19/// See the [module docs](crate::layer::decompression) for more details.
20pub struct Decompression<S> {
21    pub(crate) inner: S,
22    pub(crate) accept: AcceptEncoding,
23}
24
25impl<S> Decompression<S> {
26    /// Creates a new `Decompression` wrapping the `service`.
27    pub fn new(service: S) -> Self {
28        Self {
29            inner: service,
30            accept: AcceptEncoding::default(),
31        }
32    }
33
34    define_inner_service_accessors!();
35
36    /// Sets whether to request the gzip encoding.
37    pub fn gzip(mut self, enable: bool) -> Self {
38        self.accept.set_gzip(enable);
39        self
40    }
41
42    /// Sets whether to request the gzip encoding.
43    pub fn set_gzip(&mut self, enable: bool) -> &mut Self {
44        self.accept.set_gzip(enable);
45        self
46    }
47
48    /// Sets whether to request the Deflate encoding.
49    pub fn deflate(mut self, enable: bool) -> Self {
50        self.accept.set_deflate(enable);
51        self
52    }
53
54    /// Sets whether to request the Deflate encoding.
55    pub fn set_deflate(&mut self, enable: bool) -> &mut Self {
56        self.accept.set_deflate(enable);
57        self
58    }
59
60    /// Sets whether to request the Brotli encoding.
61    pub fn br(mut self, enable: bool) -> Self {
62        self.accept.set_br(enable);
63        self
64    }
65
66    /// Sets whether to request the Brotli encoding.
67    pub fn set_br(&mut self, enable: bool) -> &mut Self {
68        self.accept.set_br(enable);
69        self
70    }
71
72    /// Sets whether to request the Zstd encoding.
73    pub fn zstd(mut self, enable: bool) -> Self {
74        self.accept.set_zstd(enable);
75        self
76    }
77
78    /// Sets whether to request the Zstd encoding.
79    pub fn set_zstd(&mut self, enable: bool) -> &mut Self {
80        self.accept.set_zstd(enable);
81        self
82    }
83}
84
85impl<S: fmt::Debug> fmt::Debug for Decompression<S> {
86    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87        f.debug_struct("Decompression")
88            .field("inner", &self.inner)
89            .field("accept", &self.accept)
90            .finish()
91    }
92}
93
94impl<S: Clone> Clone for Decompression<S> {
95    fn clone(&self) -> Self {
96        Decompression {
97            inner: self.inner.clone(),
98            accept: self.accept,
99        }
100    }
101}
102
103impl<S, State, ReqBody, ResBody> Service<State, Request<ReqBody>> for Decompression<S>
104where
105    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
106    State: Clone + Send + Sync + 'static,
107    ReqBody: Send + 'static,
108    ResBody: Body<Data: Send + 'static, Error: Send + 'static> + Send + 'static,
109{
110    type Response = Response<DecompressionBody<ResBody>>;
111    type Error = S::Error;
112
113    async fn serve(
114        &self,
115        ctx: Context<State>,
116        mut req: Request<ReqBody>,
117    ) -> Result<Self::Response, Self::Error> {
118        if let header::Entry::Vacant(entry) = req.headers_mut().entry(ACCEPT_ENCODING) {
119            if let Some(accept) = self.accept.maybe_to_header_value() {
120                entry.insert(accept);
121            }
122        }
123
124        let res = self.inner.serve(ctx, req).await?;
125
126        let (mut parts, body) = res.into_parts();
127
128        let res =
129            if let header::Entry::Occupied(entry) = parts.headers.entry(header::CONTENT_ENCODING) {
130                let body = match entry.get().as_bytes() {
131                    b"gzip" if self.accept.gzip() => DecompressionBody::new(BodyInner::gzip(
132                        WrapBody::new(body, CompressionLevel::default()),
133                    )),
134
135                    b"deflate" if self.accept.deflate() => DecompressionBody::new(
136                        BodyInner::deflate(WrapBody::new(body, CompressionLevel::default())),
137                    ),
138
139                    b"br" if self.accept.br() => DecompressionBody::new(BodyInner::brotli(
140                        WrapBody::new(body, CompressionLevel::default()),
141                    )),
142
143                    b"zstd" if self.accept.zstd() => DecompressionBody::new(BodyInner::zstd(
144                        WrapBody::new(body, CompressionLevel::default()),
145                    )),
146
147                    _ => {
148                        return Ok(Response::from_parts(
149                            parts,
150                            DecompressionBody::new(BodyInner::identity(body)),
151                        ));
152                    }
153                };
154
155                entry.remove();
156                parts.headers.remove(header::CONTENT_LENGTH);
157
158                Response::from_parts(parts, body)
159            } else {
160                Response::from_parts(parts, DecompressionBody::new(BodyInner::identity(body)))
161            };
162
163        Ok(res)
164    }
165}