rama_http/layer/decompression/
service.rs

1use std::fmt;
2
3use super::{body::BodyInner, DecompressionBody};
4use crate::dep::http_body::Body;
5use crate::layer::util::{
6    compression::{AcceptEncoding, CompressionLevel, WrapBody},
7    content_encoding::SupportedEncodings,
8};
9use crate::{
10    header::{self, ACCEPT_ENCODING},
11    Request, Response,
12};
13use rama_core::{Context, Service};
14use rama_utils::macros::define_inner_service_accessors;
15
16/// Decompresses response bodies of the underlying service.
17///
18/// This adds the `Accept-Encoding` header to requests and transparently decompresses response
19/// bodies based on the `Content-Encoding` header.
20///
21/// See the [module docs](crate::layer::decompression) for more details.
22pub struct Decompression<S> {
23    pub(crate) inner: S,
24    pub(crate) accept: AcceptEncoding,
25}
26
27impl<S> Decompression<S> {
28    /// Creates a new `Decompression` wrapping the `service`.
29    pub fn new(service: S) -> Self {
30        Self {
31            inner: service,
32            accept: AcceptEncoding::default(),
33        }
34    }
35
36    define_inner_service_accessors!();
37
38    /// Sets whether to request the gzip encoding.
39    pub fn gzip(mut self, enable: bool) -> Self {
40        self.accept.set_gzip(enable);
41        self
42    }
43
44    /// Sets whether to request the gzip encoding.
45    pub fn set_gzip(&mut self, enable: bool) -> &mut Self {
46        self.accept.set_gzip(enable);
47        self
48    }
49
50    /// Sets whether to request the Deflate encoding.
51    pub fn deflate(mut self, enable: bool) -> Self {
52        self.accept.set_deflate(enable);
53        self
54    }
55
56    /// Sets whether to request the Deflate encoding.
57    pub fn set_deflate(&mut self, enable: bool) -> &mut Self {
58        self.accept.set_deflate(enable);
59        self
60    }
61
62    /// Sets whether to request the Brotli encoding.
63    pub fn br(mut self, enable: bool) -> Self {
64        self.accept.set_br(enable);
65        self
66    }
67
68    /// Sets whether to request the Brotli encoding.
69    pub fn set_br(&mut self, enable: bool) -> &mut Self {
70        self.accept.set_br(enable);
71        self
72    }
73
74    /// Sets whether to request the Zstd encoding.
75    pub fn zstd(mut self, enable: bool) -> Self {
76        self.accept.set_zstd(enable);
77        self
78    }
79
80    /// Sets whether to request the Zstd encoding.
81    pub fn set_zstd(&mut self, enable: bool) -> &mut Self {
82        self.accept.set_zstd(enable);
83        self
84    }
85}
86
87impl<S: fmt::Debug> fmt::Debug for Decompression<S> {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        f.debug_struct("Decompression")
90            .field("inner", &self.inner)
91            .field("accept", &self.accept)
92            .finish()
93    }
94}
95
96impl<S: Clone> Clone for Decompression<S> {
97    fn clone(&self) -> Self {
98        Decompression {
99            inner: self.inner.clone(),
100            accept: self.accept,
101        }
102    }
103}
104
105impl<S, State, ReqBody, ResBody> Service<State, Request<ReqBody>> for Decompression<S>
106where
107    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
108    State: Clone + Send + Sync + 'static,
109    ReqBody: Send + 'static,
110    ResBody: Body<Data: Send + 'static, Error: Send + 'static> + Send + 'static,
111{
112    type Response = Response<DecompressionBody<ResBody>>;
113    type Error = S::Error;
114
115    async fn serve(
116        &self,
117        ctx: Context<State>,
118        mut req: Request<ReqBody>,
119    ) -> Result<Self::Response, Self::Error> {
120        if let header::Entry::Vacant(entry) = req.headers_mut().entry(ACCEPT_ENCODING) {
121            if let Some(accept) = self.accept.to_header_value() {
122                entry.insert(accept);
123            }
124        }
125
126        let res = self.inner.serve(ctx, req).await?;
127
128        let (mut parts, body) = res.into_parts();
129
130        let res =
131            if let header::Entry::Occupied(entry) = parts.headers.entry(header::CONTENT_ENCODING) {
132                let body = match entry.get().as_bytes() {
133                    b"gzip" if self.accept.gzip() => DecompressionBody::new(BodyInner::gzip(
134                        WrapBody::new(body, CompressionLevel::default()),
135                    )),
136
137                    b"deflate" if self.accept.deflate() => DecompressionBody::new(
138                        BodyInner::deflate(WrapBody::new(body, CompressionLevel::default())),
139                    ),
140
141                    b"br" if self.accept.br() => DecompressionBody::new(BodyInner::brotli(
142                        WrapBody::new(body, CompressionLevel::default()),
143                    )),
144
145                    b"zstd" if self.accept.zstd() => DecompressionBody::new(BodyInner::zstd(
146                        WrapBody::new(body, CompressionLevel::default()),
147                    )),
148
149                    _ => {
150                        return Ok(Response::from_parts(
151                            parts,
152                            DecompressionBody::new(BodyInner::identity(body)),
153                        ))
154                    }
155                };
156
157                entry.remove();
158                parts.headers.remove(header::CONTENT_LENGTH);
159
160                Response::from_parts(parts, body)
161            } else {
162                Response::from_parts(parts, DecompressionBody::new(BodyInner::identity(body)))
163            };
164
165        Ok(res)
166    }
167}