ntex_service/
map_err.rs

1use std::{fmt, marker::PhantomData, task::Context};
2
3use super::{Service, ServiceCtx, ServiceFactory};
4
5/// Service for the `map_err` combinator, changing the type of a service's
6/// error.
7///
8/// This is created by the `ServiceExt::map_err` method.
9pub struct MapErr<A, F, E> {
10    service: A,
11    f: F,
12    _t: PhantomData<E>,
13}
14
15impl<A, F, E> MapErr<A, F, E> {
16    /// Create new `MapErr` combinator
17    pub(crate) fn new<R>(service: A, f: F) -> Self
18    where
19        A: Service<R>,
20        F: Fn(A::Error) -> E,
21    {
22        Self {
23            service,
24            f,
25            _t: PhantomData,
26        }
27    }
28}
29
30impl<A, F, E> Clone for MapErr<A, F, E>
31where
32    A: Clone,
33    F: Clone,
34{
35    #[inline]
36    fn clone(&self) -> Self {
37        MapErr {
38            service: self.service.clone(),
39            f: self.f.clone(),
40            _t: PhantomData,
41        }
42    }
43}
44
45impl<A, F, E> fmt::Debug for MapErr<A, F, E>
46where
47    A: fmt::Debug,
48{
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        f.debug_struct("MapErr")
51            .field("svc", &self.service)
52            .field("map", &std::any::type_name::<F>())
53            .finish()
54    }
55}
56
57impl<A, R, F, E> Service<R> for MapErr<A, F, E>
58where
59    A: Service<R>,
60    F: Fn(A::Error) -> E,
61{
62    type Response = A::Response;
63    type Error = E;
64
65    #[inline]
66    async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
67        ctx.ready(&self.service).await.map_err(&self.f)
68    }
69
70    #[inline]
71    fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
72        self.service.poll(cx).map_err(&self.f)
73    }
74
75    #[inline]
76    async fn call(
77        &self,
78        req: R,
79        ctx: ServiceCtx<'_, Self>,
80    ) -> Result<Self::Response, Self::Error> {
81        ctx.call(&self.service, req).await.map_err(|e| (self.f)(e))
82    }
83
84    crate::forward_shutdown!(service);
85}
86
87/// Factory for the `map_err` combinator, changing the type of a new
88/// service's error.
89///
90/// This is created by the `NewServiceExt::map_err` method.
91pub struct MapErrFactory<A, R, C, F, E>
92where
93    A: ServiceFactory<R, C>,
94    F: Fn(A::Error) -> E + Clone,
95{
96    a: A,
97    f: F,
98    e: PhantomData<fn(R, C) -> E>,
99}
100
101impl<A, R, C, F, E> MapErrFactory<A, R, C, F, E>
102where
103    A: ServiceFactory<R, C>,
104    F: Fn(A::Error) -> E + Clone,
105{
106    /// Create new `MapErr` new service instance
107    pub(crate) fn new(a: A, f: F) -> Self {
108        Self {
109            a,
110            f,
111            e: PhantomData,
112        }
113    }
114}
115
116impl<A, R, C, F, E> Clone for MapErrFactory<A, R, C, F, E>
117where
118    A: ServiceFactory<R, C> + Clone,
119    F: Fn(A::Error) -> E + Clone,
120{
121    fn clone(&self) -> Self {
122        Self {
123            a: self.a.clone(),
124            f: self.f.clone(),
125            e: PhantomData,
126        }
127    }
128}
129
130impl<A, R, C, F, E> fmt::Debug for MapErrFactory<A, R, C, F, E>
131where
132    A: ServiceFactory<R, C> + fmt::Debug,
133    F: Fn(A::Error) -> E + Clone,
134{
135    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136        f.debug_struct("MapErrFactory")
137            .field("factory", &self.a)
138            .field("map", &std::any::type_name::<F>())
139            .finish()
140    }
141}
142
143impl<A, R, C, F, E> ServiceFactory<R, C> for MapErrFactory<A, R, C, F, E>
144where
145    A: ServiceFactory<R, C>,
146    F: Fn(A::Error) -> E + Clone,
147{
148    type Response = A::Response;
149    type Error = E;
150
151    type Service = MapErr<A::Service, F, E>;
152    type InitError = A::InitError;
153
154    #[inline]
155    async fn create(&self, cfg: C) -> Result<Self::Service, Self::InitError> {
156        self.a.create(cfg).await.map(|service| MapErr {
157            service,
158            f: self.f.clone(),
159            _t: PhantomData,
160        })
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use std::{cell::Cell, rc::Rc};
167
168    use super::*;
169    use crate::{fn_factory, Pipeline};
170
171    #[derive(Debug, Clone)]
172    struct Srv(bool, Rc<Cell<usize>>);
173
174    impl Service<()> for Srv {
175        type Response = ();
176        type Error = ();
177
178        async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
179            if self.0 {
180                Err(())
181            } else {
182                Ok(())
183            }
184        }
185
186        async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
187            Err(())
188        }
189
190        async fn shutdown(&self) {
191            self.1.set(self.1.get() + 1);
192        }
193    }
194
195    #[ntex::test]
196    async fn test_ready() {
197        let cnt_sht = Rc::new(Cell::new(0));
198        let srv = Pipeline::new(Srv(true, cnt_sht.clone()).map_err(|_| "error"));
199        let res = srv.ready().await;
200        assert_eq!(res, Err("error"));
201
202        srv.shutdown().await;
203        assert_eq!(cnt_sht.get(), 1);
204    }
205
206    #[ntex::test]
207    async fn test_service() {
208        let srv = Pipeline::new(
209            Srv(false, Rc::new(Cell::new(0)))
210                .map_err(|_| "error")
211                .clone(),
212        );
213        let res = srv.call(()).await;
214        assert!(res.is_err());
215        assert_eq!(res.err().unwrap(), "error");
216
217        let _ = format!("{:?}", srv);
218    }
219
220    #[ntex::test]
221    async fn test_pipeline() {
222        let srv = Pipeline::new(
223            crate::chain(Srv(false, Rc::new(Cell::new(0))))
224                .map_err(|_| "error")
225                .clone(),
226        );
227        let res = srv.call(()).await;
228        assert!(res.is_err());
229        assert_eq!(res.err().unwrap(), "error");
230
231        let _ = format!("{:?}", srv);
232    }
233
234    #[ntex::test]
235    async fn test_factory() {
236        let new_srv =
237            fn_factory(|| async { Ok::<_, ()>(Srv(false, Rc::new(Cell::new(0)))) })
238                .map_err(|_| "error")
239                .clone();
240        let srv = Pipeline::new(new_srv.create(&()).await.unwrap());
241        let res = srv.call(()).await;
242        assert!(res.is_err());
243        assert_eq!(res.err().unwrap(), "error");
244        let _ = format!("{:?}", new_srv);
245    }
246
247    #[ntex::test]
248    async fn test_pipeline_factory() {
249        let new_srv = crate::chain_factory(fn_factory(|| async {
250            Ok::<Srv, ()>(Srv(false, Rc::new(Cell::new(0))))
251        }))
252        .map_err(|_| "error")
253        .clone();
254        let srv = Pipeline::new(new_srv.create(&()).await.unwrap());
255        let res = srv.call(()).await;
256        assert!(res.is_err());
257        assert_eq!(res.err().unwrap(), "error");
258        let _ = format!("{:?}", new_srv);
259    }
260}