ntex_service/
map_err.rs

1use std::{fmt, marker::PhantomData, task::Context};
2
3use super::{Service, ServiceCtx, ServiceFactory};
4
5/// Service for the `map_err` combinator, changing the type of a service's
6/// error.
7///
8/// This is created by the `ServiceExt::map_err` method.
9pub struct MapErr<A, F, E> {
10    service: A,
11    f: F,
12    _t: PhantomData<E>,
13}
14
15impl<A, F, E> MapErr<A, F, E> {
16    /// Create new `MapErr` combinator
17    pub(crate) fn new<R>(service: A, f: F) -> Self
18    where
19        A: Service<R>,
20        F: Fn(A::Error) -> E,
21    {
22        Self {
23            service,
24            f,
25            _t: PhantomData,
26        }
27    }
28}
29
30impl<A, F, E> Clone for MapErr<A, F, E>
31where
32    A: Clone,
33    F: Clone,
34{
35    #[inline]
36    fn clone(&self) -> Self {
37        MapErr {
38            service: self.service.clone(),
39            f: self.f.clone(),
40            _t: PhantomData,
41        }
42    }
43}
44
45impl<A, F, E> fmt::Debug for MapErr<A, F, E>
46where
47    A: fmt::Debug,
48{
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        f.debug_struct("MapErr")
51            .field("svc", &self.service)
52            .field("map", &std::any::type_name::<F>())
53            .finish()
54    }
55}
56
57impl<A, R, F, E> Service<R> for MapErr<A, F, E>
58where
59    A: Service<R>,
60    F: Fn(A::Error) -> E,
61{
62    type Response = A::Response;
63    type Error = E;
64
65    #[inline]
66    async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
67        ctx.ready(&self.service).await.map_err(&self.f)
68    }
69
70    #[inline]
71    fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
72        self.service.poll(cx).map_err(&self.f)
73    }
74
75    #[inline]
76    async fn call(
77        &self,
78        req: R,
79        ctx: ServiceCtx<'_, Self>,
80    ) -> Result<Self::Response, Self::Error> {
81        ctx.call(&self.service, req).await.map_err(|e| (self.f)(e))
82    }
83
84    crate::forward_shutdown!(service);
85}
86
87/// Factory for the `map_err` combinator, changing the type of a new
88/// service's error.
89///
90/// This is created by the `NewServiceExt::map_err` method.
91pub struct MapErrFactory<A, R, C, F, E>
92where
93    A: ServiceFactory<R, C>,
94    F: Fn(A::Error) -> E + Clone,
95{
96    a: A,
97    f: F,
98    e: PhantomData<fn(R, C) -> E>,
99}
100
101impl<A, R, C, F, E> MapErrFactory<A, R, C, F, E>
102where
103    A: ServiceFactory<R, C>,
104    F: Fn(A::Error) -> E + Clone,
105{
106    /// Create new `MapErr` new service instance
107    pub(crate) fn new(a: A, f: F) -> Self {
108        Self {
109            a,
110            f,
111            e: PhantomData,
112        }
113    }
114}
115
116impl<A, R, C, F, E> Clone for MapErrFactory<A, R, C, F, E>
117where
118    A: ServiceFactory<R, C> + Clone,
119    F: Fn(A::Error) -> E + Clone,
120{
121    fn clone(&self) -> Self {
122        Self {
123            a: self.a.clone(),
124            f: self.f.clone(),
125            e: PhantomData,
126        }
127    }
128}
129
130impl<A, R, C, F, E> fmt::Debug for MapErrFactory<A, R, C, F, E>
131where
132    A: ServiceFactory<R, C> + fmt::Debug,
133    F: Fn(A::Error) -> E + Clone,
134{
135    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136        f.debug_struct("MapErrFactory")
137            .field("factory", &self.a)
138            .field("map", &std::any::type_name::<F>())
139            .finish()
140    }
141}
142
143impl<A, R, C, F, E> ServiceFactory<R, C> for MapErrFactory<A, R, C, F, E>
144where
145    A: ServiceFactory<R, C>,
146    F: Fn(A::Error) -> E + Clone,
147{
148    type Response = A::Response;
149    type Error = E;
150
151    type Service = MapErr<A::Service, F, E>;
152    type InitError = A::InitError;
153
154    #[inline]
155    async fn create(&self, cfg: C) -> Result<Self::Service, Self::InitError> {
156        self.a.create(cfg).await.map(|service| MapErr {
157            service,
158            f: self.f.clone(),
159            _t: PhantomData,
160        })
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use std::{cell::Cell, rc::Rc};
167
168    use super::*;
169    use crate::{Pipeline, fn_factory};
170
171    #[derive(Debug, Clone)]
172    struct Srv(bool, Rc<Cell<usize>>);
173
174    impl Service<()> for Srv {
175        type Response = ();
176        type Error = ();
177
178        async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
179            if self.0 { Err(()) } else { Ok(()) }
180        }
181
182        async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
183            Err(())
184        }
185
186        async fn shutdown(&self) {
187            self.1.set(self.1.get() + 1);
188        }
189    }
190
191    #[ntex::test]
192    async fn test_ready() {
193        let cnt_sht = Rc::new(Cell::new(0));
194        let srv = Pipeline::new(Srv(true, cnt_sht.clone()).map_err(|_| "error"));
195        let res = srv.ready().await;
196        assert_eq!(res, Err("error"));
197
198        srv.shutdown().await;
199        assert_eq!(cnt_sht.get(), 1);
200    }
201
202    #[ntex::test]
203    async fn test_service() {
204        let srv = Pipeline::new(
205            Srv(false, Rc::new(Cell::new(0)))
206                .map_err(|_| "error")
207                .clone(),
208        );
209        let res = srv.call(()).await;
210        assert!(res.is_err());
211        assert_eq!(res.err().unwrap(), "error");
212
213        let _ = format!("{srv:?}");
214    }
215
216    #[ntex::test]
217    async fn test_pipeline() {
218        let srv = Pipeline::new(
219            crate::chain(Srv(false, Rc::new(Cell::new(0))))
220                .map_err(|_| "error")
221                .clone(),
222        );
223        let res = srv.call(()).await;
224        assert!(res.is_err());
225        assert_eq!(res.err().unwrap(), "error");
226
227        let _ = format!("{srv:?}");
228    }
229
230    #[ntex::test]
231    async fn test_factory() {
232        let new_srv =
233            fn_factory(|| async { Ok::<_, ()>(Srv(false, Rc::new(Cell::new(0)))) })
234                .map_err(|_| "error")
235                .clone();
236        let srv = Pipeline::new(new_srv.create(&()).await.unwrap());
237        let res = srv.call(()).await;
238        assert!(res.is_err());
239        assert_eq!(res.err().unwrap(), "error");
240        let _ = format!("{new_srv:?}");
241    }
242
243    #[ntex::test]
244    async fn test_pipeline_factory() {
245        let new_srv = crate::chain_factory(fn_factory(|| async {
246            Ok::<Srv, ()>(Srv(false, Rc::new(Cell::new(0))))
247        }))
248        .map_err(|_| "error")
249        .clone();
250        let srv = Pipeline::new(new_srv.create(&()).await.unwrap());
251        let res = srv.call(()).await;
252        assert!(res.is_err());
253        assert_eq!(res.err().unwrap(), "error");
254        let _ = format!("{new_srv:?}");
255    }
256}