1use std::{fmt, marker::PhantomData, task::Context};
2
3use super::{Service, ServiceCtx, ServiceFactory};
4
5pub struct MapErr<A, F, E> {
10 service: A,
11 f: F,
12 _t: PhantomData<E>,
13}
14
15impl<A, F, E> MapErr<A, F, E> {
16 pub(crate) fn new<R>(service: A, f: F) -> Self
18 where
19 A: Service<R>,
20 F: Fn(A::Error) -> E,
21 {
22 Self {
23 service,
24 f,
25 _t: PhantomData,
26 }
27 }
28}
29
30impl<A, F, E> Clone for MapErr<A, F, E>
31where
32 A: Clone,
33 F: Clone,
34{
35 #[inline]
36 fn clone(&self) -> Self {
37 MapErr {
38 service: self.service.clone(),
39 f: self.f.clone(),
40 _t: PhantomData,
41 }
42 }
43}
44
45impl<A, F, E> fmt::Debug for MapErr<A, F, E>
46where
47 A: fmt::Debug,
48{
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 f.debug_struct("MapErr")
51 .field("svc", &self.service)
52 .field("map", &std::any::type_name::<F>())
53 .finish()
54 }
55}
56
57impl<A, R, F, E> Service<R> for MapErr<A, F, E>
58where
59 A: Service<R>,
60 F: Fn(A::Error) -> E,
61{
62 type Response = A::Response;
63 type Error = E;
64
65 #[inline]
66 async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
67 ctx.ready(&self.service).await.map_err(&self.f)
68 }
69
70 #[inline]
71 fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
72 self.service.poll(cx).map_err(&self.f)
73 }
74
75 #[inline]
76 async fn call(
77 &self,
78 req: R,
79 ctx: ServiceCtx<'_, Self>,
80 ) -> Result<Self::Response, Self::Error> {
81 ctx.call(&self.service, req).await.map_err(|e| (self.f)(e))
82 }
83
84 crate::forward_shutdown!(service);
85}
86
87pub struct MapErrFactory<A, R, C, F, E>
92where
93 A: ServiceFactory<R, C>,
94 F: Fn(A::Error) -> E + Clone,
95{
96 a: A,
97 f: F,
98 e: PhantomData<fn(R, C) -> E>,
99}
100
101impl<A, R, C, F, E> MapErrFactory<A, R, C, F, E>
102where
103 A: ServiceFactory<R, C>,
104 F: Fn(A::Error) -> E + Clone,
105{
106 pub(crate) fn new(a: A, f: F) -> Self {
108 Self {
109 a,
110 f,
111 e: PhantomData,
112 }
113 }
114}
115
116impl<A, R, C, F, E> Clone for MapErrFactory<A, R, C, F, E>
117where
118 A: ServiceFactory<R, C> + Clone,
119 F: Fn(A::Error) -> E + Clone,
120{
121 fn clone(&self) -> Self {
122 Self {
123 a: self.a.clone(),
124 f: self.f.clone(),
125 e: PhantomData,
126 }
127 }
128}
129
130impl<A, R, C, F, E> fmt::Debug for MapErrFactory<A, R, C, F, E>
131where
132 A: ServiceFactory<R, C> + fmt::Debug,
133 F: Fn(A::Error) -> E + Clone,
134{
135 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136 f.debug_struct("MapErrFactory")
137 .field("factory", &self.a)
138 .field("map", &std::any::type_name::<F>())
139 .finish()
140 }
141}
142
143impl<A, R, C, F, E> ServiceFactory<R, C> for MapErrFactory<A, R, C, F, E>
144where
145 A: ServiceFactory<R, C>,
146 F: Fn(A::Error) -> E + Clone,
147{
148 type Response = A::Response;
149 type Error = E;
150
151 type Service = MapErr<A::Service, F, E>;
152 type InitError = A::InitError;
153
154 #[inline]
155 async fn create(&self, cfg: C) -> Result<Self::Service, Self::InitError> {
156 self.a.create(cfg).await.map(|service| MapErr {
157 service,
158 f: self.f.clone(),
159 _t: PhantomData,
160 })
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use std::{cell::Cell, rc::Rc};
167
168 use super::*;
169 use crate::{fn_factory, Pipeline};
170
171 #[derive(Debug, Clone)]
172 struct Srv(bool, Rc<Cell<usize>>);
173
174 impl Service<()> for Srv {
175 type Response = ();
176 type Error = ();
177
178 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
179 if self.0 {
180 Err(())
181 } else {
182 Ok(())
183 }
184 }
185
186 async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
187 Err(())
188 }
189
190 async fn shutdown(&self) {
191 self.1.set(self.1.get() + 1);
192 }
193 }
194
195 #[ntex::test]
196 async fn test_ready() {
197 let cnt_sht = Rc::new(Cell::new(0));
198 let srv = Pipeline::new(Srv(true, cnt_sht.clone()).map_err(|_| "error"));
199 let res = srv.ready().await;
200 assert_eq!(res, Err("error"));
201
202 srv.shutdown().await;
203 assert_eq!(cnt_sht.get(), 1);
204 }
205
206 #[ntex::test]
207 async fn test_service() {
208 let srv = Pipeline::new(
209 Srv(false, Rc::new(Cell::new(0)))
210 .map_err(|_| "error")
211 .clone(),
212 );
213 let res = srv.call(()).await;
214 assert!(res.is_err());
215 assert_eq!(res.err().unwrap(), "error");
216
217 let _ = format!("{:?}", srv);
218 }
219
220 #[ntex::test]
221 async fn test_pipeline() {
222 let srv = Pipeline::new(
223 crate::chain(Srv(false, Rc::new(Cell::new(0))))
224 .map_err(|_| "error")
225 .clone(),
226 );
227 let res = srv.call(()).await;
228 assert!(res.is_err());
229 assert_eq!(res.err().unwrap(), "error");
230
231 let _ = format!("{:?}", srv);
232 }
233
234 #[ntex::test]
235 async fn test_factory() {
236 let new_srv =
237 fn_factory(|| async { Ok::<_, ()>(Srv(false, Rc::new(Cell::new(0)))) })
238 .map_err(|_| "error")
239 .clone();
240 let srv = Pipeline::new(new_srv.create(&()).await.unwrap());
241 let res = srv.call(()).await;
242 assert!(res.is_err());
243 assert_eq!(res.err().unwrap(), "error");
244 let _ = format!("{:?}", new_srv);
245 }
246
247 #[ntex::test]
248 async fn test_pipeline_factory() {
249 let new_srv = crate::chain_factory(fn_factory(|| async {
250 Ok::<Srv, ()>(Srv(false, Rc::new(Cell::new(0))))
251 }))
252 .map_err(|_| "error")
253 .clone();
254 let srv = Pipeline::new(new_srv.create(&()).await.unwrap());
255 let res = srv.call(()).await;
256 assert!(res.is_err());
257 assert_eq!(res.err().unwrap(), "error");
258 let _ = format!("{:?}", new_srv);
259 }
260}