1use std::{fmt, marker::PhantomData, task::Context};
2
3use super::{Service, ServiceCtx, ServiceFactory};
4
5pub struct MapErr<A, F, E> {
10 service: A,
11 f: F,
12 _t: PhantomData<E>,
13}
14
15impl<A, F, E> MapErr<A, F, E> {
16 pub(crate) fn new<R>(service: A, f: F) -> Self
18 where
19 A: Service<R>,
20 F: Fn(A::Error) -> E,
21 {
22 Self {
23 service,
24 f,
25 _t: PhantomData,
26 }
27 }
28}
29
30impl<A, F, E> Clone for MapErr<A, F, E>
31where
32 A: Clone,
33 F: Clone,
34{
35 #[inline]
36 fn clone(&self) -> Self {
37 MapErr {
38 service: self.service.clone(),
39 f: self.f.clone(),
40 _t: PhantomData,
41 }
42 }
43}
44
45impl<A, F, E> fmt::Debug for MapErr<A, F, E>
46where
47 A: fmt::Debug,
48{
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 f.debug_struct("MapErr")
51 .field("svc", &self.service)
52 .field("map", &std::any::type_name::<F>())
53 .finish()
54 }
55}
56
57impl<A, R, F, E> Service<R> for MapErr<A, F, E>
58where
59 A: Service<R>,
60 F: Fn(A::Error) -> E,
61{
62 type Response = A::Response;
63 type Error = E;
64
65 #[inline]
66 async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
67 ctx.ready(&self.service).await.map_err(&self.f)
68 }
69
70 #[inline]
71 fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
72 self.service.poll(cx).map_err(&self.f)
73 }
74
75 #[inline]
76 async fn call(
77 &self,
78 req: R,
79 ctx: ServiceCtx<'_, Self>,
80 ) -> Result<Self::Response, Self::Error> {
81 ctx.call(&self.service, req).await.map_err(|e| (self.f)(e))
82 }
83
84 crate::forward_shutdown!(service);
85}
86
87pub struct MapErrFactory<A, R, C, F, E>
92where
93 A: ServiceFactory<R, C>,
94 F: Fn(A::Error) -> E + Clone,
95{
96 a: A,
97 f: F,
98 e: PhantomData<fn(R, C) -> E>,
99}
100
101impl<A, R, C, F, E> MapErrFactory<A, R, C, F, E>
102where
103 A: ServiceFactory<R, C>,
104 F: Fn(A::Error) -> E + Clone,
105{
106 pub(crate) fn new(a: A, f: F) -> Self {
108 Self {
109 a,
110 f,
111 e: PhantomData,
112 }
113 }
114}
115
116impl<A, R, C, F, E> Clone for MapErrFactory<A, R, C, F, E>
117where
118 A: ServiceFactory<R, C> + Clone,
119 F: Fn(A::Error) -> E + Clone,
120{
121 fn clone(&self) -> Self {
122 Self {
123 a: self.a.clone(),
124 f: self.f.clone(),
125 e: PhantomData,
126 }
127 }
128}
129
130impl<A, R, C, F, E> fmt::Debug for MapErrFactory<A, R, C, F, E>
131where
132 A: ServiceFactory<R, C> + fmt::Debug,
133 F: Fn(A::Error) -> E + Clone,
134{
135 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136 f.debug_struct("MapErrFactory")
137 .field("factory", &self.a)
138 .field("map", &std::any::type_name::<F>())
139 .finish()
140 }
141}
142
143impl<A, R, C, F, E> ServiceFactory<R, C> for MapErrFactory<A, R, C, F, E>
144where
145 A: ServiceFactory<R, C>,
146 F: Fn(A::Error) -> E + Clone,
147{
148 type Response = A::Response;
149 type Error = E;
150
151 type Service = MapErr<A::Service, F, E>;
152 type InitError = A::InitError;
153
154 #[inline]
155 async fn create(&self, cfg: C) -> Result<Self::Service, Self::InitError> {
156 self.a.create(cfg).await.map(|service| MapErr {
157 service,
158 f: self.f.clone(),
159 _t: PhantomData,
160 })
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use std::{cell::Cell, rc::Rc};
167
168 use super::*;
169 use crate::{Pipeline, fn_factory};
170
171 #[derive(Debug, Clone)]
172 struct Srv(bool, Rc<Cell<usize>>);
173
174 impl Service<()> for Srv {
175 type Response = ();
176 type Error = ();
177
178 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
179 if self.0 { Err(()) } else { Ok(()) }
180 }
181
182 async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
183 Err(())
184 }
185
186 async fn shutdown(&self) {
187 self.1.set(self.1.get() + 1);
188 }
189 }
190
191 #[ntex::test]
192 async fn test_ready() {
193 let cnt_sht = Rc::new(Cell::new(0));
194 let srv = Pipeline::new(Srv(true, cnt_sht.clone()).map_err(|_| "error"));
195 let res = srv.ready().await;
196 assert_eq!(res, Err("error"));
197
198 srv.shutdown().await;
199 assert_eq!(cnt_sht.get(), 1);
200 }
201
202 #[ntex::test]
203 async fn test_service() {
204 let srv = Pipeline::new(
205 Srv(false, Rc::new(Cell::new(0)))
206 .map_err(|_| "error")
207 .clone(),
208 );
209 let res = srv.call(()).await;
210 assert!(res.is_err());
211 assert_eq!(res.err().unwrap(), "error");
212
213 let _ = format!("{srv:?}");
214 }
215
216 #[ntex::test]
217 async fn test_pipeline() {
218 let srv = Pipeline::new(
219 crate::chain(Srv(false, Rc::new(Cell::new(0))))
220 .map_err(|_| "error")
221 .clone(),
222 );
223 let res = srv.call(()).await;
224 assert!(res.is_err());
225 assert_eq!(res.err().unwrap(), "error");
226
227 let _ = format!("{srv:?}");
228 }
229
230 #[ntex::test]
231 async fn test_factory() {
232 let new_srv =
233 fn_factory(|| async { Ok::<_, ()>(Srv(false, Rc::new(Cell::new(0)))) })
234 .map_err(|_| "error")
235 .clone();
236 let srv = Pipeline::new(new_srv.create(&()).await.unwrap());
237 let res = srv.call(()).await;
238 assert!(res.is_err());
239 assert_eq!(res.err().unwrap(), "error");
240 let _ = format!("{new_srv:?}");
241 }
242
243 #[ntex::test]
244 async fn test_pipeline_factory() {
245 let new_srv = crate::chain_factory(fn_factory(|| async {
246 Ok::<Srv, ()>(Srv(false, Rc::new(Cell::new(0))))
247 }))
248 .map_err(|_| "error")
249 .clone();
250 let srv = Pipeline::new(new_srv.create(&()).await.unwrap());
251 let res = srv.call(()).await;
252 assert!(res.is_err());
253 assert_eq!(res.err().unwrap(), "error");
254 let _ = format!("{new_srv:?}");
255 }
256}