xitca_web/middleware/
decompress.rs

1//! decompression middleware
2
3use crate::service::Service;
4
5/// decompress middleware.
6///
7/// look into [WebContext]'s `Content-Encoding` header and apply according decompression to
8/// it according to enabled compress feature.
9/// `compress-x` feature must be enabled for this middleware to function correctly.
10///
11/// # Type mutation
12/// `Decompress` would mutate request body type from `B` to `Coder<B>`. Service enclosed
13/// by it must be able to handle it's mutation or utilize [TypeEraser] to erase the mutation.
14/// For more explanation please reference [type mutation](crate::middleware#type-mutation).
15///
16/// [WebContext]: crate::WebContext
17/// [TypeEraser]: crate::middleware::eraser::TypeEraser
18#[derive(Clone)]
19pub struct Decompress;
20
21impl<S, E> Service<Result<S, E>> for Decompress {
22    type Response = service::DecompressService<S>;
23    type Error = E;
24
25    async fn call(&self, res: Result<S, E>) -> Result<Self::Response, Self::Error> {
26        res.map(service::DecompressService)
27    }
28}
29
30mod service {
31    use core::{cell::RefCell, convert::Infallible};
32
33    use http_encoding::{Coder, error::EncodingError};
34
35    use crate::{
36        body::BodyStream,
37        context::WebContext,
38        error::Error,
39        error::error_from_service,
40        http::{Request, StatusCode, WebResponse, const_header_value::TEXT_UTF8, header::CONTENT_TYPE},
41        service::ready::ReadyService,
42    };
43
44    use super::*;
45
46    pub struct DecompressService<S>(pub(super) S);
47
48    impl<'r, S, C, B, Res, Err> Service<WebContext<'r, C, B>> for DecompressService<S>
49    where
50        B: BodyStream + Default,
51        S: for<'rs> Service<WebContext<'rs, C, Coder<B>>, Response = Res, Error = Err>,
52        Err: Into<Error>,
53    {
54        type Response = Res;
55        type Error = Error;
56
57        async fn call(&self, mut ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
58            let (parts, ext) = ctx.take_request().into_parts();
59            let state = ctx.ctx;
60            let (ext, body) = ext.replace_body(());
61            let req = Request::from_parts(parts, ());
62
63            let decoder = http_encoding::try_decoder(req.headers(), body)?;
64            let mut body = RefCell::new(decoder);
65            let mut req = req.map(|_| ext);
66
67            self.0
68                .call(WebContext::new(&mut req, &mut body, state))
69                .await
70                .map_err(|e| {
71                    // restore original body as error path of other services may have use of it.
72                    let body = body.into_inner().into_inner();
73                    *ctx.body_borrow_mut() = body;
74                    e.into()
75                })
76        }
77    }
78
79    impl<S> ReadyService for DecompressService<S>
80    where
81        S: ReadyService,
82    {
83        type Ready = S::Ready;
84
85        #[inline]
86        async fn ready(&self) -> Self::Ready {
87            self.0.ready().await
88        }
89    }
90
91    error_from_service!(EncodingError);
92
93    impl<'r, C, B> Service<WebContext<'r, C, B>> for EncodingError {
94        type Response = WebResponse;
95        type Error = Infallible;
96
97        async fn call(&self, req: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
98            let mut res = req.into_response(format!("{self}"));
99            res.headers_mut().insert(CONTENT_TYPE, TEXT_UTF8);
100            *res.status_mut() = StatusCode::UNSUPPORTED_MEDIA_TYPE;
101            Ok(res)
102        }
103    }
104}
105
106#[cfg(test)]
107mod test {
108    use http_encoding::{ContentEncoding, encoder};
109    use xitca_unsafe_collection::futures::NowOrPanic;
110
111    use crate::{
112        App,
113        body::ResponseBody,
114        handler::handler_service,
115        http::header::CONTENT_ENCODING,
116        http::{WebRequest, WebResponse},
117        test::collect_body,
118    };
119
120    use super::*;
121
122    const Q: &[u8] = b"what is the goal of life";
123    const A: &str = "go dock for chip";
124
125    async fn handler(vec: Vec<u8>) -> &'static str {
126        assert_eq!(Q, vec);
127        A
128    }
129
130    #[test]
131    fn build() {
132        async fn noop() -> &'static str {
133            "noop"
134        }
135
136        App::new()
137            .at("/", handler_service(noop))
138            .enclosed(Decompress)
139            .finish()
140            .call(())
141            .now_or_panic()
142            .unwrap()
143            .call(WebRequest::default())
144            .now_or_panic()
145            .ok()
146            .unwrap();
147    }
148
149    #[test]
150    fn plain() {
151        let req = WebRequest::default().map(|ext| ext.map_body(|_: ()| Q.into()));
152        App::new()
153            .at("/", handler_service(handler))
154            .enclosed(Decompress)
155            .finish()
156            .call(())
157            .now_or_panic()
158            .unwrap()
159            .call(req)
160            .now_or_panic()
161            .ok()
162            .unwrap();
163    }
164
165    #[cfg(any(feature = "compress-br", feature = "compress-gz", feature = "compress-de"))]
166    #[test]
167    fn compressed() {
168        // a hack to generate a compressed client request from server response.
169        let res = WebResponse::<ResponseBody>::new(Q.into());
170
171        #[allow(unreachable_code)]
172        let encoding = || {
173            #[cfg(all(feature = "compress-br", not(any(feature = "compress-gz", feature = "compress-de"))))]
174            {
175                return ContentEncoding::Br;
176            }
177
178            #[cfg(all(feature = "compress-gz", not(any(feature = "compress-br", feature = "compress-de"))))]
179            {
180                return ContentEncoding::Gzip;
181            }
182
183            #[cfg(all(feature = "compress-de", not(any(feature = "compress-br", feature = "compress-gz"))))]
184            {
185                return ContentEncoding::Deflate;
186            }
187
188            ContentEncoding::Br
189        };
190
191        let encoding = encoding();
192
193        let (mut parts, body) = encoder(res, encoding).into_parts();
194
195        let body = collect_body(body).now_or_panic().unwrap();
196
197        let mut req = WebRequest::default().map(|ext| ext.map_body(|_: ()| body.into()));
198
199        req.headers_mut()
200            .insert(CONTENT_ENCODING, parts.headers.remove(CONTENT_ENCODING).unwrap());
201
202        App::new()
203            .at("/", handler_service(handler))
204            .enclosed(Decompress)
205            .finish()
206            .call(())
207            .now_or_panic()
208            .unwrap()
209            .call(req)
210            .now_or_panic()
211            .ok()
212            .unwrap();
213    }
214}