1use std::cmp::min;
72use std::future::Future;
73use std::time::Duration;
74
75use futures_timer::Delay;
76#[cfg(feature = "rand")]
77use rand::Rng;
78#[cfg(feature = "rand")]
79use rand::distr::OpenClosed01;
80#[cfg(feature = "rand")]
81use rand::rng;
82
83pub async fn retry<T>(task: T) -> Result<T::Item, T::Error>
89where
90 T: Task,
91{
92 retry_if(task, Always).await
93}
94
95pub async fn retry_if<T, C>(task: T, condition: C) -> Result<T::Item, T::Error>
102where
103 T: Task,
104 C: Condition<T::Error>,
105{
106 RetryPolicy::default().retry_if(task, condition).await
107}
108
109pub async fn collect<T, C, S>(
120 task: T,
121 condition: C,
122 start_value: S,
123) -> Result<Vec<T::Item>, T::Error>
124where
125 T: TaskWithParameter<S>,
126 C: SuccessCondition<T::Item, S>,
127{
128 RetryPolicy::default()
129 .collect(task, condition, start_value)
130 .await
131}
132
133pub async fn collect_and_retry<T, C, D, S>(
147 task: T,
148 success_condition: C,
149 error_condition: D,
150 start_value: S,
151) -> Result<Vec<T::Item>, T::Error>
152where
153 T: TaskWithParameter<S>,
154 C: SuccessCondition<T::Item, S>,
155 D: Condition<T::Error>,
156 S: Clone,
157{
158 RetryPolicy::default()
159 .collect_and_retry(task, success_condition, error_condition, start_value)
160 .await
161}
162
163#[derive(Clone, Copy)]
164enum Backoff {
165 Fixed,
166 Exponential { exponent: f64 },
167}
168
169impl Default for Backoff {
170 fn default() -> Self {
171 Backoff::Exponential { exponent: 2.0 }
172 }
173}
174
175impl Backoff {
176 fn iter(self, policy: &RetryPolicy) -> BackoffIter {
177 BackoffIter {
178 backoff: self,
179 current: 1.0,
180 #[cfg(feature = "rand")]
181 jitter: policy.jitter,
182 delay: policy.delay,
183 max_delay: policy.max_delay,
184 max_retries: policy.max_retries,
185 }
186 }
187}
188
189struct BackoffIter {
190 backoff: Backoff,
191 current: f64,
192 #[cfg(feature = "rand")]
193 jitter: bool,
194 delay: Duration,
195 max_delay: Option<Duration>,
196 max_retries: usize,
197}
198
199impl Iterator for BackoffIter {
200 type Item = Duration;
201
202 fn next(&mut self) -> Option<Self::Item> {
203 if self.max_retries > 0 {
204 let factor = match self.backoff {
205 Backoff::Fixed => self.current,
206 Backoff::Exponential { exponent } => {
207 let factor = self.current;
208 let next_factor = self.current * exponent;
209 self.current = next_factor;
210 factor
211 }
212 };
213
214 let mut delay = self.delay.mul_f64(factor);
215 #[cfg(feature = "rand")]
216 {
217 if self.jitter {
218 delay = jitter(delay);
219 }
220 }
221 if let Some(max_delay) = self.max_delay {
222 delay = min(delay, max_delay);
223 }
224 self.max_retries -= 1;
225
226 return Some(delay);
227 }
228 None
229 }
230}
231
232#[derive(Clone)]
238pub struct RetryPolicy {
239 backoff: Backoff,
240 #[cfg(feature = "rand")]
241 jitter: bool,
242 delay: Duration,
243 max_delay: Option<Duration>,
244 max_retries: usize,
245}
246
247impl Default for RetryPolicy {
248 fn default() -> Self {
249 Self {
250 backoff: Backoff::default(),
251 delay: Duration::from_secs(1),
252 #[cfg(feature = "rand")]
253 jitter: false,
254 max_delay: None,
255 max_retries: 5,
256 }
257 }
258}
259
260#[cfg(feature = "rand")]
261fn jitter(duration: Duration) -> Duration {
262 let jitter: f64 = rng().sample(OpenClosed01);
263 let secs = (duration.as_secs() as f64) * jitter;
264 let nanos = f64::from(duration.subsec_nanos()) * jitter;
265 let millis = (secs * 1_000_f64) + (nanos / 1_000_000_f64);
266 Duration::from_millis(millis as u64)
267}
268
269impl RetryPolicy {
270 fn backoffs(&self) -> impl Iterator<Item = Duration> + '_ {
271 self.backoff.iter(self)
272 }
273
274 pub fn exponential(delay: Duration) -> Self {
289 Self {
290 backoff: Backoff::Exponential { exponent: 2.0f64 },
291 delay,
292 ..Self::default()
293 }
294 }
295
296 pub fn fixed(delay: Duration) -> Self {
306 Self {
307 backoff: Backoff::Fixed,
308 delay,
309 ..Self::default()
310 }
311 }
312
313 pub fn with_backoff_exponent(mut self, exp: f64) -> Self {
317 if let Backoff::Exponential { ref mut exponent } = self.backoff {
318 *exponent = exp;
319 }
320 self
321 }
322
323 #[cfg(feature = "rand")]
328 pub fn with_jitter(mut self, jitter: bool) -> Self {
329 self.jitter = jitter;
330 self
331 }
332
333 pub fn with_max_delay(mut self, max: Duration) -> Self {
335 self.max_delay = Some(max);
336 self
337 }
338
339 pub fn with_max_retries(mut self, max: usize) -> Self {
341 self.max_retries = max;
342 self
343 }
344
345 pub async fn retry<T>(&self, task: T) -> Result<T::Item, T::Error>
347 where
348 T: Task,
349 {
350 self.retry_if(task, Always).await
351 }
352
353 pub async fn collect<T, C, S>(
356 &self,
357 task: T,
358 condition: C,
359 start_value: S,
360 ) -> Result<Vec<T::Item>, T::Error>
361 where
362 T: TaskWithParameter<S>,
363 C: SuccessCondition<T::Item, S>,
364 {
365 let mut backoffs = self.backoffs();
366 let mut condition = condition;
367 let mut task = task;
368 let mut results = vec![];
369 let mut input = start_value;
370
371 loop {
372 match task.call(input).await {
373 Ok(result) => {
374 let maybe_new_input = condition.retry_with(&result);
375 results.push(result);
376
377 if let Some(new_input) = maybe_new_input {
378 if let Some(delay) = backoffs.next() {
379 #[cfg(feature = "log")]
380 {
381 log::trace!(
382 "task succeeded and condition is met. will run again in {:?}",
383 delay
384 );
385 }
386 let () = Delay::new(delay).await;
387 input = new_input;
388 continue;
389 }
390 }
391
392 return Ok(results);
393 }
394 Err(err) => return Err(err),
395 }
396 }
397 }
398
399 pub async fn collect_and_retry<T, C, D, S>(
404 &self,
405 task: T,
406 success_condition: C,
407 error_condition: D,
408 start_value: S,
409 ) -> Result<Vec<T::Item>, T::Error>
410 where
411 T: TaskWithParameter<S>,
412 C: SuccessCondition<T::Item, S>,
413 D: Condition<T::Error>,
414 S: Clone,
415 {
416 let mut success_backoffs = self.backoffs();
417 let mut error_backoffs = self.backoffs();
418 let mut success_condition = success_condition;
419 let mut error_condition = error_condition;
420 let mut task = task;
421 let mut results = vec![];
422 let mut input = start_value.clone();
423 let mut last_result = start_value;
424
425 loop {
426 match task.call(input).await {
427 Ok(result) => {
428 let maybe_new_input = success_condition.retry_with(&result);
429 results.push(result);
430
431 if let Some(new_input) = maybe_new_input {
432 if let Some(delay) = success_backoffs.next() {
433 #[cfg(feature = "log")]
434 {
435 log::trace!(
436 "task succeeded and condition is met. will run again in {:?}",
437 delay
438 );
439 }
440 let () = Delay::new(delay).await;
441 input = new_input.clone();
442 last_result = new_input;
443 continue;
444 }
445 }
446
447 return Ok(results);
448 }
449 Err(err) => {
450 if error_condition.is_retryable(&err) {
451 if let Some(delay) = error_backoffs.next() {
452 #[cfg(feature = "log")]
453 {
454 log::trace!(
455 "task failed with error {:?}. will try again in {:?}",
456 err,
457 delay
458 );
459 }
460 let () = Delay::new(delay).await;
461 input = last_result.clone();
462 continue;
463 }
464 }
465 return Err(err);
466 }
467 }
468 }
469 }
470
471 pub async fn retry_if<T, C>(&self, task: T, condition: C) -> Result<T::Item, T::Error>
474 where
475 T: Task,
476 C: Condition<T::Error>,
477 {
478 let mut backoffs = self.backoffs();
479 let mut task = task;
480 let mut condition = condition;
481 loop {
482 match task.call().await {
483 Ok(result) => return Ok(result),
484 Err(err) => {
485 if condition.is_retryable(&err) {
486 if let Some(delay) = backoffs.next() {
487 #[cfg(feature = "log")]
488 {
489 log::trace!(
490 "task failed with error {:?}. will try again in {:?}",
491 err,
492 delay
493 );
494 }
495 let () = Delay::new(delay).await;
496 continue;
497 }
498 }
499 return Err(err);
500 }
501 }
502 }
503 }
504}
505
506pub trait Condition<E> {
511 fn is_retryable(&mut self, error: &E) -> bool;
513}
514
515struct Always;
516
517impl<E> Condition<E> for Always {
518 #[inline]
519 fn is_retryable(&mut self, _: &E) -> bool {
520 true
521 }
522}
523
524impl<F, E> Condition<E> for F
525where
526 F: FnMut(&E) -> bool,
527{
528 fn is_retryable(&mut self, error: &E) -> bool {
529 self(error)
530 }
531}
532
533pub trait SuccessCondition<R, S> {
539 fn retry_with(&mut self, result: &R) -> Option<S>;
541}
542
543impl<F, R, S> SuccessCondition<R, S> for F
544where
545 F: Fn(&R) -> Option<S>,
546{
547 fn retry_with(&mut self, result: &R) -> Option<S> {
548 self(result)
549 }
550}
551
552pub trait TaskWithParameter<P> {
556 type Item;
558 type Error: std::fmt::Debug;
560 type Fut: Future<Output = Result<Self::Item, Self::Error>>;
562 fn call(&mut self, parameter: P) -> Self::Fut;
564}
565
566impl<F, Fut, I, P, E> TaskWithParameter<P> for F
567where
568 F: FnMut(P) -> Fut,
569 Fut: Future<Output = Result<I, E>>,
570 E: std::fmt::Debug,
571{
572 type Error = E;
573 type Fut = Fut;
574 type Item = I;
575
576 fn call(&mut self, p: P) -> Self::Fut {
577 self(p)
578 }
579}
580
581pub trait Task {
585 type Item;
587 type Error: std::fmt::Debug;
589 type Fut: Future<Output = Result<Self::Item, Self::Error>>;
591 fn call(&mut self) -> Self::Fut;
593}
594
595impl<F, Fut, I, E> Task for F
596where
597 F: FnMut() -> Fut,
598 Fut: Future<Output = Result<I, E>>,
599 E: std::fmt::Debug,
600{
601 type Error = E;
602 type Fut = Fut;
603 type Item = I;
604
605 fn call(&mut self) -> Self::Fut {
606 self()
607 }
608}
609
610#[cfg(test)]
611mod tests {
612 use std::error::Error;
613
614 use approx::assert_relative_eq;
615
616 use super::*;
617
618 #[test]
619 fn retry_policy_is_send() {
620 fn test(_: impl Send) {}
621 test(RetryPolicy::default());
622 }
623
624 #[test]
625 #[cfg(feature = "rand")]
626 fn jitter_adds_variance_to_durations() {
627 assert!(jitter(Duration::from_secs(1)) != Duration::from_secs(1));
628 }
629
630 #[test]
631 fn backoff_default() {
632 if let Backoff::Exponential { exponent } = Backoff::default() {
633 assert_relative_eq!(exponent, 2.0);
634 } else {
635 panic!("Default backoff expected to be exponential!");
636 }
637 }
638
639 #[test]
640 fn fixed_backoff() {
641 let binding = RetryPolicy::fixed(Duration::from_secs(1));
642 let mut iter = binding.backoffs();
643 assert_eq!(iter.next(), Some(Duration::from_secs(1)));
644 assert_eq!(iter.next(), Some(Duration::from_secs(1)));
645 assert_eq!(iter.next(), Some(Duration::from_secs(1)));
646 assert_eq!(iter.next(), Some(Duration::from_secs(1)));
647 }
648
649 #[test]
650 fn exponential_backoff() {
651 let binding = RetryPolicy::exponential(Duration::from_secs(1));
652 let mut iter = binding.backoffs();
653 assert_relative_eq!(iter.next().unwrap().as_secs_f64(), 1.0);
654 assert_relative_eq!(iter.next().unwrap().as_secs_f64(), 2.0);
655 assert_relative_eq!(iter.next().unwrap().as_secs_f64(), 4.0);
656 assert_relative_eq!(iter.next().unwrap().as_secs_f64(), 8.0);
657 }
658
659 #[test]
660 fn exponential_backoff_factor() {
661 let binding = RetryPolicy::exponential(Duration::from_secs(1)).with_backoff_exponent(1.5);
662 let mut iter = binding.backoffs();
663 assert_relative_eq!(iter.next().unwrap().as_secs_f64(), 1.0);
664 assert_relative_eq!(iter.next().unwrap().as_secs_f64(), 1.5);
665 assert_relative_eq!(iter.next().unwrap().as_secs_f64(), 2.25);
666 assert_relative_eq!(iter.next().unwrap().as_secs_f64(), 3.375);
667 }
668
669 #[test]
670 fn always_is_always_retryable() {
671 assert!(Always.is_retryable(&()));
672 }
673
674 #[test]
675 fn closures_impl_condition() {
676 fn test(_: impl Condition<()>) {}
677 #[allow(clippy::trivially_copy_pass_by_ref)]
678 fn foo(_err: &()) -> bool {
679 true
680 }
681 test(foo);
682 test(|_err: &()| true);
683 }
684
685 #[test]
686 fn closures_impl_task() {
687 fn test(_: impl Task) {}
688 async fn foo() -> Result<u32, ()> {
689 Ok(42)
690 }
691 test(foo);
692 test(|| async { Ok::<u32, ()>(42) });
693 }
694
695 #[test]
696 fn retried_futures_are_send_when_tasks_are_send() {
697 fn test(_: impl Send) {}
698 test(RetryPolicy::default().retry(|| async { Ok::<u32, ()>(42) }));
699 }
700
701 #[tokio::test]
702 async fn collect_retries_when_condition_is_met() -> Result<(), Box<dyn Error>> {
703 let result = RetryPolicy::fixed(Duration::from_millis(1))
704 .collect(
705 |input: u32| async move { Ok::<u32, ()>(input + 1) },
706 |result: &u32| if *result < 2 { Some(*result) } else { None },
707 0,
708 )
709 .await;
710 assert_eq!(result, Ok(vec![1, 2]));
711 Ok(())
712 }
713
714 #[tokio::test]
715 async fn collect_does_not_retry_when_condition_is_not_met() -> Result<(), Box<dyn Error>> {
716 let result = RetryPolicy::fixed(Duration::from_millis(1))
717 .collect(
718 |input: u32| async move { Ok::<u32, ()>(input + 1) },
719 |result: &u32| if *result < 1 { Some(*result) } else { None },
720 0,
721 )
722 .await;
723 assert_eq!(result, Ok(vec![1]));
724 Ok(())
725 }
726
727 #[tokio::test]
728 async fn collect_and_retry_retries_when_success_condition_is_met() -> Result<(), Box<dyn Error>>
729 {
730 let result = RetryPolicy::fixed(Duration::from_millis(1))
731 .collect_and_retry(
732 |input: u32| async move { Ok::<u32, u32>(input + 1) },
733 |result: &u32| if *result < 2 { Some(*result) } else { None },
734 |err: &u32| *err > 1,
735 0,
736 )
737 .await;
738 assert_eq!(result, Ok(vec![1, 2]));
739 Ok(())
740 }
741
742 #[tokio::test]
743 async fn collect_and_retry_does_not_retry_when_success_condition_is_not_met()
744 -> Result<(), Box<dyn Error>> {
745 let result = RetryPolicy::fixed(Duration::from_millis(1))
746 .collect_and_retry(
747 |input: u32| async move { Ok::<u32, u32>(input + 1) },
748 |result: &u32| if *result < 1 { Some(*result) } else { None },
749 |err: &u32| *err > 1,
750 0,
751 )
752 .await;
753 assert_eq!(result, Ok(vec![1]));
754 Ok(())
755 }
756
757 #[tokio::test]
758 async fn collect_and_retry_retries_when_error_condition_is_met() -> Result<(), Box<dyn Error>> {
759 let mut task_ran = 0;
760 let _ = RetryPolicy::fixed(Duration::from_millis(1))
761 .collect_and_retry(
762 |_input: u32| {
763 task_ran += 1;
764 async move { Err::<u32, u32>(0) }
765 },
766 |result: &u32| if *result < 2 { Some(*result) } else { None },
767 |err: &u32| *err == 0,
768 0,
769 )
770 .await;
771 assert_eq!(task_ran, 6);
774 Ok(())
775 }
776
777 #[tokio::test]
778 async fn collect_and_retry_does_not_retry_when_error_condition_is_not_met()
779 -> Result<(), Box<dyn Error>> {
780 let result = RetryPolicy::fixed(Duration::from_millis(1))
781 .collect_and_retry(
782 |input: u32| async move { Err::<u32, u32>(input + 1) },
783 |result: &u32| if *result < 1 { Some(*result) } else { None },
784 |err: &u32| *err > 1,
785 0,
786 )
787 .await;
788 assert_eq!(result, Err(1));
789 Ok(())
790 }
791
792 #[tokio::test]
793 async fn ok_futures_yield_ok() -> Result<(), Box<dyn Error>> {
794 let result = RetryPolicy::default()
795 .retry(|| async { Ok::<u32, ()>(42) })
796 .await;
797 assert_eq!(result, Ok(42));
798 Ok(())
799 }
800
801 #[tokio::test]
802 async fn failed_futures_yield_err() -> Result<(), Box<dyn Error>> {
803 let result = RetryPolicy::fixed(Duration::from_millis(1))
804 .retry(|| async { Err::<u32, ()>(()) })
805 .await;
806 assert_eq!(result, Err(()));
807 Ok(())
808 }
809}