ntex_util/services/
retry.rs

1#![allow(async_fn_in_trait)]
2use ntex_service::{Middleware, Middleware2, Service, ServiceCtx};
3
4/// Trait defines retry policy
5pub trait Policy<Req, S: Service<Req>>: Sized + Clone {
6    async fn retry(&mut self, req: &Req, res: &Result<S::Response, S::Error>) -> bool;
7
8    fn clone_request(&self, req: &Req) -> Option<Req>;
9}
10
11#[derive(Clone, Debug)]
12/// Retry middleware
13///
14/// Retry middleware allows to retry service call
15pub struct Retry<P> {
16    policy: P,
17}
18
19#[derive(Clone, Debug)]
20/// Retry service
21///
22/// Retry service allows to retry service call
23pub struct RetryService<P, S> {
24    policy: P,
25    service: S,
26}
27
28impl<P> Retry<P> {
29    /// Create retry middleware
30    pub fn new(policy: P) -> Self {
31        Retry { policy }
32    }
33}
34
35impl<P: Clone, S> Middleware<S> for Retry<P> {
36    type Service = RetryService<P, S>;
37
38    fn create(&self, service: S) -> Self::Service {
39        RetryService {
40            service,
41            policy: self.policy.clone(),
42        }
43    }
44}
45
46impl<P: Clone, S, C> Middleware2<S, C> for Retry<P> {
47    type Service = RetryService<P, S>;
48
49    fn create(&self, service: S, _: C) -> Self::Service {
50        RetryService {
51            service,
52            policy: self.policy.clone(),
53        }
54    }
55}
56
57impl<P, S> RetryService<P, S> {
58    /// Create retry service
59    pub fn new(policy: P, service: S) -> Self {
60        RetryService { policy, service }
61    }
62}
63
64impl<P, S, R> Service<R> for RetryService<P, S>
65where
66    P: Policy<R, S>,
67    S: Service<R>,
68{
69    type Response = S::Response;
70    type Error = S::Error;
71
72    ntex_service::forward_poll!(service);
73    ntex_service::forward_ready!(service);
74    ntex_service::forward_shutdown!(service);
75
76    async fn call(
77        &self,
78        mut request: R,
79        ctx: ServiceCtx<'_, Self>,
80    ) -> Result<S::Response, S::Error> {
81        let mut policy = self.policy.clone();
82        let mut cloned = policy.clone_request(&request);
83
84        loop {
85            let result = ctx.call(&self.service, request).await;
86
87            cloned = if let Some(req) = cloned.take() {
88                if policy.retry(&req, &result).await {
89                    request = req;
90                    policy.clone_request(&request)
91                } else {
92                    return result;
93                }
94            } else {
95                return result;
96            }
97        }
98    }
99}
100
101#[derive(Copy, Clone, Debug)]
102/// Default retry policy
103///
104/// This policy retries on any error. By default retry count is 3
105pub struct DefaultRetryPolicy(u16);
106
107impl DefaultRetryPolicy {
108    /// Create default retry policy
109    pub fn new(retry: u16) -> Self {
110        DefaultRetryPolicy(retry)
111    }
112}
113
114impl Default for DefaultRetryPolicy {
115    fn default() -> Self {
116        DefaultRetryPolicy::new(3)
117    }
118}
119
120impl<R, S> Policy<R, S> for DefaultRetryPolicy
121where
122    R: Clone,
123    S: Service<R>,
124{
125    async fn retry(&mut self, _: &R, res: &Result<S::Response, S::Error>) -> bool {
126        if res.is_err() {
127            if self.0 == 0 {
128                false
129            } else {
130                self.0 -= 1;
131                true
132            }
133        } else {
134            false
135        }
136    }
137
138    fn clone_request(&self, req: &R) -> Option<R> {
139        Some(req.clone())
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use std::{cell::Cell, rc::Rc};
146
147    use ntex_service::{Pipeline, ServiceFactory, apply, apply2, fn_factory};
148
149    use super::*;
150
151    #[derive(Clone, Debug, PartialEq)]
152    struct TestService(Rc<Cell<usize>>);
153
154    impl Service<()> for TestService {
155        type Response = ();
156        type Error = ();
157
158        async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
159            let cnt = self.0.get();
160            if cnt == 0 {
161                Ok(())
162            } else {
163                self.0.set(cnt - 1);
164                Err(())
165            }
166        }
167    }
168
169    #[ntex::test]
170    async fn test_retry() {
171        let cnt = Rc::new(Cell::new(5));
172        let svc = Pipeline::new(
173            RetryService::new(DefaultRetryPolicy::default(), TestService(cnt.clone()))
174                .clone(),
175        );
176        assert_eq!(svc.call(()).await, Err(()));
177        assert_eq!(svc.ready().await, Ok(()));
178        svc.shutdown().await;
179        assert_eq!(cnt.get(), 1);
180
181        let factory = apply(
182            Retry::new(DefaultRetryPolicy::new(3)).clone(),
183            fn_factory(|| async { Ok::<_, ()>(TestService(Rc::new(Cell::new(2)))) }),
184        );
185        let srv = factory.pipeline(&()).await.unwrap();
186        assert_eq!(srv.call(()).await, Ok(()));
187
188        let factory = apply2(
189            Retry::new(DefaultRetryPolicy::new(3)).clone(),
190            fn_factory(|| async { Ok::<_, ()>(TestService(Rc::new(Cell::new(2)))) }),
191        );
192        let srv = factory.pipeline(&()).await.unwrap();
193        assert_eq!(srv.call(()).await, Ok(()));
194    }
195}