xitca_web/middleware/
eraser.rs

1//! type eraser middleware.
2
3use core::marker::PhantomData;
4
5use crate::service::Service;
6
7#[doc(hidden)]
8mod marker {
9    pub struct EraseReqBody;
10    pub struct EraseResBody;
11
12    pub struct EraseErr;
13}
14
15use marker::*;
16
17/// Type eraser middleware is for erasing "unwanted" complex types produced by service tree
18/// and expose well known concrete types `xitca-web` can handle.
19///
20/// # Example
21/// ```rust
22/// # use xitca_web::{
23/// #   handler::handler_service,
24/// #   middleware::{eraser::TypeEraser, limit::Limit, Group},
25/// #   service::ServiceExt,
26/// #   App, WebContext
27/// #   };
28/// // a handler function expect xitca_web::body::RequestBody as body type.
29/// async fn handler(_: &WebContext<'_>) -> &'static str {
30///     "hello,world!"
31/// }
32///
33/// // a limit middleware that limit request body to max size of 1MB.
34/// // this middleware would produce a new type of request body that
35/// // handler function don't know of.
36/// let limit = Limit::new().set_request_body_max_size(1024 * 1024);
37///
38/// // an eraser middleware that transform any request body to xitca_web::body::RequestBody.
39/// let eraser = TypeEraser::request_body();
40///
41/// App::new()
42///     .at("/", handler_service(handler))
43///     // introduce eraser middleware between handler and limit middleware
44///     // to resolve the type difference between them.
45///     // without it this piece of code would simply refuse to compile.
46///     .enclosed(eraser.clone())
47///     .enclosed(limit.clone());
48///
49/// // group middleware is suggested way of handling of use case like this.
50/// let group = Group::new()
51///     .enclosed(eraser)
52///     .enclosed(limit);
53///
54/// // it's also suggested to group multiple type mutation middlewares together and apply
55/// // eraser on them once if possible. the reason being TypeErase has a cost and by
56/// // grouping you only pay the cost once.
57///
58/// App::new()
59///     .at("/", handler_service(handler))
60///     .enclosed(group);
61/// ```
62pub struct TypeEraser<M>(PhantomData<M>);
63
64impl<M> Clone for TypeEraser<M> {
65    fn clone(&self) -> Self {
66        Self(PhantomData)
67    }
68}
69
70impl TypeEraser<EraseReqBody> {
71    /// Erase generic request body type. making downstream middlewares observe [`RequestBody`].
72    ///
73    /// [`RequestBody`]: crate::body::RequestBody
74    pub const fn request_body() -> Self {
75        TypeEraser(PhantomData)
76    }
77}
78
79impl TypeEraser<EraseResBody> {
80    /// Erase generic response body type. making downstream middlewares observe [`ResponseBody`].
81    ///
82    /// [`ResponseBody`]: crate::body::ResponseBody
83    pub const fn response_body() -> Self {
84        TypeEraser(PhantomData)
85    }
86}
87
88impl TypeEraser<EraseErr> {
89    /// Erase generic E type from Service<Error = E>. making downstream middlewares observe [`Error`].
90    ///
91    /// [`Error`]: crate::error::Error
92    pub const fn error() -> Self {
93        TypeEraser(PhantomData)
94    }
95}
96
97impl<M, S, E> Service<Result<S, E>> for TypeEraser<M> {
98    type Response = service::EraserService<M, S>;
99    type Error = E;
100
101    async fn call(&self, res: Result<S, E>) -> Result<Self::Response, Self::Error> {
102        res.map(|service| service::EraserService {
103            service,
104            _erase: PhantomData,
105        })
106    }
107}
108
109mod service {
110    use core::cell::RefCell;
111
112    use crate::{
113        WebContext,
114        body::{BodyStream, BoxBody},
115        body::{RequestBody, ResponseBody},
116        bytes::Bytes,
117        error::Error,
118        http::WebResponse,
119        service::ready::ReadyService,
120    };
121
122    use super::*;
123
124    pub struct EraserService<M, S> {
125        pub(super) service: S,
126        pub(super) _erase: PhantomData<M>,
127    }
128
129    impl<'r, S, C, ReqB, ResB, Err> Service<WebContext<'r, C, ReqB>> for EraserService<EraseReqBody, S>
130    where
131        S: for<'rs> Service<WebContext<'rs, C>, Response = WebResponse<ResB>, Error = Err>,
132        ReqB: BodyStream<Chunk = Bytes> + Default + 'static,
133        ResB: BodyStream<Chunk = Bytes> + 'static,
134    {
135        type Response = WebResponse;
136        type Error = Err;
137
138        async fn call(&self, mut ctx: WebContext<'r, C, ReqB>) -> Result<Self::Response, Self::Error> {
139            let body = ctx.take_body_mut();
140            let mut body = RefCell::new(RequestBody::Unknown(BoxBody::new(body)));
141            let WebContext { req, ctx, .. } = ctx;
142            let res = self.service.call(WebContext::new(req, &mut body, ctx)).await?;
143            Ok(res.map(ResponseBody::box_stream))
144        }
145    }
146
147    impl<S, Req, ResB> Service<Req> for EraserService<EraseResBody, S>
148    where
149        S: Service<Req, Response = WebResponse<ResB>>,
150        ResB: BodyStream<Chunk = Bytes> + 'static,
151    {
152        type Response = WebResponse;
153        type Error = S::Error;
154
155        #[inline]
156        async fn call(&self, req: Req) -> Result<Self::Response, Self::Error> {
157            let res = self.service.call(req).await?;
158            Ok(res.map(ResponseBody::box_stream))
159        }
160    }
161
162    impl<'r, C, B, S> Service<WebContext<'r, C, B>> for EraserService<EraseErr, S>
163    where
164        S: Service<WebContext<'r, C, B>>,
165        S::Error: Into<Error>,
166    {
167        type Response = S::Response;
168        type Error = Error;
169
170        #[inline]
171        async fn call(&self, ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
172            self.service.call(ctx).await.map_err(Into::into)
173        }
174    }
175
176    impl<M, S> ReadyService for EraserService<M, S>
177    where
178        S: ReadyService,
179    {
180        type Ready = S::Ready;
181
182        #[inline]
183        async fn ready(&self) -> Self::Ready {
184            self.service.ready().await
185        }
186    }
187}
188
189#[cfg(test)]
190mod test {
191    use xitca_http::body::Once;
192    use xitca_unsafe_collection::futures::NowOrPanic;
193
194    use crate::{
195        App, WebContext,
196        bytes::Bytes,
197        error::Error,
198        handler::handler_service,
199        http::{Request, StatusCode, WebResponse},
200        middleware::Group,
201        service::ServiceExt,
202    };
203
204    use super::*;
205
206    async fn handler(_: &WebContext<'_>) -> &'static str {
207        "996"
208    }
209
210    async fn map_body<S, C, B, Err>(_: &S, _: WebContext<'_, C, B>) -> Result<WebResponse<Once<Bytes>>, Err>
211    where
212        S: for<'r> Service<WebContext<'r, C, B>, Response = WebResponse, Error = Err>,
213    {
214        Ok(WebResponse::new(Once::new(Bytes::new())))
215    }
216
217    async fn middleware_fn<S, C, B, Err>(s: &S, ctx: WebContext<'_, C, B>) -> Result<WebResponse, Err>
218    where
219        S: for<'r> Service<WebContext<'r, C, B>, Response = WebResponse, Error = Err>,
220    {
221        s.call(ctx).await
222    }
223
224    #[test]
225    fn erase_body() {
226        let _ = App::new()
227            // map WebResponse to WebResponse<Once<Bytes>> type.
228            .at("/", handler_service(handler).enclosed_fn(map_body))
229            // erase the body type to make it WebResponse type again.
230            .enclosed(TypeEraser::response_body())
231            // observe erased body type.
232            .enclosed_fn(middleware_fn)
233            .finish()
234            .call(())
235            .now_or_panic()
236            .unwrap()
237            .call(Request::default())
238            .now_or_panic()
239            .unwrap();
240    }
241
242    #[test]
243    fn erase_error() {
244        async fn middleware_fn<S, C, B, Err>(s: &S, ctx: WebContext<'_, C, B>) -> Result<WebResponse, StatusCode>
245        where
246            S: for<'r> Service<WebContext<'r, C, B>, Response = WebResponse, Error = Err>,
247        {
248            s.call(ctx).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
249        }
250
251        async fn middleware_fn2<S, C, B>(s: &S, ctx: WebContext<'_, C, B>) -> Result<WebResponse, Error>
252        where
253            S: for<'r> Service<WebContext<'r, C, B>, Response = WebResponse, Error = Error>,
254        {
255            s.call(ctx).await
256        }
257
258        let _ = App::new()
259            // map WebResponse to WebResponse<Once<Bytes>> type.
260            .at("/", handler_service(handler).enclosed(TypeEraser::error()))
261            .enclosed(
262                Group::new()
263                    .enclosed_fn(middleware_fn)
264                    .enclosed(TypeEraser::error())
265                    .enclosed_fn(middleware_fn2),
266            )
267            .finish()
268            .call(())
269            .now_or_panic()
270            .unwrap()
271            .call(Request::default())
272            .now_or_panic()
273            .unwrap();
274    }
275}