cs_utils/utils/futures/
with_retries.rs

1use std::{
2    error::Error as StdError,
3    fmt::Debug,
4};
5
6use futures::Future;
7use thiserror::Error as ThisError;
8
9use crate::futures::wait;
10
11#[derive(ThisError, Debug)]
12pub enum BaseError<S: AsRef<str> + ToString + Send = String> {
13    #[error("{0}")]
14    Generic(#[from] Box<dyn StdError + Send + Sync>),
15
16    #[error("{0}")]
17    Msg(S),
18}
19
20impl<T, S: AsRef<str> + ToString + Send> From<BaseError<S>> for Result<T, BaseError<S>> {
21    fn from(error: BaseError<S>) -> Self {
22        return Err(error);
23    }
24}
25
26pub struct WithRetriesOptions<
27    TError: Into<Box<dyn StdError + Send + Sync>> = BaseError,
28> {
29    pub(crate) should_stop: Box<dyn Fn(&TError, usize) -> bool + Send>,
30    pub(crate) retries: usize,
31    pub(crate) retry_delay: u64,
32}
33
34impl<TError: Into<Box<dyn StdError + Send + Sync>>> WithRetriesOptions<TError> {
35    pub fn retries(
36        self,
37        retries: usize,
38    ) -> WithRetriesOptions<TError> {
39        return WithRetriesOptions {
40            retries,
41            ..self
42        };
43    }
44
45    pub fn decrement_retries(
46        &mut self,
47    ) {
48        // make sure we don't violate `usize` type by substracting from zero
49        assert!(
50            self.retries > 0,
51            "Retries must be greater than 0.",
52        );
53
54        self.retries -= 1;
55    }
56
57    pub fn retry_delay(
58        self,
59        retry_delay_ms: u64,
60    ) -> WithRetriesOptions<TError> {
61        return WithRetriesOptions {
62            retry_delay: retry_delay_ms,
63            ..self
64        };
65    }
66
67    pub fn should_stop_retries(
68        self,
69        should_stop_retries: impl Fn(&TError, usize) -> bool + Send + 'static,
70    ) -> WithRetriesOptions<TError> {
71        return WithRetriesOptions {
72            should_stop: Box::new(should_stop_retries),
73            ..self
74        };
75    }
76}
77
78impl<TError: Into<Box<dyn StdError + Send + Sync>>> Default for WithRetriesOptions<TError> {
79    fn default() -> WithRetriesOptions<TError> {
80        return WithRetriesOptions {
81            retries: 2,
82            retry_delay: 0,
83            should_stop: Box::new(|_, _| {
84                return false;
85            }),
86        }
87    }
88}
89
90pub async fn with_retries_inner<
91    T: Send,
92    TError: Into<Box<dyn StdError + Send + Sync>>,
93    TFuture: Future<Output = Result<T, TError>> + Send,
94>(
95    job: TFuture,
96    retries: usize,
97    // options: &WithRetriesOptions,
98) -> (Result<T, TError>, bool) {
99    // let retries = options.retries;
100    // let should_stop = &options.should_stop;
101
102    // if success, return the future result
103    let error = match job.await {
104        Ok(result) => return (Ok(result), true),
105        Err(err) => err,
106    };
107
108    // if no retries left, return the error result
109    if retries == 0 {
110        return (Err(error), true);
111    }
112
113    // // pass `error` reference and `retry` number to the closure,
114    // // if `true` is returned, - return the error result and stop
115    // if should_stop(&error, retries) {
116    //     return (Err(error), true);
117    // }
118
119    // return error result but not stop
120    return (Err(error), false);
121}
122
123/// !Experimental! Might be unstable in regard of runtime stability and API surface.
124/// 
125/// Executes a future with retries if it fails.
126pub async fn with_retries<
127    T: Send,
128    TError: Into<Box<dyn StdError + Send + Sync>>,
129    TFuture: Future<Output = Result<T, TError>> + Send,
130>(
131    mut future_factory: impl FnMut(usize) -> TFuture,
132    mut options: WithRetriesOptions<TError>,
133) -> Result<T, TError> {
134    loop {
135        let (result, is_done) = with_retries_inner(
136            future_factory(options.retries),
137            options.retries,
138        ).await;
139
140        // if done, return the result
141        if is_done {
142            return result;
143        }
144        
145        // get error value and assert that the result is not `Ok()`
146        let error = match result {
147            Err(err) => err,
148            Ok(_) => {
149                panic!("Result cannot be `Ok` here.");
150            },
151        };
152
153        // pass `error` reference and `retry` number to the closure,
154        // if `true` is returned, - return the error result and stop
155        if (&options.should_stop)(&error, options.retries) {
156            return Err(error);
157        }
158
159        // wait for `retry_delay` milliseconds before retrying
160        wait(options.retry_delay).await;
161
162        // remove one retry
163        options.decrement_retries();
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use std::error::Error as StdError;
170
171    use cs_utils::{
172        random_number,
173        futures::{wait_random, with_retries, WithRetriesOptions},
174    };
175    use anyhow::anyhow;
176    use thiserror::Error as ThisError;
177
178    use crate::futures::with_retries::BaseError;
179
180    mod future_factory {
181        use super::*;
182
183        use tokio::sync::mpsc;
184
185        /// The `expected_result` variable is external for the future factory here.
186        #[tokio::test]
187        async fn factory_depends_on_external_variable() {
188            let expected_result = random_number(0..=u64::MAX);
189
190            let result = with_retries(|_| {
191                    return async move {
192                        wait_random(1..=5).await;
193
194                        // help rust to infer the `Result` type
195                        if false {
196                            return BaseError::Msg("Something bad happened.").into();
197                        }
198
199                        return Ok(expected_result);
200                    };
201                },
202                Default::default(),
203            ).await.expect("Must succeed.");
204
205            assert_eq!(
206                result,
207                expected_result,
208                "Must return the expected result.",
209            );
210        }
211
212        #[tokio::test]
213        async fn factory_depends_on_retries_left() {
214            let result = with_retries(|retries_left| async move {
215                    wait_random(1..=5).await;
216                    
217                    // help rust to infer the `Result` type
218                    if false {
219                        return Err(BaseError::Msg("Something bad happened."));
220                    }
221
222                    return Ok(retries_left);
223                },
224                Default::default(),
225            ).await.expect("Must succeed.");
226
227            assert_eq!(
228                result,
229                2,
230                "Must return the default `retries` value result.",
231            );
232        }
233
234        #[tokio::test]
235        async fn factory_without_move() {
236            let result = with_retries(|_| async {
237                    wait_random(1..=5).await;
238
239                    // help rust to infer the `Result` type
240                    if false {
241                        return Err(BaseError::Msg("Something bad happened."));
242                    }
243
244                    return Ok(10);
245                },
246                Default::default(),
247            ).await.expect("Must succeed.");
248
249            assert_eq!(
250                result,
251                10,
252                "Must return the expected result.",
253            );
254        }
255
256        #[tokio::test]
257        async fn factory_takes_number_of_remaining_retries() {
258            let (sender, mut receiver) = mpsc::channel(1024);
259
260            tokio::try_join!(
261                tokio::spawn(async move {
262                    return with_retries(move |retries_left| {
263                            let tx = sender.clone();
264            
265                            return async move {
266                                wait_random(1..=5).await;
267            
268                                tx
269                                    .send(retries_left).await
270                                    .expect("Cannot send retries left number.");
271            
272                                if retries_left > 0 {
273                                    let err: Box<dyn StdError + Send + Sync> = anyhow!("Some error.").into();
274                                    return Err(err);
275                                }
276            
277                                return Ok(());
278                            };
279                        },
280                        Default::default(),
281                    ).await.expect("Must succeed.");
282                }),
283                tokio::spawn(async move {
284                    let mut expected_retries_left = 2;
285
286                    while let Some(retries_left) = receiver.recv().await {
287                        assert_eq!(
288                            retries_left,
289                            expected_retries_left,
290                            "Must return the expected retries left.",
291                        );
292
293                        if expected_retries_left == 0 {
294                            break;
295                        }
296
297                        expected_retries_left -= 1;
298                    }
299                }),
300            ).unwrap();
301        }
302    }
303    
304    mod should_stop_retries {
305        use crate::{traits::Random, test::random_vec, random_str};
306
307        use super::*;
308
309        #[rstest::rstest]
310        #[case(2, vec![2], true)]
311        #[case(1, vec![2, 1], true)]
312        #[case(0, vec![2, 1, 0], false)]
313        #[tokio::test]
314        async fn defines_if_should_stop_retries(
315            #[case] stop_at: usize,
316            #[case] expected_retries_left: Vec<usize>,
317            #[case] is_error: bool,
318        ) {
319            let options = WithRetriesOptions::default()
320                .should_stop_retries(move |_error, retries_left| {
321                    return retries_left == stop_at;
322                });
323
324            let mut items = vec![];
325            let result = with_retries(|retries_left| {
326                    items.push(retries_left);
327                    return async move {
328                        wait_random(1..=5).await;
329
330                        if retries_left > 0 {
331                            return Err(anyhow!("Some error."));
332                        }
333    
334                        return Ok(());
335                    };
336                },
337                options,
338            ).await;
339
340            assert_eq!(
341                items,
342                expected_retries_left,
343            );
344
345            assert_eq!(
346                result.is_err(),
347                is_error,
348                "Must return correct result.",
349            );
350        }
351
352        #[tokio::test]
353        async fn receives_original_error() {
354            let options = WithRetriesOptions::default()
355                .should_stop_retries(move |error, retry_number| {
356                    assert_eq!(
357                        format!("{}", error),
358                        format!("Oh no! #{}", retry_number),
359                        "Must accept original error.",
360                    );
361
362                    return false;
363                });
364
365            with_retries(|retries_left| {
366                    return async move {
367                        wait_random(1..=5).await;
368
369                        if retries_left > 0 {
370                            return Err(BaseError::Msg(format!("Oh no! #{}", retries_left)));
371                        }
372    
373                        return Ok(());
374                    };
375                },
376                options,
377            ).await.expect("Must succeed.");
378        }
379
380        #[derive(ThisError, Debug)]
381        pub enum TestError<S: AsRef<str> + ToString = String> {
382            // #[error("data store disconnected")]
383            // Disconnect(#[from] io::Error),
384
385            // #[error("the data for key `{0}` is not available")]
386            // Redaction(String),
387
388            // #[error("invalid header (expected {expected:?}, found {found:?})")]
389            // InvalidHeader {
390            //     expected: String,
391            //     found: String,
392            // },
393
394            #[error("{0}")]
395            Msg(S),
396
397            #[error("{0}")]
398            Anyhow(anyhow::Error),
399
400            #[error("Unknown error. [code: {0}]")]
401            Unknown(u128),
402        }
403
404        impl Clone for TestError {
405            fn clone(&self) -> TestError {
406                match self {
407                    TestError::Msg(msg) => {
408                        return TestError::Msg(msg.clone());
409                    },
410                    TestError::Anyhow(msg) => {
411                        return TestError::Anyhow(anyhow!(msg.to_string()));
412                    },
413                    TestError::Unknown(id) => return TestError::Unknown(id.clone()),
414                };
415            }
416        }
417
418        impl PartialEq for TestError {
419            fn eq(&self, other: &TestError) -> bool {
420                match (self, other) {
421                    (TestError::Msg(left_msg), TestError::Msg(right_msg)) => {
422                        return left_msg == right_msg;
423                    },
424                    (TestError::Anyhow(left_err), TestError::Anyhow(right_err)) => {
425                        return format!("{}", left_err) == format!("{}", right_err);
426                    },
427                    (TestError::Unknown(left_id), TestError::Unknown(right_id)) => {
428                        return left_id == right_id;
429                    },
430                    _ => return false,
431                };
432            }
433        }
434
435        impl Random for TestError {
436            fn random() -> TestError {
437                match random_number(0..=2) {
438                    0 => {
439                        return TestError::Msg(
440                            format!("Something bad happened. [code: {}]", random_str(16)),
441                        );
442                    },
443                    1 => {
444                        return TestError::Anyhow(anyhow!(
445                            format!("Oh no! [code: {}]", random_str(16)),
446                        ));
447                    },
448                    2 => {
449                        return TestError::Unknown(random_number(0..=u128::MAX))
450                    },
451                    _ => unreachable!(),
452                };
453            }
454        }
455
456        #[rstest::rstest]
457        #[case(random_number(0..=8))]
458        #[case(random_number(0..=8))]
459        #[case(random_number(0..=8))]
460        #[case(random_number(0..=8))]
461        #[case(random_number(0..=8))]
462        #[case(random_number(0..=8))]
463        #[case(random_number(0..=8))]
464        #[case(random_number(0..=8))]
465        #[tokio::test]
466        async fn receives_original_error_of_type(
467            #[case] retry_count: u32,
468        ) {
469            let errors: Vec<TestError> = random_vec(retry_count + 1);
470
471            let expected_errors = errors.clone();
472            let options = WithRetriesOptions::default()
473                .should_stop_retries(move |error, retry_number| {
474                    assert_eq!(
475                        error,
476                        &expected_errors[retry_number],
477                        "Must receive original errors.",
478                    );
479
480                    return false;
481                })
482                .retries(retry_count as usize);
483
484            let result = with_retries(|retries_left|
485                {
486                    let sent_errors = errors.clone();
487                    return async move {
488                        wait_random(1..=5).await;
489
490                        if retries_left > 0 {
491                            return Err(sent_errors[retries_left].clone());
492                        }
493    
494                        return Ok(retries_left);
495                    };
496                },
497                options,
498            ).await.expect("Must succeed.");
499
500            assert_eq!(
501                result,
502                0,
503                "Must exhaust all retry attempts.",
504            );
505        }
506
507        #[rstest::rstest]
508        #[case(random_number(0..=2), TestError::random())]
509        #[case(random_number(0..=2), TestError::random())]
510        #[case(random_number(0..=2), TestError::random())]
511        #[case(random_number(0..=2), TestError::random())]
512        #[case(random_number(0..=2), TestError::random())]
513        #[case(random_number(0..=2), TestError::random())]
514        #[case(random_number(0..=2), TestError::random())]
515        #[case(random_number(0..=2), TestError::random())]
516        #[tokio::test]
517        async fn can_stop_by_error(
518            #[case] stop_at: usize,
519            #[case] test_error: TestError,
520        ) {
521            let expected_error = test_error.clone();
522
523            let options = WithRetriesOptions::default()
524                .should_stop_retries(move |error, retry_number| {
525                    assert!(
526                        retry_number >= stop_at,
527                        "Must not run after `true` is returned",
528                    );
529
530                    if error == &expected_error {
531                        assert_eq!(
532                            retry_number,
533                            stop_at,
534                            "The expected error must be thrown at the expected iteration.",
535                        );
536
537                        return true;
538                    }
539                    
540                    return false;
541                });
542
543            let result = with_retries(|retries_left|
544                {
545                    let sent_error = test_error.clone();
546                    return async move {
547                        wait_random(1..=5).await;
548
549                        if retries_left == stop_at {
550                            return Err(sent_error);
551                        }
552
553                        if retries_left > 0 {
554                            return Err(TestError::random());
555                        }
556    
557                        return Ok(());
558                    };
559                },
560                options,
561            ).await.expect_err("Must fail.");
562
563            assert_eq!(
564                result,
565                test_error,
566                "Must complete with the expected error.",
567            );
568        }
569    }
570
571    mod retry_delay {
572        #[tokio::test]
573        async fn waits_before_retry() {
574            todo!();
575        }
576    }
577}