1use alloc::boxed::Box;
2use core::{convert::Infallible, fmt, future::Future, time::Duration};
3
4use async_sleep::{
5 timeout::{timeout, Error as TimeoutError},
6 Sleepble,
7};
8use futures_util::TryFutureExt as _;
9use retry_policy::{retry_predicate::RetryPredicate, RetryPolicy};
10
11use crate::retry::Retry;
12
13pub fn retry_with_timeout<SLEEP, POL, F, Fut, T, E>(
15 policy: POL,
16 future_repeater: F,
17 every_performance_timeout_dur: Duration,
18) -> Retry<SLEEP, POL, T, ErrorWrapper<E>>
19where
20 SLEEP: Sleepble + 'static,
21 POL: RetryPolicy<ErrorWrapper<E>>,
22 F: Fn() -> Fut + Send + 'static,
23 Fut: Future<Output = Result<T, E>> + Send + 'static,
24{
25 Retry::<SLEEP, _, _, _>::new(
26 policy,
27 Box::new(move || {
28 let fut = future_repeater();
29 Box::pin(
30 timeout::<SLEEP, _>(every_performance_timeout_dur, Box::pin(fut)).map_ok_or_else(
31 |err| Err(ErrorWrapper::Timeout(err)),
32 |ret| match ret {
33 Ok(x) => Ok(x),
34 Err(err) => Err(ErrorWrapper::Inner(err)),
35 },
36 ),
37 )
38 }),
39 )
40}
41
42pub fn retry_with_timeout_for_non_logic_error<SLEEP, POL, F, Fut, T>(
44 policy: POL,
45 future_repeater: F,
46 every_performance_timeout_dur: Duration,
47) -> Retry<SLEEP, POL, T, ErrorWrapper<Infallible>>
48where
49 SLEEP: Sleepble + 'static,
50 POL: RetryPolicy<ErrorWrapper<Infallible>>,
51 F: Fn() -> Fut + Send + 'static,
52 Fut: Future<Output = T> + Send + 'static,
53{
54 Retry::<SLEEP, _, _, _>::new(
55 policy,
56 Box::new(move || {
57 let fut = future_repeater();
58 Box::pin(
59 timeout::<SLEEP, _>(every_performance_timeout_dur, Box::pin(fut))
60 .map_ok_or_else(|err| Err(ErrorWrapper::Timeout(err)), |x| Ok(x)),
61 )
62 }),
63 )
64}
65
66pub enum ErrorWrapper<T> {
70 Inner(T),
71 Timeout(TimeoutError),
72}
73
74impl<T> fmt::Debug for ErrorWrapper<T>
75where
76 T: fmt::Debug,
77{
78 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79 match self {
80 ErrorWrapper::Inner(err) => f.debug_tuple("ErrorWrapper::Inner").field(err).finish(),
81 ErrorWrapper::Timeout(err) => {
82 f.debug_tuple("ErrorWrapper::Timeout").field(err).finish()
83 }
84 }
85 }
86}
87
88impl<T> fmt::Display for ErrorWrapper<T>
89where
90 T: fmt::Debug,
91{
92 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93 write!(f, "{self:?}")
94 }
95}
96
97#[cfg(feature = "std")]
98impl<T> std::error::Error for ErrorWrapper<T> where T: fmt::Debug {}
99
100impl<T> ErrorWrapper<T> {
101 pub fn is_inner(&self) -> bool {
102 matches!(self, Self::Inner(_))
103 }
104
105 pub fn is_timeout(&self) -> bool {
106 matches!(self, Self::Timeout(_))
107 }
108
109 pub fn into_inner(self) -> Option<T> {
110 match self {
111 Self::Inner(x) => Some(x),
112 Self::Timeout(_) => None,
113 }
114 }
115}
116
117pub struct PredicateWrapper<T> {
121 inner: T,
122}
123
124impl<T> fmt::Debug for PredicateWrapper<T>
125where
126 T: fmt::Debug,
127{
128 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129 f.debug_struct("PredicateWrapper")
130 .field("inner", &self.inner)
131 .finish()
132 }
133}
134
135impl<T> PredicateWrapper<T> {
136 pub fn new(inner: T) -> Self {
137 Self { inner }
138 }
139}
140
141impl<E, P> RetryPredicate<ErrorWrapper<E>> for PredicateWrapper<P>
142where
143 P: RetryPredicate<E>,
144{
145 fn test(&self, params: &ErrorWrapper<E>) -> bool {
146 match params {
147 ErrorWrapper::Inner(inner_params) => self.inner.test(inner_params),
148 ErrorWrapper::Timeout(_) => true,
149 }
150 }
151}
152
153#[cfg(feature = "std")]
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 use core::{
159 sync::atomic::{AtomicUsize, Ordering},
160 time::Duration,
161 };
162
163 use async_sleep::impl_tokio::Sleep;
164 use once_cell::sync::Lazy;
165 use retry_policy::{
166 policies::SimplePolicy,
167 retry_backoff::backoffs::FnBackoff,
168 retry_predicate::predicates::{AlwaysPredicate, FnPredicate},
169 StopReason,
170 };
171
172 #[tokio::test]
173 async fn test_retry_with_timeout() {
174 #[derive(Debug, PartialEq)]
175 struct FError(usize);
176 async fn f(n: usize) -> Result<(), FError> {
177 #[allow(clippy::single_match)]
178 match n {
179 1 => tokio::time::sleep(tokio::time::Duration::from_millis(80)).await,
180 _ => {}
181 }
182 Err(FError(n))
183 }
184
185 static N: Lazy<AtomicUsize> = Lazy::new(|| AtomicUsize::new(0));
187
188 let policy = SimplePolicy::new(
189 PredicateWrapper::new(FnPredicate::from(|FError(n): &FError| [0, 1].contains(n))),
190 3,
191 FnBackoff::from(|_| Duration::from_millis(100)),
192 );
193
194 let now = std::time::Instant::now();
196
197 match retry_with_timeout::<Sleep, _, _, _, _, _>(
198 policy,
199 || f(N.fetch_add(1, Ordering::SeqCst)),
200 Duration::from_millis(50),
201 )
202 .await
203 {
204 Ok(_) => panic!(""),
205 Err(err) => {
206 assert_eq!(&err.stop_reason, &StopReason::PredicateFailed);
207 for (i, err) in err.errors().iter().enumerate() {
208 println!("{i} {err:?}");
209 match i {
210 0 => match err {
211 ErrorWrapper::Inner(FError(n)) => {
212 assert_eq!(*n, 0)
213 }
214 err => panic!("{i} {err:?}"),
215 },
216 1 => match err {
217 ErrorWrapper::Timeout(TimeoutError::Timeout(dur)) => {
218 assert_eq!(*dur, Duration::from_millis(50));
219 }
220 err => panic!("{i} {err:?}"),
221 },
222 2 => match err {
223 ErrorWrapper::Inner(FError(n)) => {
224 assert_eq!(*n, 2)
225 }
226 err => panic!("{i} {err:?}"),
227 },
228 n => panic!("{n} {err:?}"),
229 }
230 }
231 }
232 }
233
234 let elapsed_dur = now.elapsed();
235 assert!(elapsed_dur.as_millis() >= 250 && elapsed_dur.as_millis() <= 260);
236 }
237
238 #[tokio::test]
239 async fn test_retry_with_timeout_for_unresult() {
240 async fn f(n: usize) {
241 #[allow(clippy::single_match)]
242 match n {
243 0 => tokio::time::sleep(tokio::time::Duration::from_millis(80)).await,
244 _ => {}
245 }
246 }
247
248 static N: Lazy<AtomicUsize> = Lazy::new(|| AtomicUsize::new(0));
250
251 let policy = SimplePolicy::new(
252 PredicateWrapper::new(AlwaysPredicate),
253 3,
254 FnBackoff::from(|_| Duration::from_millis(100)),
255 );
256
257 let now = std::time::Instant::now();
259
260 match retry_with_timeout_for_non_logic_error::<Sleep, _, _, _, ()>(
261 policy,
262 || f(N.fetch_add(1, Ordering::SeqCst)),
263 Duration::from_millis(50),
264 )
265 .await
266 {
267 Ok(_) => {}
268 Err(err) => {
269 panic!("{err:?}")
270 }
271 }
272
273 let elapsed_dur = now.elapsed();
274 assert!(elapsed_dur.as_millis() >= 150 && elapsed_dur.as_millis() <= 155);
275 }
276
277 #[tokio::test]
278 async fn test_retry_with_timeout_for_non_logic_error_with_max_retries_reached() {
279 async fn f(_n: usize) {
280 tokio::time::sleep(tokio::time::Duration::from_millis(80)).await;
281 }
282
283 static N: Lazy<AtomicUsize> = Lazy::new(|| AtomicUsize::new(0));
285
286 let policy = SimplePolicy::new(
287 PredicateWrapper::new(AlwaysPredicate),
288 3,
289 FnBackoff::from(|_| Duration::from_millis(100)),
290 );
291
292 let now = std::time::Instant::now();
294
295 match retry_with_timeout_for_non_logic_error::<Sleep, _, _, _, ()>(
296 policy,
297 || f(N.fetch_add(1, Ordering::SeqCst)),
298 Duration::from_millis(50),
299 )
300 .await
301 {
302 Ok(_) => panic!(""),
303 Err(err) => {
304 assert_eq!(&err.stop_reason, &StopReason::MaxRetriesReached);
305 for (i, err) in err.errors().iter().enumerate() {
306 println!("{i} {err:?}");
307 match i {
308 0..=3 => match err {
309 ErrorWrapper::Timeout(TimeoutError::Timeout(dur)) => {
310 assert_eq!(*dur, Duration::from_millis(50));
311 }
312 err => panic!("{i} {err:?}"),
313 },
314
315 n => panic!("{n} {err:?}"),
316 }
317 }
318 }
319 }
320
321 let elapsed_dur = now.elapsed();
322 assert!(elapsed_dur.as_millis() >= 500 && elapsed_dur.as_millis() <= 515);
323 }
324}
325
326#[cfg(test)]
327mod tests_without_std {
328 use super::*;
329
330 #[test]
331 fn test_error_wrapper() {
332 let inner_err = ErrorWrapper::Inner(());
334 assert!(inner_err.is_inner());
335 assert!(!inner_err.is_timeout());
336 assert_eq!(inner_err.into_inner(), Some(()));
337
338 let timeout_err =
340 ErrorWrapper::<()>::Timeout(TimeoutError::Timeout(Duration::from_secs(1)));
341 assert!(!timeout_err.is_inner());
342 assert!(timeout_err.is_timeout());
343 assert_eq!(timeout_err.into_inner(), None);
344 }
345}