1use std::{
2 error::Error as StdError,
3 fmt::Debug,
4};
5
6use futures::Future;
7use thiserror::Error as ThisError;
8
9use crate::futures::wait;
10
11#[derive(ThisError, Debug)]
12pub enum BaseError<S: AsRef<str> + ToString + Send = String> {
13 #[error("{0}")]
14 Generic(#[from] Box<dyn StdError + Send + Sync>),
15
16 #[error("{0}")]
17 Msg(S),
18}
19
20impl<T, S: AsRef<str> + ToString + Send> From<BaseError<S>> for Result<T, BaseError<S>> {
21 fn from(error: BaseError<S>) -> Self {
22 return Err(error);
23 }
24}
25
26pub struct WithRetriesOptions<
27 TError: Into<Box<dyn StdError + Send + Sync>> = BaseError,
28> {
29 pub(crate) should_stop: Box<dyn Fn(&TError, usize) -> bool + Send>,
30 pub(crate) retries: usize,
31 pub(crate) retry_delay: u64,
32}
33
34impl<TError: Into<Box<dyn StdError + Send + Sync>>> WithRetriesOptions<TError> {
35 pub fn retries(
36 self,
37 retries: usize,
38 ) -> WithRetriesOptions<TError> {
39 return WithRetriesOptions {
40 retries,
41 ..self
42 };
43 }
44
45 pub fn decrement_retries(
46 &mut self,
47 ) {
48 assert!(
50 self.retries > 0,
51 "Retries must be greater than 0.",
52 );
53
54 self.retries -= 1;
55 }
56
57 pub fn retry_delay(
58 self,
59 retry_delay_ms: u64,
60 ) -> WithRetriesOptions<TError> {
61 return WithRetriesOptions {
62 retry_delay: retry_delay_ms,
63 ..self
64 };
65 }
66
67 pub fn should_stop_retries(
68 self,
69 should_stop_retries: impl Fn(&TError, usize) -> bool + Send + 'static,
70 ) -> WithRetriesOptions<TError> {
71 return WithRetriesOptions {
72 should_stop: Box::new(should_stop_retries),
73 ..self
74 };
75 }
76}
77
78impl<TError: Into<Box<dyn StdError + Send + Sync>>> Default for WithRetriesOptions<TError> {
79 fn default() -> WithRetriesOptions<TError> {
80 return WithRetriesOptions {
81 retries: 2,
82 retry_delay: 0,
83 should_stop: Box::new(|_, _| {
84 return false;
85 }),
86 }
87 }
88}
89
90pub async fn with_retries_inner<
91 T: Send,
92 TError: Into<Box<dyn StdError + Send + Sync>>,
93 TFuture: Future<Output = Result<T, TError>> + Send,
94>(
95 job: TFuture,
96 retries: usize,
97 ) -> (Result<T, TError>, bool) {
99 let error = match job.await {
104 Ok(result) => return (Ok(result), true),
105 Err(err) => err,
106 };
107
108 if retries == 0 {
110 return (Err(error), true);
111 }
112
113 return (Err(error), false);
121}
122
123pub async fn with_retries<
127 T: Send,
128 TError: Into<Box<dyn StdError + Send + Sync>>,
129 TFuture: Future<Output = Result<T, TError>> + Send,
130>(
131 mut future_factory: impl FnMut(usize) -> TFuture,
132 mut options: WithRetriesOptions<TError>,
133) -> Result<T, TError> {
134 loop {
135 let (result, is_done) = with_retries_inner(
136 future_factory(options.retries),
137 options.retries,
138 ).await;
139
140 if is_done {
142 return result;
143 }
144
145 let error = match result {
147 Err(err) => err,
148 Ok(_) => {
149 panic!("Result cannot be `Ok` here.");
150 },
151 };
152
153 if (&options.should_stop)(&error, options.retries) {
156 return Err(error);
157 }
158
159 wait(options.retry_delay).await;
161
162 options.decrement_retries();
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use std::error::Error as StdError;
170
171 use cs_utils::{
172 random_number,
173 futures::{wait_random, with_retries, WithRetriesOptions},
174 };
175 use anyhow::anyhow;
176 use thiserror::Error as ThisError;
177
178 use crate::futures::with_retries::BaseError;
179
180 mod future_factory {
181 use super::*;
182
183 use tokio::sync::mpsc;
184
185 #[tokio::test]
187 async fn factory_depends_on_external_variable() {
188 let expected_result = random_number(0..=u64::MAX);
189
190 let result = with_retries(|_| {
191 return async move {
192 wait_random(1..=5).await;
193
194 if false {
196 return BaseError::Msg("Something bad happened.").into();
197 }
198
199 return Ok(expected_result);
200 };
201 },
202 Default::default(),
203 ).await.expect("Must succeed.");
204
205 assert_eq!(
206 result,
207 expected_result,
208 "Must return the expected result.",
209 );
210 }
211
212 #[tokio::test]
213 async fn factory_depends_on_retries_left() {
214 let result = with_retries(|retries_left| async move {
215 wait_random(1..=5).await;
216
217 if false {
219 return Err(BaseError::Msg("Something bad happened."));
220 }
221
222 return Ok(retries_left);
223 },
224 Default::default(),
225 ).await.expect("Must succeed.");
226
227 assert_eq!(
228 result,
229 2,
230 "Must return the default `retries` value result.",
231 );
232 }
233
234 #[tokio::test]
235 async fn factory_without_move() {
236 let result = with_retries(|_| async {
237 wait_random(1..=5).await;
238
239 if false {
241 return Err(BaseError::Msg("Something bad happened."));
242 }
243
244 return Ok(10);
245 },
246 Default::default(),
247 ).await.expect("Must succeed.");
248
249 assert_eq!(
250 result,
251 10,
252 "Must return the expected result.",
253 );
254 }
255
256 #[tokio::test]
257 async fn factory_takes_number_of_remaining_retries() {
258 let (sender, mut receiver) = mpsc::channel(1024);
259
260 tokio::try_join!(
261 tokio::spawn(async move {
262 return with_retries(move |retries_left| {
263 let tx = sender.clone();
264
265 return async move {
266 wait_random(1..=5).await;
267
268 tx
269 .send(retries_left).await
270 .expect("Cannot send retries left number.");
271
272 if retries_left > 0 {
273 let err: Box<dyn StdError + Send + Sync> = anyhow!("Some error.").into();
274 return Err(err);
275 }
276
277 return Ok(());
278 };
279 },
280 Default::default(),
281 ).await.expect("Must succeed.");
282 }),
283 tokio::spawn(async move {
284 let mut expected_retries_left = 2;
285
286 while let Some(retries_left) = receiver.recv().await {
287 assert_eq!(
288 retries_left,
289 expected_retries_left,
290 "Must return the expected retries left.",
291 );
292
293 if expected_retries_left == 0 {
294 break;
295 }
296
297 expected_retries_left -= 1;
298 }
299 }),
300 ).unwrap();
301 }
302 }
303
304 mod should_stop_retries {
305 use crate::{traits::Random, test::random_vec, random_str};
306
307 use super::*;
308
309 #[rstest::rstest]
310 #[case(2, vec![2], true)]
311 #[case(1, vec![2, 1], true)]
312 #[case(0, vec![2, 1, 0], false)]
313 #[tokio::test]
314 async fn defines_if_should_stop_retries(
315 #[case] stop_at: usize,
316 #[case] expected_retries_left: Vec<usize>,
317 #[case] is_error: bool,
318 ) {
319 let options = WithRetriesOptions::default()
320 .should_stop_retries(move |_error, retries_left| {
321 return retries_left == stop_at;
322 });
323
324 let mut items = vec![];
325 let result = with_retries(|retries_left| {
326 items.push(retries_left);
327 return async move {
328 wait_random(1..=5).await;
329
330 if retries_left > 0 {
331 return Err(anyhow!("Some error."));
332 }
333
334 return Ok(());
335 };
336 },
337 options,
338 ).await;
339
340 assert_eq!(
341 items,
342 expected_retries_left,
343 );
344
345 assert_eq!(
346 result.is_err(),
347 is_error,
348 "Must return correct result.",
349 );
350 }
351
352 #[tokio::test]
353 async fn receives_original_error() {
354 let options = WithRetriesOptions::default()
355 .should_stop_retries(move |error, retry_number| {
356 assert_eq!(
357 format!("{}", error),
358 format!("Oh no! #{}", retry_number),
359 "Must accept original error.",
360 );
361
362 return false;
363 });
364
365 with_retries(|retries_left| {
366 return async move {
367 wait_random(1..=5).await;
368
369 if retries_left > 0 {
370 return Err(BaseError::Msg(format!("Oh no! #{}", retries_left)));
371 }
372
373 return Ok(());
374 };
375 },
376 options,
377 ).await.expect("Must succeed.");
378 }
379
380 #[derive(ThisError, Debug)]
381 pub enum TestError<S: AsRef<str> + ToString = String> {
382 #[error("{0}")]
395 Msg(S),
396
397 #[error("{0}")]
398 Anyhow(anyhow::Error),
399
400 #[error("Unknown error. [code: {0}]")]
401 Unknown(u128),
402 }
403
404 impl Clone for TestError {
405 fn clone(&self) -> TestError {
406 match self {
407 TestError::Msg(msg) => {
408 return TestError::Msg(msg.clone());
409 },
410 TestError::Anyhow(msg) => {
411 return TestError::Anyhow(anyhow!(msg.to_string()));
412 },
413 TestError::Unknown(id) => return TestError::Unknown(id.clone()),
414 };
415 }
416 }
417
418 impl PartialEq for TestError {
419 fn eq(&self, other: &TestError) -> bool {
420 match (self, other) {
421 (TestError::Msg(left_msg), TestError::Msg(right_msg)) => {
422 return left_msg == right_msg;
423 },
424 (TestError::Anyhow(left_err), TestError::Anyhow(right_err)) => {
425 return format!("{}", left_err) == format!("{}", right_err);
426 },
427 (TestError::Unknown(left_id), TestError::Unknown(right_id)) => {
428 return left_id == right_id;
429 },
430 _ => return false,
431 };
432 }
433 }
434
435 impl Random for TestError {
436 fn random() -> TestError {
437 match random_number(0..=2) {
438 0 => {
439 return TestError::Msg(
440 format!("Something bad happened. [code: {}]", random_str(16)),
441 );
442 },
443 1 => {
444 return TestError::Anyhow(anyhow!(
445 format!("Oh no! [code: {}]", random_str(16)),
446 ));
447 },
448 2 => {
449 return TestError::Unknown(random_number(0..=u128::MAX))
450 },
451 _ => unreachable!(),
452 };
453 }
454 }
455
456 #[rstest::rstest]
457 #[case(random_number(0..=8))]
458 #[case(random_number(0..=8))]
459 #[case(random_number(0..=8))]
460 #[case(random_number(0..=8))]
461 #[case(random_number(0..=8))]
462 #[case(random_number(0..=8))]
463 #[case(random_number(0..=8))]
464 #[case(random_number(0..=8))]
465 #[tokio::test]
466 async fn receives_original_error_of_type(
467 #[case] retry_count: u32,
468 ) {
469 let errors: Vec<TestError> = random_vec(retry_count + 1);
470
471 let expected_errors = errors.clone();
472 let options = WithRetriesOptions::default()
473 .should_stop_retries(move |error, retry_number| {
474 assert_eq!(
475 error,
476 &expected_errors[retry_number],
477 "Must receive original errors.",
478 );
479
480 return false;
481 })
482 .retries(retry_count as usize);
483
484 let result = with_retries(|retries_left|
485 {
486 let sent_errors = errors.clone();
487 return async move {
488 wait_random(1..=5).await;
489
490 if retries_left > 0 {
491 return Err(sent_errors[retries_left].clone());
492 }
493
494 return Ok(retries_left);
495 };
496 },
497 options,
498 ).await.expect("Must succeed.");
499
500 assert_eq!(
501 result,
502 0,
503 "Must exhaust all retry attempts.",
504 );
505 }
506
507 #[rstest::rstest]
508 #[case(random_number(0..=2), TestError::random())]
509 #[case(random_number(0..=2), TestError::random())]
510 #[case(random_number(0..=2), TestError::random())]
511 #[case(random_number(0..=2), TestError::random())]
512 #[case(random_number(0..=2), TestError::random())]
513 #[case(random_number(0..=2), TestError::random())]
514 #[case(random_number(0..=2), TestError::random())]
515 #[case(random_number(0..=2), TestError::random())]
516 #[tokio::test]
517 async fn can_stop_by_error(
518 #[case] stop_at: usize,
519 #[case] test_error: TestError,
520 ) {
521 let expected_error = test_error.clone();
522
523 let options = WithRetriesOptions::default()
524 .should_stop_retries(move |error, retry_number| {
525 assert!(
526 retry_number >= stop_at,
527 "Must not run after `true` is returned",
528 );
529
530 if error == &expected_error {
531 assert_eq!(
532 retry_number,
533 stop_at,
534 "The expected error must be thrown at the expected iteration.",
535 );
536
537 return true;
538 }
539
540 return false;
541 });
542
543 let result = with_retries(|retries_left|
544 {
545 let sent_error = test_error.clone();
546 return async move {
547 wait_random(1..=5).await;
548
549 if retries_left == stop_at {
550 return Err(sent_error);
551 }
552
553 if retries_left > 0 {
554 return Err(TestError::random());
555 }
556
557 return Ok(());
558 };
559 },
560 options,
561 ).await.expect_err("Must fail.");
562
563 assert_eq!(
564 result,
565 test_error,
566 "Must complete with the expected error.",
567 );
568 }
569 }
570
571 mod retry_delay {
572 #[tokio::test]
573 async fn waits_before_retry() {
574 todo!();
575 }
576 }
577}