1use backoff::{backoff::Backoff, ExponentialBackoff};
2use std::time::Duration;
3
4#[derive(Clone)]
6pub struct RetryPolicy {
7 pub max_retries: u32,
9 pub initial_delay: Duration,
11 pub max_delay: Duration,
13}
14
15impl Default for RetryPolicy {
16 fn default() -> Self {
17 Self {
18 max_retries: 2,
19 initial_delay: Duration::from_millis(500),
20 max_delay: Duration::from_secs(32),
21 }
22 }
23}
24
25impl RetryPolicy {
26 pub fn new() -> Self {
28 Self::default()
29 }
30
31 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
33 self.max_retries = max_retries;
34 self
35 }
36
37 pub fn with_initial_delay(mut self, delay: Duration) -> Self {
39 self.initial_delay = delay;
40 self
41 }
42
43 pub fn with_max_delay(mut self, delay: Duration) -> Self {
45 self.max_delay = delay;
46 self
47 }
48
49 pub async fn retry<F, Fut, T, E>(&self, mut f: F) -> Result<T, E>
51 where
52 F: FnMut() -> Fut,
53 Fut: std::future::Future<Output = Result<T, E>>,
54 E: std::fmt::Debug,
55 {
56 let mut backoff = ExponentialBackoff {
57 initial_interval: self.initial_delay,
58 max_interval: self.max_delay,
59 max_elapsed_time: None,
60 ..Default::default()
61 };
62
63 let mut attempt = 0;
64 loop {
65 match f().await {
66 Ok(result) => return Ok(result),
67 Err(err) => {
68 attempt += 1;
69 if attempt > self.max_retries {
70 return Err(err);
71 }
72
73 if let Some(delay) = backoff.next_backoff() {
74 tracing::debug!(
75 "Retrying after {:?}, attempt {}/{}",
76 delay,
77 attempt,
78 self.max_retries
79 );
80 tokio::time::sleep(delay).await;
81 } else {
82 return Err(err);
83 }
84 }
85 }
86 }
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 #[tokio::test]
95 async fn test_retry_succeeds_on_first_attempt() {
96 let policy = RetryPolicy::default();
97 let result = policy
98 .retry(|| async { Ok::<i32, String>(42) })
99 .await
100 .unwrap();
101 assert_eq!(result, 42);
102 }
103
104 #[tokio::test]
105 async fn test_retry_succeeds_after_failures() {
106 use std::sync::atomic::{AtomicU32, Ordering};
107 use std::sync::Arc;
108
109 let policy = RetryPolicy::default();
110 let attempts = Arc::new(AtomicU32::new(0));
111 let attempts_clone = attempts.clone();
112
113 let result = policy
114 .retry(move || {
115 let attempts = attempts_clone.clone();
116 async move {
117 let count = attempts.fetch_add(1, Ordering::SeqCst) + 1;
118 if count < 2 {
119 Err("transient error")
120 } else {
121 Ok(42)
122 }
123 }
124 })
125 .await
126 .unwrap();
127
128 assert_eq!(result, 42);
129 assert_eq!(attempts.load(Ordering::SeqCst), 2);
130 }
131
132 #[tokio::test]
133 async fn test_retry_fails_after_max_retries() {
134 use std::sync::atomic::{AtomicU32, Ordering};
135 use std::sync::Arc;
136
137 let policy = RetryPolicy::default().with_max_retries(1);
138 let attempts = Arc::new(AtomicU32::new(0));
139 let attempts_clone = attempts.clone();
140
141 let result = policy
142 .retry(move || {
143 let attempts = attempts_clone.clone();
144 async move {
145 attempts.fetch_add(1, Ordering::SeqCst);
146 Err::<i32, _>("permanent error")
147 }
148 })
149 .await;
150
151 assert!(result.is_err());
152 assert_eq!(attempts.load(Ordering::SeqCst), 2); }
154}