rama_http/layer/compress_adapter/
service.rs1use crate::headers::encoding::{Encoding, parse_accept_encoding_headers};
2use crate::layer::{
3 compression::{self, CompressionBody, CompressionLevel},
4 decompression::{self, DecompressionBody},
5 util::compression::WrapBody,
6};
7use rama_core::{Context, Service, error::BoxError};
8use rama_http_types::{
9 HeaderValue, Request, Response,
10 dep::http_body::Body,
11 header::{CONTENT_ENCODING, CONTENT_LENGTH},
12};
13use rama_utils::macros::define_inner_service_accessors;
14
15pub struct CompressAdaptService<S> {
26 pub(crate) inner: S,
27 pub(crate) quality: CompressionLevel,
28}
29
30impl<S> std::fmt::Debug for CompressAdaptService<S>
31where
32 S: std::fmt::Debug,
33{
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("CompressAdaptService")
36 .field("inner", &self.inner)
37 .field("quality", &self.quality)
38 .finish()
39 }
40}
41
42impl<S> Clone for CompressAdaptService<S>
43where
44 S: Clone,
45{
46 fn clone(&self) -> Self {
47 Self {
48 inner: self.inner.clone(),
49 quality: self.quality,
50 }
51 }
52}
53
54impl<S> CompressAdaptService<S> {
55 pub fn new(service: S) -> CompressAdaptService<S> {
57 Self {
58 inner: service,
59 quality: CompressionLevel::default(),
60 }
61 }
62}
63
64impl<S> CompressAdaptService<S> {
65 define_inner_service_accessors!();
66
67 pub fn quality(mut self, quality: CompressionLevel) -> Self {
69 self.quality = quality;
70 self
71 }
72
73 pub fn set_quality(&mut self, quality: CompressionLevel) -> &mut Self {
75 self.quality = quality;
76 self
77 }
78}
79
80impl<ReqBody, ResBody, S, State> Service<State, Request<ReqBody>> for CompressAdaptService<S>
81where
82 S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
83 ResBody:
84 Body<Data: Send + 'static, Error: Into<BoxError> + Send + Sync + 'static> + Send + 'static,
85 ReqBody: Send + 'static,
86 State: Clone + Send + Sync + 'static,
87{
88 type Response = Response<CompressionBody<DecompressionBody<ResBody>>>;
89 type Error = S::Error;
90
91 #[allow(unreachable_code, unused_mut, unused_variables, unreachable_patterns)]
92 async fn serve(
93 &self,
94 ctx: Context<State>,
95 req: Request<ReqBody>,
96 ) -> Result<Self::Response, Self::Error> {
97 let requested_encoding =
98 parse_accept_encoding_headers(req.headers(), true).collect::<Vec<_>>();
99
100 let res = self.inner.serve(ctx, req).await?;
101 let (mut parts, body) = res.into_parts();
102
103 match Encoding::maybe_from_content_encoding_header(&parts.headers, true) {
104 Some(server_encoding)
105 if !requested_encoding
106 .iter()
107 .any(|qv| qv.value == server_encoding) =>
108 {
109 tracing::trace!(
110 %server_encoding,
111 "server encoded not supported by requested client encoding, decompressing"
112 );
113 let decompress_body = DecompressionBody::new(match server_encoding {
114 Encoding::Identity => decompression::body::BodyInner::identity(body),
115 Encoding::Deflate => decompression::body::BodyInner::deflate(WrapBody::new(
116 body,
117 CompressionLevel::default(),
118 )),
119 Encoding::Gzip => decompression::body::BodyInner::gzip(WrapBody::new(
120 body,
121 CompressionLevel::default(),
122 )),
123 Encoding::Brotli => decompression::body::BodyInner::brotli(WrapBody::new(
124 body,
125 CompressionLevel::default(),
126 )),
127 Encoding::Zstd => decompression::body::BodyInner::zstd(WrapBody::new(
128 body,
129 CompressionLevel::default(),
130 )),
131 });
132
133 parts.headers.remove(CONTENT_LENGTH);
134 parts.headers.remove(CONTENT_ENCODING);
135
136 let final_body = match Encoding::maybe_preferred_encoding(
137 requested_encoding.into_iter(),
138 ) {
139 Some(client_encoding) => {
140 tracing::trace!(
141 %server_encoding,
142 %client_encoding,
143 "re-encode decompressed response body into preferred client encoding"
144 );
145 parts
146 .headers
147 .insert(CONTENT_ENCODING, HeaderValue::from(client_encoding));
148 match client_encoding {
149 Encoding::Identity => CompressionBody::new(
150 compression::body::BodyInner::identity(decompress_body),
151 ),
152 Encoding::Deflate => {
153 CompressionBody::new(compression::body::BodyInner::deflate(
154 WrapBody::new(decompress_body, self.quality),
155 ))
156 }
157 Encoding::Gzip => {
158 CompressionBody::new(compression::body::BodyInner::gzip(
159 WrapBody::new(decompress_body, self.quality),
160 ))
161 }
162 Encoding::Brotli => {
163 CompressionBody::new(compression::body::BodyInner::brotli(
164 WrapBody::new(decompress_body, self.quality),
165 ))
166 }
167 Encoding::Zstd => {
168 CompressionBody::new(compression::body::BodyInner::zstd(
169 WrapBody::new(decompress_body, self.quality),
170 ))
171 }
172 }
173 }
174 None => CompressionBody::new(compression::body::BodyInner::identity(
175 decompress_body,
176 )),
177 };
178
179 Ok(Response::from_parts(parts, final_body))
180 }
181 _ => {
182 tracing::trace!("no action required for server response encoding");
183 let body = CompressionBody::new(compression::body::BodyInner::identity(
184 DecompressionBody::new(decompression::body::BodyInner::identity(body)),
185 ));
186 Ok(Response::from_parts(parts, body))
187 }
188 }
189 }
190}