rama_http/layer/compress_adapter/
service.rs

1use crate::headers::encoding::{Encoding, parse_accept_encoding_headers};
2use crate::layer::{
3    compression::{self, CompressionBody, CompressionLevel},
4    decompression::{self, DecompressionBody},
5    util::compression::WrapBody,
6};
7use rama_core::{Context, Service, error::BoxError};
8use rama_http_types::{
9    HeaderValue, Request, Response,
10    dep::http_body::Body,
11    header::{CONTENT_ENCODING, CONTENT_LENGTH},
12};
13use rama_utils::macros::define_inner_service_accessors;
14
15/// Service which tracks the original 'Accept-Encoding' header and compares
16/// it with the server 'Content-Encoding' header, to adapt the response if needed.
17///
18/// ## Example
19///
20/// `Accept-Encoding: gzip` and `Content-Encoding: zstd` will result in:
21///
22/// ```text
23/// compress_gzip(decompress_zstd(body))
24/// ```
25pub struct CompressAdaptService<S> {
26    pub(crate) inner: S,
27    pub(crate) quality: CompressionLevel,
28}
29
30impl<S> std::fmt::Debug for CompressAdaptService<S>
31where
32    S: std::fmt::Debug,
33{
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("CompressAdaptService")
36            .field("inner", &self.inner)
37            .field("quality", &self.quality)
38            .finish()
39    }
40}
41
42impl<S> Clone for CompressAdaptService<S>
43where
44    S: Clone,
45{
46    fn clone(&self) -> Self {
47        Self {
48            inner: self.inner.clone(),
49            quality: self.quality,
50        }
51    }
52}
53
54impl<S> CompressAdaptService<S> {
55    /// Creates a new `CompressAdaptService` wrapping the `service`.
56    pub fn new(service: S) -> CompressAdaptService<S> {
57        Self {
58            inner: service,
59            quality: CompressionLevel::default(),
60        }
61    }
62}
63
64impl<S> CompressAdaptService<S> {
65    define_inner_service_accessors!();
66
67    /// Sets the compression quality.
68    pub fn quality(mut self, quality: CompressionLevel) -> Self {
69        self.quality = quality;
70        self
71    }
72
73    /// Sets the compression quality.
74    pub fn set_quality(&mut self, quality: CompressionLevel) -> &mut Self {
75        self.quality = quality;
76        self
77    }
78}
79
80impl<ReqBody, ResBody, S, State> Service<State, Request<ReqBody>> for CompressAdaptService<S>
81where
82    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
83    ResBody:
84        Body<Data: Send + 'static, Error: Into<BoxError> + Send + Sync + 'static> + Send + 'static,
85    ReqBody: Send + 'static,
86    State: Clone + Send + Sync + 'static,
87{
88    type Response = Response<CompressionBody<DecompressionBody<ResBody>>>;
89    type Error = S::Error;
90
91    #[allow(unreachable_code, unused_mut, unused_variables, unreachable_patterns)]
92    async fn serve(
93        &self,
94        ctx: Context<State>,
95        req: Request<ReqBody>,
96    ) -> Result<Self::Response, Self::Error> {
97        let requested_encoding =
98            parse_accept_encoding_headers(req.headers(), true).collect::<Vec<_>>();
99
100        let res = self.inner.serve(ctx, req).await?;
101        let (mut parts, body) = res.into_parts();
102
103        match Encoding::maybe_from_content_encoding_header(&parts.headers, true) {
104            Some(server_encoding)
105                if !requested_encoding
106                    .iter()
107                    .any(|qv| qv.value == server_encoding) =>
108            {
109                tracing::trace!(
110                    %server_encoding,
111                    "server encoded not supported by requested client encoding, decompressing"
112                );
113                let decompress_body = DecompressionBody::new(match server_encoding {
114                    Encoding::Identity => decompression::body::BodyInner::identity(body),
115                    Encoding::Deflate => decompression::body::BodyInner::deflate(WrapBody::new(
116                        body,
117                        CompressionLevel::default(),
118                    )),
119                    Encoding::Gzip => decompression::body::BodyInner::gzip(WrapBody::new(
120                        body,
121                        CompressionLevel::default(),
122                    )),
123                    Encoding::Brotli => decompression::body::BodyInner::brotli(WrapBody::new(
124                        body,
125                        CompressionLevel::default(),
126                    )),
127                    Encoding::Zstd => decompression::body::BodyInner::zstd(WrapBody::new(
128                        body,
129                        CompressionLevel::default(),
130                    )),
131                });
132
133                parts.headers.remove(CONTENT_LENGTH);
134                parts.headers.remove(CONTENT_ENCODING);
135
136                let final_body = match Encoding::maybe_preferred_encoding(
137                    requested_encoding.into_iter(),
138                ) {
139                    Some(client_encoding) => {
140                        tracing::trace!(
141                            %server_encoding,
142                            %client_encoding,
143                            "re-encode decompressed response body into preferred client encoding"
144                        );
145                        parts
146                            .headers
147                            .insert(CONTENT_ENCODING, HeaderValue::from(client_encoding));
148                        match client_encoding {
149                            Encoding::Identity => CompressionBody::new(
150                                compression::body::BodyInner::identity(decompress_body),
151                            ),
152                            Encoding::Deflate => {
153                                CompressionBody::new(compression::body::BodyInner::deflate(
154                                    WrapBody::new(decompress_body, self.quality),
155                                ))
156                            }
157                            Encoding::Gzip => {
158                                CompressionBody::new(compression::body::BodyInner::gzip(
159                                    WrapBody::new(decompress_body, self.quality),
160                                ))
161                            }
162                            Encoding::Brotli => {
163                                CompressionBody::new(compression::body::BodyInner::brotli(
164                                    WrapBody::new(decompress_body, self.quality),
165                                ))
166                            }
167                            Encoding::Zstd => {
168                                CompressionBody::new(compression::body::BodyInner::zstd(
169                                    WrapBody::new(decompress_body, self.quality),
170                                ))
171                            }
172                        }
173                    }
174                    None => CompressionBody::new(compression::body::BodyInner::identity(
175                        decompress_body,
176                    )),
177                };
178
179                Ok(Response::from_parts(parts, final_body))
180            }
181            _ => {
182                tracing::trace!("no action required for server response encoding");
183                let body = CompressionBody::new(compression::body::BodyInner::identity(
184                    DecompressionBody::new(decompression::body::BodyInner::identity(body)),
185                ));
186                Ok(Response::from_parts(parts, body))
187            }
188        }
189    }
190}