rama_http/layer/decompression/
service.rs1use std::fmt;
2
3use super::{body::BodyInner, DecompressionBody};
4use crate::dep::http_body::Body;
5use crate::layer::util::{
6 compression::{AcceptEncoding, CompressionLevel, WrapBody},
7 content_encoding::SupportedEncodings,
8};
9use crate::{
10 header::{self, ACCEPT_ENCODING},
11 Request, Response,
12};
13use rama_core::{Context, Service};
14use rama_utils::macros::define_inner_service_accessors;
15
16pub struct Decompression<S> {
23 pub(crate) inner: S,
24 pub(crate) accept: AcceptEncoding,
25}
26
27impl<S> Decompression<S> {
28 pub fn new(service: S) -> Self {
30 Self {
31 inner: service,
32 accept: AcceptEncoding::default(),
33 }
34 }
35
36 define_inner_service_accessors!();
37
38 pub fn gzip(mut self, enable: bool) -> Self {
40 self.accept.set_gzip(enable);
41 self
42 }
43
44 pub fn set_gzip(&mut self, enable: bool) -> &mut Self {
46 self.accept.set_gzip(enable);
47 self
48 }
49
50 pub fn deflate(mut self, enable: bool) -> Self {
52 self.accept.set_deflate(enable);
53 self
54 }
55
56 pub fn set_deflate(&mut self, enable: bool) -> &mut Self {
58 self.accept.set_deflate(enable);
59 self
60 }
61
62 pub fn br(mut self, enable: bool) -> Self {
64 self.accept.set_br(enable);
65 self
66 }
67
68 pub fn set_br(&mut self, enable: bool) -> &mut Self {
70 self.accept.set_br(enable);
71 self
72 }
73
74 pub fn zstd(mut self, enable: bool) -> Self {
76 self.accept.set_zstd(enable);
77 self
78 }
79
80 pub fn set_zstd(&mut self, enable: bool) -> &mut Self {
82 self.accept.set_zstd(enable);
83 self
84 }
85}
86
87impl<S: fmt::Debug> fmt::Debug for Decompression<S> {
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 f.debug_struct("Decompression")
90 .field("inner", &self.inner)
91 .field("accept", &self.accept)
92 .finish()
93 }
94}
95
96impl<S: Clone> Clone for Decompression<S> {
97 fn clone(&self) -> Self {
98 Decompression {
99 inner: self.inner.clone(),
100 accept: self.accept,
101 }
102 }
103}
104
105impl<S, State, ReqBody, ResBody> Service<State, Request<ReqBody>> for Decompression<S>
106where
107 S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
108 State: Clone + Send + Sync + 'static,
109 ReqBody: Send + 'static,
110 ResBody: Body<Data: Send + 'static, Error: Send + 'static> + Send + 'static,
111{
112 type Response = Response<DecompressionBody<ResBody>>;
113 type Error = S::Error;
114
115 async fn serve(
116 &self,
117 ctx: Context<State>,
118 mut req: Request<ReqBody>,
119 ) -> Result<Self::Response, Self::Error> {
120 if let header::Entry::Vacant(entry) = req.headers_mut().entry(ACCEPT_ENCODING) {
121 if let Some(accept) = self.accept.to_header_value() {
122 entry.insert(accept);
123 }
124 }
125
126 let res = self.inner.serve(ctx, req).await?;
127
128 let (mut parts, body) = res.into_parts();
129
130 let res =
131 if let header::Entry::Occupied(entry) = parts.headers.entry(header::CONTENT_ENCODING) {
132 let body = match entry.get().as_bytes() {
133 b"gzip" if self.accept.gzip() => DecompressionBody::new(BodyInner::gzip(
134 WrapBody::new(body, CompressionLevel::default()),
135 )),
136
137 b"deflate" if self.accept.deflate() => DecompressionBody::new(
138 BodyInner::deflate(WrapBody::new(body, CompressionLevel::default())),
139 ),
140
141 b"br" if self.accept.br() => DecompressionBody::new(BodyInner::brotli(
142 WrapBody::new(body, CompressionLevel::default()),
143 )),
144
145 b"zstd" if self.accept.zstd() => DecompressionBody::new(BodyInner::zstd(
146 WrapBody::new(body, CompressionLevel::default()),
147 )),
148
149 _ => {
150 return Ok(Response::from_parts(
151 parts,
152 DecompressionBody::new(BodyInner::identity(body)),
153 ))
154 }
155 };
156
157 entry.remove();
158 parts.headers.remove(header::CONTENT_LENGTH);
159
160 Response::from_parts(parts, body)
161 } else {
162 Response::from_parts(parts, DecompressionBody::new(BodyInner::identity(body)))
163 };
164
165 Ok(res)
166 }
167}