requiem_service/
map_err.rs

1use std::future::Future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use super::{Service, ServiceFactory};
7
8/// Service for the `map_err` combinator, changing the type of a service's
9/// error.
10///
11/// This is created by the `ServiceExt::map_err` method.
12pub struct MapErr<A, F, E> {
13    service: A,
14    f: F,
15    _t: PhantomData<E>,
16}
17
18impl<A, F, E> MapErr<A, F, E> {
19    /// Create new `MapErr` combinator
20    pub(crate) fn new(service: A, f: F) -> Self
21    where
22        A: Service,
23        F: Fn(A::Error) -> E,
24    {
25        Self {
26            service,
27            f,
28            _t: PhantomData,
29        }
30    }
31}
32
33impl<A, F, E> Clone for MapErr<A, F, E>
34where
35    A: Clone,
36    F: Clone,
37{
38    fn clone(&self) -> Self {
39        MapErr {
40            service: self.service.clone(),
41            f: self.f.clone(),
42            _t: PhantomData,
43        }
44    }
45}
46
47impl<A, F, E> Service for MapErr<A, F, E>
48where
49    A: Service,
50    F: Fn(A::Error) -> E + Clone,
51{
52    type Request = A::Request;
53    type Response = A::Response;
54    type Error = E;
55    type Future = MapErrFuture<A, F, E>;
56
57    fn poll_ready(&mut self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
58        self.service.poll_ready(ctx).map_err(&self.f)
59    }
60
61    fn call(&mut self, req: A::Request) -> Self::Future {
62        MapErrFuture::new(self.service.call(req), self.f.clone())
63    }
64}
65
66#[pin_project::pin_project]
67pub struct MapErrFuture<A, F, E>
68where
69    A: Service,
70    F: Fn(A::Error) -> E,
71{
72    f: F,
73    #[pin]
74    fut: A::Future,
75}
76
77impl<A, F, E> MapErrFuture<A, F, E>
78where
79    A: Service,
80    F: Fn(A::Error) -> E,
81{
82    fn new(fut: A::Future, f: F) -> Self {
83        MapErrFuture { f, fut }
84    }
85}
86
87impl<A, F, E> Future for MapErrFuture<A, F, E>
88where
89    A: Service,
90    F: Fn(A::Error) -> E,
91{
92    type Output = Result<A::Response, E>;
93
94    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
95        let this = self.project();
96        this.fut.poll(cx).map_err(this.f)
97    }
98}
99
100/// Factory for the `map_err` combinator, changing the type of a new
101/// service's error.
102///
103/// This is created by the `NewServiceExt::map_err` method.
104pub struct MapErrServiceFactory<A, F, E>
105where
106    A: ServiceFactory,
107    F: Fn(A::Error) -> E + Clone,
108{
109    a: A,
110    f: F,
111    e: PhantomData<E>,
112}
113
114impl<A, F, E> MapErrServiceFactory<A, F, E>
115where
116    A: ServiceFactory,
117    F: Fn(A::Error) -> E + Clone,
118{
119    /// Create new `MapErr` new service instance
120    pub(crate) fn new(a: A, f: F) -> Self {
121        Self {
122            a,
123            f,
124            e: PhantomData,
125        }
126    }
127}
128
129impl<A, F, E> Clone for MapErrServiceFactory<A, F, E>
130where
131    A: ServiceFactory + Clone,
132    F: Fn(A::Error) -> E + Clone,
133{
134    fn clone(&self) -> Self {
135        Self {
136            a: self.a.clone(),
137            f: self.f.clone(),
138            e: PhantomData,
139        }
140    }
141}
142
143impl<A, F, E> ServiceFactory for MapErrServiceFactory<A, F, E>
144where
145    A: ServiceFactory,
146    F: Fn(A::Error) -> E + Clone,
147{
148    type Request = A::Request;
149    type Response = A::Response;
150    type Error = E;
151
152    type Config = A::Config;
153    type Service = MapErr<A::Service, F, E>;
154    type InitError = A::InitError;
155    type Future = MapErrServiceFuture<A, F, E>;
156
157    fn new_service(&self, cfg: A::Config) -> Self::Future {
158        MapErrServiceFuture::new(self.a.new_service(cfg), self.f.clone())
159    }
160}
161
162#[pin_project::pin_project]
163pub struct MapErrServiceFuture<A, F, E>
164where
165    A: ServiceFactory,
166    F: Fn(A::Error) -> E,
167{
168    #[pin]
169    fut: A::Future,
170    f: F,
171}
172
173impl<A, F, E> MapErrServiceFuture<A, F, E>
174where
175    A: ServiceFactory,
176    F: Fn(A::Error) -> E,
177{
178    fn new(fut: A::Future, f: F) -> Self {
179        MapErrServiceFuture { f, fut }
180    }
181}
182
183impl<A, F, E> Future for MapErrServiceFuture<A, F, E>
184where
185    A: ServiceFactory,
186    F: Fn(A::Error) -> E + Clone,
187{
188    type Output = Result<MapErr<A::Service, F, E>, A::InitError>;
189
190    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
191        let this = self.project();
192        if let Poll::Ready(svc) = this.fut.poll(cx)? {
193            Poll::Ready(Ok(MapErr::new(svc, this.f.clone())))
194        } else {
195            Poll::Pending
196        }
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use futures_util::future::{err, lazy, ok, Ready};
203
204    use super::*;
205    use crate::{IntoServiceFactory, Service, ServiceFactory};
206
207    struct Srv;
208
209    impl Service for Srv {
210        type Request = ();
211        type Response = ();
212        type Error = ();
213        type Future = Ready<Result<(), ()>>;
214
215        fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
216            Poll::Ready(Err(()))
217        }
218
219        fn call(&mut self, _: ()) -> Self::Future {
220            err(())
221        }
222    }
223
224    #[actix_rt::test]
225    async fn test_poll_ready() {
226        let mut srv = Srv.map_err(|_| "error");
227        let res = lazy(|cx| srv.poll_ready(cx)).await;
228        assert_eq!(res, Poll::Ready(Err("error")));
229    }
230
231    #[actix_rt::test]
232    async fn test_call() {
233        let mut srv = Srv.map_err(|_| "error");
234        let res = srv.call(()).await;
235        assert!(res.is_err());
236        assert_eq!(res.err().unwrap(), "error");
237    }
238
239    #[actix_rt::test]
240    async fn test_new_service() {
241        let new_srv = (|| ok::<_, ()>(Srv)).into_factory().map_err(|_| "error");
242        let mut srv = new_srv.new_service(&()).await.unwrap();
243        let res = srv.call(()).await;
244        assert!(res.is_err());
245        assert_eq!(res.err().unwrap(), "error");
246    }
247}