async_retry/
retry_with_timeout.rs

1use alloc::boxed::Box;
2use core::{convert::Infallible, fmt, future::Future, time::Duration};
3
4use async_sleep::{
5    timeout::{timeout, Error as TimeoutError},
6    Sleepble,
7};
8use futures_util::TryFutureExt as _;
9use retry_policy::{retry_predicate::RetryPredicate, RetryPolicy};
10
11use crate::retry::Retry;
12
13//
14pub fn retry_with_timeout<SLEEP, POL, F, Fut, T, E>(
15    policy: POL,
16    future_repeater: F,
17    every_performance_timeout_dur: Duration,
18) -> Retry<SLEEP, POL, T, ErrorWrapper<E>>
19where
20    SLEEP: Sleepble + 'static,
21    POL: RetryPolicy<ErrorWrapper<E>>,
22    F: Fn() -> Fut + Send + 'static,
23    Fut: Future<Output = Result<T, E>> + Send + 'static,
24{
25    Retry::<SLEEP, _, _, _>::new(
26        policy,
27        Box::new(move || {
28            let fut = future_repeater();
29            Box::pin(
30                timeout::<SLEEP, _>(every_performance_timeout_dur, Box::pin(fut)).map_ok_or_else(
31                    |err| Err(ErrorWrapper::Timeout(err)),
32                    |ret| match ret {
33                        Ok(x) => Ok(x),
34                        Err(err) => Err(ErrorWrapper::Inner(err)),
35                    },
36                ),
37            )
38        }),
39    )
40}
41
42//
43pub fn retry_with_timeout_for_non_logic_error<SLEEP, POL, F, Fut, T>(
44    policy: POL,
45    future_repeater: F,
46    every_performance_timeout_dur: Duration,
47) -> Retry<SLEEP, POL, T, ErrorWrapper<Infallible>>
48where
49    SLEEP: Sleepble + 'static,
50    POL: RetryPolicy<ErrorWrapper<Infallible>>,
51    F: Fn() -> Fut + Send + 'static,
52    Fut: Future<Output = T> + Send + 'static,
53{
54    Retry::<SLEEP, _, _, _>::new(
55        policy,
56        Box::new(move || {
57            let fut = future_repeater();
58            Box::pin(
59                timeout::<SLEEP, _>(every_performance_timeout_dur, Box::pin(fut))
60                    .map_ok_or_else(|err| Err(ErrorWrapper::Timeout(err)), |x| Ok(x)),
61            )
62        }),
63    )
64}
65
66//
67//
68//
69pub enum ErrorWrapper<T> {
70    Inner(T),
71    Timeout(TimeoutError),
72}
73
74impl<T> fmt::Debug for ErrorWrapper<T>
75where
76    T: fmt::Debug,
77{
78    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79        match self {
80            ErrorWrapper::Inner(err) => f.debug_tuple("ErrorWrapper::Inner").field(err).finish(),
81            ErrorWrapper::Timeout(err) => {
82                f.debug_tuple("ErrorWrapper::Timeout").field(err).finish()
83            }
84        }
85    }
86}
87
88impl<T> fmt::Display for ErrorWrapper<T>
89where
90    T: fmt::Debug,
91{
92    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93        write!(f, "{self:?}")
94    }
95}
96
97#[cfg(feature = "std")]
98impl<T> std::error::Error for ErrorWrapper<T> where T: fmt::Debug {}
99
100impl<T> ErrorWrapper<T> {
101    pub fn is_inner(&self) -> bool {
102        matches!(self, Self::Inner(_))
103    }
104
105    pub fn is_timeout(&self) -> bool {
106        matches!(self, Self::Timeout(_))
107    }
108
109    pub fn into_inner(self) -> Option<T> {
110        match self {
111            Self::Inner(x) => Some(x),
112            Self::Timeout(_) => None,
113        }
114    }
115}
116
117//
118//
119//
120pub struct PredicateWrapper<T> {
121    inner: T,
122}
123
124impl<T> fmt::Debug for PredicateWrapper<T>
125where
126    T: fmt::Debug,
127{
128    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129        f.debug_struct("PredicateWrapper")
130            .field("inner", &self.inner)
131            .finish()
132    }
133}
134
135impl<T> PredicateWrapper<T> {
136    pub fn new(inner: T) -> Self {
137        Self { inner }
138    }
139}
140
141impl<E, P> RetryPredicate<ErrorWrapper<E>> for PredicateWrapper<P>
142where
143    P: RetryPredicate<E>,
144{
145    fn test(&self, params: &ErrorWrapper<E>) -> bool {
146        match params {
147            ErrorWrapper::Inner(inner_params) => self.inner.test(inner_params),
148            ErrorWrapper::Timeout(_) => true,
149        }
150    }
151}
152
153#[cfg(feature = "std")]
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    use core::{
159        sync::atomic::{AtomicUsize, Ordering},
160        time::Duration,
161    };
162
163    use async_sleep::impl_tokio::Sleep;
164    use once_cell::sync::Lazy;
165    use retry_policy::{
166        policies::SimplePolicy,
167        retry_backoff::backoffs::FnBackoff,
168        retry_predicate::predicates::{AlwaysPredicate, FnPredicate},
169        StopReason,
170    };
171
172    #[tokio::test]
173    async fn test_retry_with_timeout() {
174        #[derive(Debug, PartialEq)]
175        struct FError(usize);
176        async fn f(n: usize) -> Result<(), FError> {
177            #[allow(clippy::single_match)]
178            match n {
179                1 => tokio::time::sleep(tokio::time::Duration::from_millis(80)).await,
180                _ => {}
181            }
182            Err(FError(n))
183        }
184
185        //
186        static N: Lazy<AtomicUsize> = Lazy::new(|| AtomicUsize::new(0));
187
188        let policy = SimplePolicy::new(
189            PredicateWrapper::new(FnPredicate::from(|FError(n): &FError| [0, 1].contains(n))),
190            3,
191            FnBackoff::from(|_| Duration::from_millis(100)),
192        );
193
194        //
195        let now = std::time::Instant::now();
196
197        match retry_with_timeout::<Sleep, _, _, _, _, _>(
198            policy,
199            || f(N.fetch_add(1, Ordering::SeqCst)),
200            Duration::from_millis(50),
201        )
202        .await
203        {
204            Ok(_) => panic!(""),
205            Err(err) => {
206                assert_eq!(&err.stop_reason, &StopReason::PredicateFailed);
207                for (i, err) in err.errors().iter().enumerate() {
208                    println!("{i} {err:?}");
209                    match i {
210                        0 => match err {
211                            ErrorWrapper::Inner(FError(n)) => {
212                                assert_eq!(*n, 0)
213                            }
214                            err => panic!("{i} {err:?}"),
215                        },
216                        1 => match err {
217                            ErrorWrapper::Timeout(TimeoutError::Timeout(dur)) => {
218                                assert_eq!(*dur, Duration::from_millis(50));
219                            }
220                            err => panic!("{i} {err:?}"),
221                        },
222                        2 => match err {
223                            ErrorWrapper::Inner(FError(n)) => {
224                                assert_eq!(*n, 2)
225                            }
226                            err => panic!("{i} {err:?}"),
227                        },
228                        n => panic!("{n} {err:?}"),
229                    }
230                }
231            }
232        }
233
234        let elapsed_dur = now.elapsed();
235        assert!(elapsed_dur.as_millis() >= 250 && elapsed_dur.as_millis() <= 260);
236    }
237
238    #[tokio::test]
239    async fn test_retry_with_timeout_for_unresult() {
240        async fn f(n: usize) {
241            #[allow(clippy::single_match)]
242            match n {
243                0 => tokio::time::sleep(tokio::time::Duration::from_millis(80)).await,
244                _ => {}
245            }
246        }
247
248        //
249        static N: Lazy<AtomicUsize> = Lazy::new(|| AtomicUsize::new(0));
250
251        let policy = SimplePolicy::new(
252            PredicateWrapper::new(AlwaysPredicate),
253            3,
254            FnBackoff::from(|_| Duration::from_millis(100)),
255        );
256
257        //
258        let now = std::time::Instant::now();
259
260        match retry_with_timeout_for_non_logic_error::<Sleep, _, _, _, ()>(
261            policy,
262            || f(N.fetch_add(1, Ordering::SeqCst)),
263            Duration::from_millis(50),
264        )
265        .await
266        {
267            Ok(_) => {}
268            Err(err) => {
269                panic!("{err:?}")
270            }
271        }
272
273        let elapsed_dur = now.elapsed();
274        assert!(elapsed_dur.as_millis() >= 150 && elapsed_dur.as_millis() <= 155);
275    }
276
277    #[tokio::test]
278    async fn test_retry_with_timeout_for_non_logic_error_with_max_retries_reached() {
279        async fn f(_n: usize) {
280            tokio::time::sleep(tokio::time::Duration::from_millis(80)).await;
281        }
282
283        //
284        static N: Lazy<AtomicUsize> = Lazy::new(|| AtomicUsize::new(0));
285
286        let policy = SimplePolicy::new(
287            PredicateWrapper::new(AlwaysPredicate),
288            3,
289            FnBackoff::from(|_| Duration::from_millis(100)),
290        );
291
292        //
293        let now = std::time::Instant::now();
294
295        match retry_with_timeout_for_non_logic_error::<Sleep, _, _, _, ()>(
296            policy,
297            || f(N.fetch_add(1, Ordering::SeqCst)),
298            Duration::from_millis(50),
299        )
300        .await
301        {
302            Ok(_) => panic!(""),
303            Err(err) => {
304                assert_eq!(&err.stop_reason, &StopReason::MaxRetriesReached);
305                for (i, err) in err.errors().iter().enumerate() {
306                    println!("{i} {err:?}");
307                    match i {
308                        0..=3 => match err {
309                            ErrorWrapper::Timeout(TimeoutError::Timeout(dur)) => {
310                                assert_eq!(*dur, Duration::from_millis(50));
311                            }
312                            err => panic!("{i} {err:?}"),
313                        },
314
315                        n => panic!("{n} {err:?}"),
316                    }
317                }
318            }
319        }
320
321        let elapsed_dur = now.elapsed();
322        assert!(elapsed_dur.as_millis() >= 500 && elapsed_dur.as_millis() <= 515);
323    }
324}
325
326#[cfg(test)]
327mod tests_without_std {
328    use super::*;
329
330    #[test]
331    fn test_error_wrapper() {
332        //
333        let inner_err = ErrorWrapper::Inner(());
334        assert!(inner_err.is_inner());
335        assert!(!inner_err.is_timeout());
336        assert_eq!(inner_err.into_inner(), Some(()));
337
338        //
339        let timeout_err =
340            ErrorWrapper::<()>::Timeout(TimeoutError::Timeout(Duration::from_secs(1)));
341        assert!(!timeout_err.is_inner());
342        assert!(timeout_err.is_timeout());
343        assert_eq!(timeout_err.into_inner(), None);
344    }
345}