1use backoff::{ExponentialBackoff, ExponentialBackoffBuilder};
2use std::time::Duration;
3use tracing::{debug, warn};
4
5#[derive(Debug, Clone)]
7pub struct RetryConfig {
8 pub max_retries: u32,
9 pub initial_interval: Duration,
10 pub max_interval: Duration,
11 pub multiplier: f64,
12}
13
14impl Default for RetryConfig {
15 fn default() -> Self {
16 Self {
17 max_retries: 3,
18 initial_interval: Duration::from_millis(500),
19 max_interval: Duration::from_secs(30),
20 multiplier: 2.0,
21 }
22 }
23}
24
25impl RetryConfig {
26 pub fn to_backoff(&self) -> ExponentialBackoff {
27 ExponentialBackoffBuilder::new()
28 .with_initial_interval(self.initial_interval)
29 .with_max_interval(self.max_interval)
30 .with_multiplier(self.multiplier)
31 .with_max_elapsed_time(Some(Duration::from_secs(60)))
32 .build()
33 }
34}
35
36pub async fn retry_with_backoff<F, Fut, T, E>(
38 operation: F,
39 config: &RetryConfig,
40) -> Result<T, E>
41where
42 F: Fn() -> Fut,
43 Fut: std::future::Future<Output = Result<T, E>>,
44 E: std::fmt::Display,
45{
46 let mut delay = config.initial_interval;
47 let mut attempt = 0;
48
49 loop {
50 attempt += 1;
51
52 match operation().await {
53 Ok(result) => {
54 if attempt > 1 {
55 debug!("Operation succeeded on attempt {}", attempt);
56 }
57 return Ok(result);
58 }
59 Err(e) => {
60 if attempt >= config.max_retries {
61 warn!("Operation failed after {} attempts: {}", attempt, e);
62 return Err(e);
63 }
64
65 warn!(
66 "Operation failed (attempt {}): {}. Retrying in {:?}...",
67 attempt, e, delay
68 );
69
70 tokio::time::sleep(delay).await;
72
73 delay = std::cmp::min(
75 Duration::from_millis((delay.as_millis() as f64 * config.multiplier) as u64),
76 config.max_interval,
77 );
78 }
79 }
80 }
81}
82
83#[derive(Debug, Clone, Copy)]
85pub enum RetryStrategy {
86 Aggressive, Standard, Conservative, None, }
91
92impl RetryStrategy {
93 pub fn to_config(&self) -> RetryConfig {
94 match self {
95 Self::Aggressive => RetryConfig {
96 max_retries: 5,
97 initial_interval: Duration::from_millis(200),
98 max_interval: Duration::from_secs(10),
99 multiplier: 1.5,
100 },
101 Self::Standard => RetryConfig::default(),
102 Self::Conservative => RetryConfig {
103 max_retries: 2,
104 initial_interval: Duration::from_secs(2),
105 max_interval: Duration::from_secs(60),
106 multiplier: 3.0,
107 },
108 Self::None => RetryConfig {
109 max_retries: 0,
110 initial_interval: Duration::from_millis(0),
111 max_interval: Duration::from_millis(0),
112 multiplier: 1.0,
113 },
114 }
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121 use std::sync::atomic::{AtomicU32, Ordering};
122 use std::sync::Arc;
123
124 #[tokio::test]
125 async fn test_retry_eventually_succeeds() {
126 let attempt = Arc::new(AtomicU32::new(0));
127 let attempt_clone = attempt.clone();
128
129 let result = retry_with_backoff(
130 move || {
131 let attempt = attempt_clone.clone();
132 async move {
133 let count = attempt.fetch_add(1, Ordering::SeqCst) + 1;
134 if count < 3 {
135 Err("Temporary failure")
136 } else {
137 Ok("Success")
138 }
139 }
140 },
141 &RetryConfig::default(),
142 )
143 .await;
144
145 assert!(result.is_ok());
146 assert_eq!(result.unwrap(), "Success");
147 assert_eq!(attempt.load(Ordering::SeqCst), 3);
148 }
149
150 #[tokio::test]
151 async fn test_retry_respects_max_attempts() {
152 let attempt = Arc::new(AtomicU32::new(0));
153 let attempt_clone = attempt.clone();
154
155 let result = retry_with_backoff(
156 move || {
157 let attempt = attempt_clone.clone();
158 async move {
159 attempt.fetch_add(1, Ordering::SeqCst);
160 Err::<(), _>("Always fails")
161 }
162 },
163 &RetryConfig {
164 max_retries: 3,
165 ..Default::default()
166 },
167 )
168 .await;
169
170 assert!(result.is_err());
171 assert_eq!(attempt.load(Ordering::SeqCst), 3);
172 }
173
174 #[tokio::test]
175 async fn test_exponential_backoff_timing() {
176 let start = std::time::Instant::now();
177
178 let _ = retry_with_backoff(
179 || async { Err::<(), _>("Fail") },
180 &RetryConfig {
181 max_retries: 3,
182 initial_interval: Duration::from_millis(100),
183 multiplier: 2.0,
184 ..Default::default()
185 },
186 )
187 .await;
188
189 let elapsed = start.elapsed();
190
191 assert!(
194 elapsed.as_millis() >= 200,
195 "Expected at least 200ms, got {}ms",
196 elapsed.as_millis()
197 );
198 }
199
200 #[tokio::test]
201 async fn test_retry_strategy_aggressive() {
202 let config = RetryStrategy::Aggressive.to_config();
203 assert_eq!(config.max_retries, 5);
204 assert_eq!(config.initial_interval, Duration::from_millis(200));
205 }
206
207 #[tokio::test]
208 async fn test_retry_strategy_standard() {
209 let config = RetryStrategy::Standard.to_config();
210 assert_eq!(config.max_retries, 3);
211 assert_eq!(config.initial_interval, Duration::from_millis(500));
212 }
213
214 #[tokio::test]
215 async fn test_retry_strategy_conservative() {
216 let config = RetryStrategy::Conservative.to_config();
217 assert_eq!(config.max_retries, 2);
218 assert_eq!(config.initial_interval, Duration::from_secs(2));
219 }
220
221 #[tokio::test]
222 async fn test_retry_strategy_none() {
223 let config = RetryStrategy::None.to_config();
224 assert_eq!(config.max_retries, 0);
225 }
226
227 #[tokio::test]
228 async fn test_no_retry_on_immediate_success() {
229 let attempt = Arc::new(AtomicU32::new(0));
230 let attempt_clone = attempt.clone();
231
232 let result = retry_with_backoff(
233 move || {
234 let attempt = attempt_clone.clone();
235 async move {
236 attempt.fetch_add(1, Ordering::SeqCst);
237 Ok::<_, String>("Success")
238 }
239 },
240 &RetryConfig::default(),
241 )
242 .await;
243
244 assert!(result.is_ok());
245 assert_eq!(attempt.load(Ordering::SeqCst), 1); }
247}