rama_http/layer/decompression/request/
service.rs1use std::fmt;
2
3use crate::dep::http_body::Body;
4use crate::dep::http_body_util::{BodyExt, Empty, combinators::UnsyncBoxBody};
5use crate::headers::encoding::{AcceptEncoding, SupportedEncodings};
6use crate::layer::{
7 decompression::DecompressionBody,
8 decompression::body::BodyInner,
9 util::compression::{CompressionLevel, WrapBody},
10};
11use crate::{HeaderValue, Request, Response, StatusCode, header};
12use bytes::Buf;
13use rama_core::error::BoxError;
14use rama_core::{Context, Service};
15use rama_utils::macros::define_inner_service_accessors;
16
17pub struct RequestDecompression<S> {
29 pub(super) inner: S,
30 pub(super) accept: AcceptEncoding,
31 pub(super) pass_through_unaccepted: bool,
32}
33
34impl<S: fmt::Debug> fmt::Debug for RequestDecompression<S> {
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 f.debug_struct("RequestDecompression")
37 .field("inner", &self.inner)
38 .field("accept", &self.accept)
39 .field("pass_through_unaccepted", &self.pass_through_unaccepted)
40 .finish()
41 }
42}
43
44impl<S: Clone> Clone for RequestDecompression<S> {
45 fn clone(&self) -> Self {
46 RequestDecompression {
47 inner: self.inner.clone(),
48 accept: self.accept,
49 pass_through_unaccepted: self.pass_through_unaccepted,
50 }
51 }
52}
53
54impl<S, State, ReqBody, ResBody, D> Service<State, Request<ReqBody>> for RequestDecompression<S>
55where
56 S: Service<
57 State,
58 Request<DecompressionBody<ReqBody>>,
59 Response = Response<ResBody>,
60 Error: Into<BoxError>,
61 >,
62 State: Clone + Send + Sync + 'static,
63 ReqBody: Body + Send + 'static,
64 ResBody: Body<Data = D, Error: Into<BoxError>> + Send + 'static,
65 D: Buf + 'static,
66{
67 type Response = Response<UnsyncBoxBody<D, BoxError>>;
68 type Error = BoxError;
69
70 async fn serve(
71 &self,
72 ctx: Context<State>,
73 req: Request<ReqBody>,
74 ) -> Result<Self::Response, Self::Error> {
75 let (mut parts, body) = req.into_parts();
76
77 let body =
78 if let header::Entry::Occupied(entry) = parts.headers.entry(header::CONTENT_ENCODING) {
79 match entry.get().as_bytes() {
80 b"gzip" if self.accept.gzip() => {
81 entry.remove();
82 parts.headers.remove(header::CONTENT_LENGTH);
83 BodyInner::gzip(WrapBody::new(body, CompressionLevel::default()))
84 }
85 b"deflate" if self.accept.deflate() => {
86 entry.remove();
87 parts.headers.remove(header::CONTENT_LENGTH);
88 BodyInner::deflate(WrapBody::new(body, CompressionLevel::default()))
89 }
90 b"br" if self.accept.br() => {
91 entry.remove();
92 parts.headers.remove(header::CONTENT_LENGTH);
93 BodyInner::brotli(WrapBody::new(body, CompressionLevel::default()))
94 }
95 b"zstd" if self.accept.zstd() => {
96 entry.remove();
97 parts.headers.remove(header::CONTENT_LENGTH);
98 BodyInner::zstd(WrapBody::new(body, CompressionLevel::default()))
99 }
100 b"identity" => BodyInner::identity(body),
101 _ if self.pass_through_unaccepted => BodyInner::identity(body),
102 _ => return unsupported_encoding(self.accept).await,
103 }
104 } else {
105 BodyInner::identity(body)
106 };
107 let body = DecompressionBody::new(body);
108 let req = Request::from_parts(parts, body);
109 self.inner
110 .serve(ctx, req)
111 .await
112 .map(|res| res.map(|body| body.map_err(Into::into).boxed_unsync()))
113 .map_err(Into::into)
114 }
115}
116
117async fn unsupported_encoding<D>(
118 accept: AcceptEncoding,
119) -> Result<Response<UnsyncBoxBody<D, BoxError>>, BoxError>
120where
121 D: Buf + 'static,
122{
123 let res = Response::builder()
124 .header(
125 header::ACCEPT_ENCODING,
126 accept
127 .maybe_to_header_value()
128 .unwrap_or(HeaderValue::from_static("identity")),
129 )
130 .status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
131 .body(Empty::new().map_err(Into::into).boxed_unsync())
132 .unwrap();
133 Ok(res)
134}
135
136impl<S> RequestDecompression<S> {
137 pub fn new(service: S) -> Self {
139 Self {
140 inner: service,
141 accept: AcceptEncoding::default(),
142 pass_through_unaccepted: false,
143 }
144 }
145
146 define_inner_service_accessors!();
147
148 pub fn pass_through_unaccepted(mut self, enabled: bool) -> Self {
152 self.pass_through_unaccepted = enabled;
153 self
154 }
155
156 pub fn set_pass_through_unaccepted(&mut self, enabled: bool) -> &mut Self {
160 self.pass_through_unaccepted = enabled;
161 self
162 }
163
164 pub fn gzip(mut self, enable: bool) -> Self {
166 self.accept.set_gzip(enable);
167 self
168 }
169
170 pub fn set_gzip(&mut self, enable: bool) -> &mut Self {
172 self.accept.set_gzip(enable);
173 self
174 }
175
176 pub fn deflate(mut self, enable: bool) -> Self {
178 self.accept.set_deflate(enable);
179 self
180 }
181
182 pub fn set_deflate(&mut self, enable: bool) -> &mut Self {
184 self.accept.set_deflate(enable);
185 self
186 }
187
188 pub fn br(mut self, enable: bool) -> Self {
190 self.accept.set_br(enable);
191 self
192 }
193
194 pub fn set_br(&mut self, enable: bool) -> &mut Self {
196 self.accept.set_br(enable);
197 self
198 }
199
200 pub fn zstd(mut self, enable: bool) -> Self {
202 self.accept.set_zstd(enable);
203 self
204 }
205
206 pub fn set_zstd(&mut self, enable: bool) -> &mut Self {
208 self.accept.set_zstd(enable);
209 self
210 }
211}