ntex_util/services/
retry.rs

1#![allow(async_fn_in_trait)]
2use ntex_service::{Middleware, Service, ServiceCtx};
3
4/// Trait defines retry policy
5pub trait Policy<Req, S: Service<Req>>: Sized + Clone {
6    async fn retry(&mut self, req: &Req, res: &Result<S::Response, S::Error>) -> bool;
7
8    fn clone_request(&self, req: &Req) -> Option<Req>;
9}
10
11#[derive(Clone, Debug)]
12/// Retry middleware
13///
14/// Retry middleware allows to retry service call
15pub struct Retry<P> {
16    policy: P,
17}
18
19#[derive(Clone, Debug)]
20/// Retry service
21///
22/// Retry service allows to retry service call
23pub struct RetryService<P, S> {
24    policy: P,
25    service: S,
26}
27
28impl<P> Retry<P> {
29    /// Create retry middleware
30    pub fn new(policy: P) -> Self {
31        Retry { policy }
32    }
33}
34
35impl<P: Clone, S> Middleware<S> for Retry<P> {
36    type Service = RetryService<P, S>;
37
38    fn create(&self, service: S) -> Self::Service {
39        RetryService {
40            service,
41            policy: self.policy.clone(),
42        }
43    }
44}
45
46impl<P, S> RetryService<P, S> {
47    /// Create retry service
48    pub fn new(policy: P, service: S) -> Self {
49        RetryService { policy, service }
50    }
51}
52
53impl<P, S, R> Service<R> for RetryService<P, S>
54where
55    P: Policy<R, S>,
56    S: Service<R>,
57{
58    type Response = S::Response;
59    type Error = S::Error;
60
61    ntex_service::forward_poll!(service);
62    ntex_service::forward_ready!(service);
63    ntex_service::forward_shutdown!(service);
64
65    async fn call(
66        &self,
67        mut request: R,
68        ctx: ServiceCtx<'_, Self>,
69    ) -> Result<S::Response, S::Error> {
70        let mut policy = self.policy.clone();
71        let mut cloned = policy.clone_request(&request);
72
73        loop {
74            let result = ctx.call(&self.service, request).await;
75
76            cloned = if let Some(req) = cloned.take() {
77                if policy.retry(&req, &result).await {
78                    request = req;
79                    policy.clone_request(&request)
80                } else {
81                    return result;
82                }
83            } else {
84                return result;
85            }
86        }
87    }
88}
89
90#[derive(Copy, Clone, Debug)]
91/// Default retry policy
92///
93/// This policy retries on any error. By default retry count is 3
94pub struct DefaultRetryPolicy(u16);
95
96impl DefaultRetryPolicy {
97    /// Create default retry policy
98    pub fn new(retry: u16) -> Self {
99        DefaultRetryPolicy(retry)
100    }
101}
102
103impl Default for DefaultRetryPolicy {
104    fn default() -> Self {
105        DefaultRetryPolicy::new(3)
106    }
107}
108
109impl<R, S> Policy<R, S> for DefaultRetryPolicy
110where
111    R: Clone,
112    S: Service<R>,
113{
114    async fn retry(&mut self, _: &R, res: &Result<S::Response, S::Error>) -> bool {
115        if res.is_err() {
116            if self.0 == 0 {
117                false
118            } else {
119                self.0 -= 1;
120                true
121            }
122        } else {
123            false
124        }
125    }
126
127    fn clone_request(&self, req: &R) -> Option<R> {
128        Some(req.clone())
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use std::{cell::Cell, rc::Rc};
135
136    use ntex_service::{apply, fn_factory, Pipeline, ServiceFactory};
137
138    use super::*;
139
140    #[derive(Clone, Debug, PartialEq)]
141    struct TestService(Rc<Cell<usize>>);
142
143    impl Service<()> for TestService {
144        type Response = ();
145        type Error = ();
146
147        async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
148            let cnt = self.0.get();
149            if cnt == 0 {
150                Ok(())
151            } else {
152                self.0.set(cnt - 1);
153                Err(())
154            }
155        }
156    }
157
158    #[ntex_macros::rt_test2]
159    async fn test_retry() {
160        let cnt = Rc::new(Cell::new(5));
161        let svc = Pipeline::new(
162            RetryService::new(DefaultRetryPolicy::default(), TestService(cnt.clone()))
163                .clone(),
164        );
165        assert_eq!(svc.call(()).await, Err(()));
166        assert_eq!(svc.ready().await, Ok(()));
167        svc.shutdown().await;
168        assert_eq!(cnt.get(), 1);
169
170        let factory = apply(
171            Retry::new(DefaultRetryPolicy::new(3)).clone(),
172            fn_factory(|| async { Ok::<_, ()>(TestService(Rc::new(Cell::new(2)))) }),
173        );
174        let srv = factory.pipeline(&()).await.unwrap();
175        assert_eq!(srv.call(()).await, Ok(()));
176    }
177}