async_retry/
retry.rs

1use alloc::{boxed::Box, vec, vec::Vec};
2use core::{
3    fmt,
4    future::Future,
5    marker::PhantomData,
6    ops::ControlFlow,
7    pin::Pin,
8    task::{Context, Poll},
9};
10
11use async_sleep::{sleep, Sleepble};
12use futures_util::{future::FusedFuture, FutureExt as _};
13use pin_project_lite::pin_project;
14use retry_policy::RetryPolicy;
15
16use crate::error::Error;
17
18//
19type RetryFutureRepeater<T, E> =
20    Box<dyn FnMut() -> Pin<Box<dyn Future<Output = Result<T, E>> + Send>> + Send>;
21
22//
23pin_project! {
24    pub struct Retry<SLEEP, POL, T, E> {
25        policy: POL,
26        future_repeater: RetryFutureRepeater<T, E>,
27        //
28        state: State<T, E>,
29        attempts: usize,
30        errors: Option<Vec<E>>,
31        //
32        phantom: PhantomData<(SLEEP, T, E)>,
33    }
34}
35
36impl<SLEEP, POL, T, E> fmt::Debug for Retry<SLEEP, POL, T, E>
37where
38    POL: fmt::Debug,
39{
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        f.debug_struct("Retry")
42            .field("policy", &self.policy)
43            .field("future_repeater", &"")
44            .finish()
45    }
46}
47
48impl<SLEEP, POL, T, E> Retry<SLEEP, POL, T, E> {
49    pub(crate) fn new(policy: POL, future_repeater: RetryFutureRepeater<T, E>) -> Self {
50        Self {
51            policy,
52            future_repeater,
53            //
54            state: State::Pending,
55            attempts: 0,
56            errors: Some(vec![]),
57            //
58            phantom: PhantomData,
59        }
60    }
61}
62
63//
64enum State<T, E> {
65    Pending,
66    Fut(Pin<Box<dyn Future<Output = Result<T, E>> + Send>>),
67    Sleep(Pin<Box<dyn Future<Output = ()> + Send>>),
68    Done,
69}
70impl<T, E> fmt::Debug for State<T, E> {
71    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72        match self {
73            State::Pending => write!(f, "Pending"),
74            State::Fut(_) => write!(f, "Fut"),
75            State::Sleep(_) => write!(f, "Sleep"),
76            State::Done => write!(f, "Done"),
77        }
78    }
79}
80
81//
82pub fn retry<SLEEP, POL, F, Fut, T, E>(policy: POL, future_repeater: F) -> Retry<SLEEP, POL, T, E>
83where
84    SLEEP: Sleepble + 'static,
85    POL: RetryPolicy<E>,
86    F: Fn() -> Fut + Send + 'static,
87    Fut: Future<Output = Result<T, E>> + Send + 'static,
88{
89    Retry::new(policy, Box::new(move || Box::pin(future_repeater())))
90}
91
92//
93impl<SLEEP, POL, T, E> FusedFuture for Retry<SLEEP, POL, T, E>
94where
95    SLEEP: Sleepble + 'static,
96    POL: RetryPolicy<E>,
97{
98    fn is_terminated(&self) -> bool {
99        matches!(self.state, State::Done)
100    }
101}
102
103//
104impl<SLEEP, POL, T, E> Future for Retry<SLEEP, POL, T, E>
105where
106    SLEEP: Sleepble + 'static,
107    POL: RetryPolicy<E>,
108{
109    type Output = Result<T, Error<E>>;
110
111    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
112        let this = self.project();
113
114        loop {
115            match this.state {
116                State::Pending => {
117                    let future = (this.future_repeater)();
118
119                    //
120                    *this.state = State::Fut(future);
121
122                    continue;
123                }
124                State::Fut(future) => {
125                    match future.poll_unpin(cx) {
126                        Poll::Ready(Ok(x)) => {
127                            //
128                            *this.state = State::Done;
129                            *this.attempts = 0;
130                            *this.errors = Some(Vec::new());
131
132                            break Poll::Ready(Ok(x));
133                        }
134                        Poll::Ready(Err(err)) => {
135                            //
136                            *this.attempts += 1;
137
138                            //
139                            let ret = this.policy.next_step(&err, *this.attempts);
140
141                            //
142                            if let Some(errors) = this.errors.as_mut() {
143                                errors.push(err)
144                            } else {
145                                unreachable!()
146                            }
147
148                            match ret {
149                                ControlFlow::Continue(dur) => {
150                                    //
151                                    *this.state = State::Sleep(Box::pin(sleep::<SLEEP>(dur)));
152
153                                    continue;
154                                }
155                                ControlFlow::Break(stop_reason) => {
156                                    let errors = this.errors.take().expect("unreachable!()");
157
158                                    //
159                                    *this.state = State::Done;
160                                    *this.attempts = 0;
161                                    *this.errors = Some(Vec::new());
162
163                                    break Poll::Ready(Err(Error::new(stop_reason, errors)));
164                                }
165                            }
166                        }
167                        Poll::Pending => {
168                            break Poll::Pending;
169                        }
170                    }
171                }
172                State::Sleep(future) => match future.poll_unpin(cx) {
173                    Poll::Ready(_) => {
174                        //
175                        *this.state = State::Pending;
176
177                        continue;
178                    }
179                    Poll::Pending => {
180                        break Poll::Pending;
181                    }
182                },
183                State::Done => panic!("cannot poll Retry twice"),
184            }
185        }
186    }
187}
188
189#[cfg(feature = "std")]
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    use core::{
195        sync::atomic::{AtomicUsize, Ordering},
196        time::Duration,
197    };
198
199    use async_sleep::impl_tokio::Sleep;
200    use once_cell::sync::Lazy;
201    use retry_policy::{
202        policies::SimplePolicy,
203        retry_backoff::backoffs::FnBackoff,
204        retry_predicate::predicates::{AlwaysPredicate, FnPredicate},
205        StopReason,
206    };
207
208    #[tokio::test]
209    async fn test_retry_with_max_retries_reached() {
210        #[derive(Debug, PartialEq)]
211        struct FError(usize);
212        async fn f(n: usize) -> Result<(), FError> {
213            Err(FError(n))
214        }
215
216        //
217        let policy = SimplePolicy::new(
218            AlwaysPredicate,
219            3,
220            FnBackoff::from(|_| Duration::from_millis(100)),
221        );
222
223        //
224        let now = std::time::Instant::now();
225
226        match retry::<Sleep, _, _, _, _, _>(policy, || f(0)).await {
227            Ok(_) => panic!(""),
228            Err(err) => {
229                assert_eq!(&err.stop_reason, &StopReason::MaxRetriesReached);
230                assert_eq!(err.errors(), &[FError(0), FError(0), FError(0), FError(0)]);
231            }
232        }
233
234        let elapsed_dur = now.elapsed();
235        assert!(elapsed_dur.as_millis() >= 300 && elapsed_dur.as_millis() <= 305);
236    }
237
238    #[tokio::test]
239    async fn test_retry_with_max_retries_reached_for_tokio_spawn() {
240        #[derive(Debug, PartialEq)]
241        struct FError(usize);
242        async fn f(n: usize) -> Result<(), FError> {
243            Err(FError(n))
244        }
245
246        //
247        let policy = SimplePolicy::new(
248            AlwaysPredicate,
249            3,
250            FnBackoff::from(|_| Duration::from_millis(100)),
251        );
252
253        //
254        tokio::spawn(async move {
255            let now = std::time::Instant::now();
256
257            match retry::<Sleep, _, _, _, _, _>(policy, || f(0)).await {
258                Ok(_) => panic!(""),
259                Err(err) => {
260                    assert_eq!(&err.stop_reason, &StopReason::MaxRetriesReached);
261                    assert_eq!(err.errors(), &[FError(0), FError(0), FError(0), FError(0)]);
262                }
263            }
264
265            let elapsed_dur = now.elapsed();
266            assert!(elapsed_dur.as_millis() >= 300 && elapsed_dur.as_millis() <= 305);
267        });
268    }
269
270    #[tokio::test]
271    async fn test_retry_with_predicate_failed() {
272        #[derive(Debug, PartialEq)]
273        struct FError(usize);
274        async fn f(n: usize) -> Result<(), FError> {
275            Err(FError(n))
276        }
277
278        //
279        static N: Lazy<AtomicUsize> = Lazy::new(|| AtomicUsize::new(0));
280
281        let policy = SimplePolicy::new(
282            FnPredicate::from(|FError(n): &FError| [0, 1].contains(n)),
283            3,
284            FnBackoff::from(|_| Duration::from_millis(100)),
285        );
286
287        //
288        let now = std::time::Instant::now();
289
290        match retry::<Sleep, _, _, _, _, _>(policy, || f(N.fetch_add(1, Ordering::SeqCst))).await {
291            Ok(_) => panic!(""),
292            Err(err) => {
293                assert_eq!(&err.stop_reason, &StopReason::PredicateFailed);
294                assert_eq!(err.errors(), &[FError(0), FError(1), FError(2)]);
295            }
296        }
297
298        let elapsed_dur = now.elapsed();
299        assert!(elapsed_dur.as_millis() >= 200 && elapsed_dur.as_millis() <= 205);
300    }
301}