rama_http/layer/decompression/request/
service.rs

1use std::fmt;
2
3use crate::dep::http_body::Body;
4use crate::dep::http_body_util::{BodyExt, Empty, combinators::UnsyncBoxBody};
5use crate::headers::encoding::{AcceptEncoding, SupportedEncodings};
6use crate::layer::{
7    decompression::DecompressionBody,
8    decompression::body::BodyInner,
9    util::compression::{CompressionLevel, WrapBody},
10};
11use crate::{HeaderValue, Request, Response, StatusCode, header};
12use bytes::Buf;
13use rama_core::error::BoxError;
14use rama_core::{Context, Service};
15use rama_utils::macros::define_inner_service_accessors;
16
17/// Decompresses request bodies and calls its underlying service.
18///
19/// Transparently decompresses request bodies based on the `Content-Encoding` header.
20/// When the encoding in the `Content-Encoding` header is not accepted an `Unsupported Media Type`
21/// status code will be returned with the accepted encodings in the `Accept-Encoding` header.
22///
23/// Enabling pass-through of unaccepted encodings will not return an `Unsupported Media Type` but
24/// will call the underlying service with the unmodified request if the encoding is not supported.
25/// This is disabled by default.
26///
27/// See the [module docs](crate::layer::decompression) for more details.
28pub struct RequestDecompression<S> {
29    pub(super) inner: S,
30    pub(super) accept: AcceptEncoding,
31    pub(super) pass_through_unaccepted: bool,
32}
33
34impl<S: fmt::Debug> fmt::Debug for RequestDecompression<S> {
35    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36        f.debug_struct("RequestDecompression")
37            .field("inner", &self.inner)
38            .field("accept", &self.accept)
39            .field("pass_through_unaccepted", &self.pass_through_unaccepted)
40            .finish()
41    }
42}
43
44impl<S: Clone> Clone for RequestDecompression<S> {
45    fn clone(&self) -> Self {
46        RequestDecompression {
47            inner: self.inner.clone(),
48            accept: self.accept,
49            pass_through_unaccepted: self.pass_through_unaccepted,
50        }
51    }
52}
53
54impl<S, State, ReqBody, ResBody, D> Service<State, Request<ReqBody>> for RequestDecompression<S>
55where
56    S: Service<
57            State,
58            Request<DecompressionBody<ReqBody>>,
59            Response = Response<ResBody>,
60            Error: Into<BoxError>,
61        >,
62    State: Clone + Send + Sync + 'static,
63    ReqBody: Body + Send + 'static,
64    ResBody: Body<Data = D, Error: Into<BoxError>> + Send + 'static,
65    D: Buf + 'static,
66{
67    type Response = Response<UnsyncBoxBody<D, BoxError>>;
68    type Error = BoxError;
69
70    async fn serve(
71        &self,
72        ctx: Context<State>,
73        req: Request<ReqBody>,
74    ) -> Result<Self::Response, Self::Error> {
75        let (mut parts, body) = req.into_parts();
76
77        let body =
78            if let header::Entry::Occupied(entry) = parts.headers.entry(header::CONTENT_ENCODING) {
79                match entry.get().as_bytes() {
80                    b"gzip" if self.accept.gzip() => {
81                        entry.remove();
82                        parts.headers.remove(header::CONTENT_LENGTH);
83                        BodyInner::gzip(WrapBody::new(body, CompressionLevel::default()))
84                    }
85                    b"deflate" if self.accept.deflate() => {
86                        entry.remove();
87                        parts.headers.remove(header::CONTENT_LENGTH);
88                        BodyInner::deflate(WrapBody::new(body, CompressionLevel::default()))
89                    }
90                    b"br" if self.accept.br() => {
91                        entry.remove();
92                        parts.headers.remove(header::CONTENT_LENGTH);
93                        BodyInner::brotli(WrapBody::new(body, CompressionLevel::default()))
94                    }
95                    b"zstd" if self.accept.zstd() => {
96                        entry.remove();
97                        parts.headers.remove(header::CONTENT_LENGTH);
98                        BodyInner::zstd(WrapBody::new(body, CompressionLevel::default()))
99                    }
100                    b"identity" => BodyInner::identity(body),
101                    _ if self.pass_through_unaccepted => BodyInner::identity(body),
102                    _ => return unsupported_encoding(self.accept).await,
103                }
104            } else {
105                BodyInner::identity(body)
106            };
107        let body = DecompressionBody::new(body);
108        let req = Request::from_parts(parts, body);
109        self.inner
110            .serve(ctx, req)
111            .await
112            .map(|res| res.map(|body| body.map_err(Into::into).boxed_unsync()))
113            .map_err(Into::into)
114    }
115}
116
117async fn unsupported_encoding<D>(
118    accept: AcceptEncoding,
119) -> Result<Response<UnsyncBoxBody<D, BoxError>>, BoxError>
120where
121    D: Buf + 'static,
122{
123    let res = Response::builder()
124        .header(
125            header::ACCEPT_ENCODING,
126            accept
127                .maybe_to_header_value()
128                .unwrap_or(HeaderValue::from_static("identity")),
129        )
130        .status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
131        .body(Empty::new().map_err(Into::into).boxed_unsync())
132        .unwrap();
133    Ok(res)
134}
135
136impl<S> RequestDecompression<S> {
137    /// Creates a new `RequestDecompression` wrapping the `service`.
138    pub fn new(service: S) -> Self {
139        Self {
140            inner: service,
141            accept: AcceptEncoding::default(),
142            pass_through_unaccepted: false,
143        }
144    }
145
146    define_inner_service_accessors!();
147
148    /// Passes through the request even when the encoding is not supported.
149    ///
150    /// By default pass-through is disabled.
151    pub fn pass_through_unaccepted(mut self, enabled: bool) -> Self {
152        self.pass_through_unaccepted = enabled;
153        self
154    }
155
156    /// Passes through the request even when the encoding is not supported.
157    ///
158    /// By default pass-through is disabled.
159    pub fn set_pass_through_unaccepted(&mut self, enabled: bool) -> &mut Self {
160        self.pass_through_unaccepted = enabled;
161        self
162    }
163
164    /// Sets whether to support gzip encoding.
165    pub fn gzip(mut self, enable: bool) -> Self {
166        self.accept.set_gzip(enable);
167        self
168    }
169
170    /// Sets whether to support gzip encoding.
171    pub fn set_gzip(&mut self, enable: bool) -> &mut Self {
172        self.accept.set_gzip(enable);
173        self
174    }
175
176    /// Sets whether to support Deflate encoding.
177    pub fn deflate(mut self, enable: bool) -> Self {
178        self.accept.set_deflate(enable);
179        self
180    }
181
182    /// Sets whether to support Deflate encoding.
183    pub fn set_deflate(&mut self, enable: bool) -> &mut Self {
184        self.accept.set_deflate(enable);
185        self
186    }
187
188    /// Sets whether to support Brotli encoding.
189    pub fn br(mut self, enable: bool) -> Self {
190        self.accept.set_br(enable);
191        self
192    }
193
194    /// Sets whether to support Brotli encoding.
195    pub fn set_br(&mut self, enable: bool) -> &mut Self {
196        self.accept.set_br(enable);
197        self
198    }
199
200    /// Sets whether to support Zstd encoding.
201    pub fn zstd(mut self, enable: bool) -> Self {
202        self.accept.set_zstd(enable);
203        self
204    }
205
206    /// Sets whether to support Zstd encoding.
207    pub fn set_zstd(&mut self, enable: bool) -> &mut Self {
208        self.accept.set_zstd(enable);
209        self
210    }
211}