1use backoff::{backoff::Backoff, ExponentialBackoff, ExponentialBackoffBuilder};
7use std::future::Future;
8use std::time::Duration;
9use tracing::{debug, warn};
10
11#[derive(Debug, Clone)]
13pub struct RetryConfig {
14 pub max_retries: u32,
16 pub base_delay_ms: u64,
18 pub max_delay_ms: u64,
20 pub backoff_multiplier: f64,
22 pub use_jitter: bool,
24}
25
26impl Default for RetryConfig {
27 fn default() -> Self {
28 Self {
29 max_retries: 3,
30 base_delay_ms: 1000, max_delay_ms: 30_000, backoff_multiplier: 2.0, use_jitter: true,
34 }
35 }
36}
37
38impl RetryConfig {
39 pub fn fast() -> Self {
41 Self {
42 max_retries: 5,
43 base_delay_ms: 100, max_delay_ms: 5_000, backoff_multiplier: 1.5, use_jitter: true,
47 }
48 }
49
50 pub fn slow() -> Self {
52 Self {
53 max_retries: 3,
54 base_delay_ms: 5000, max_delay_ms: 60_000, backoff_multiplier: 2.0, use_jitter: true,
58 }
59 }
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub enum ErrorClass {
65 Permanent,
67 Retryable,
69 RateLimited,
71}
72
73fn create_backoff(config: &RetryConfig) -> ExponentialBackoff {
75 let mut builder = ExponentialBackoffBuilder::new();
76 builder.with_initial_interval(Duration::from_millis(config.base_delay_ms));
77 builder.with_max_interval(Duration::from_millis(config.max_delay_ms));
78 builder.with_multiplier(config.backoff_multiplier);
79 builder.with_max_elapsed_time(None); if !config.use_jitter {
82 builder.with_randomization_factor(0.0);
83 } else {
84 builder.with_randomization_factor(0.25);
86 }
87
88 builder.build()
89}
90
91pub async fn retry_async<F, Fut, T, E, C>(
140 mut operation: F,
141 classifier: C,
142 config: &RetryConfig,
143 operation_name: &str,
144) -> Result<T, E>
145where
146 F: FnMut() -> Fut,
147 Fut: Future<Output = Result<T, E>>,
148 E: std::fmt::Display + Clone,
149 C: Fn(&E) -> ErrorClass,
150{
151 debug!(
152 "Starting operation '{}' with retry config: max_retries={}, base_delay={}ms",
153 operation_name, config.max_retries, config.base_delay_ms
154 );
155
156 let mut backoff = create_backoff(config);
157 let mut attempts = 0u32;
158
159 loop {
160 attempts += 1;
161 debug!("Attempt {} for '{}'", attempts, operation_name);
162
163 match operation().await {
164 Ok(result) => {
165 if attempts > 1 {
166 debug!(
167 "Operation '{}' succeeded after {} attempts",
168 operation_name, attempts
169 );
170 }
171 return Ok(result);
172 }
173 Err(error) => {
174 let error_class = classifier(&error);
175
176 warn!(
177 "Operation '{}' failed (attempt {}): {} (class: {:?})",
178 operation_name, attempts, error, error_class
179 );
180
181 match error_class {
183 ErrorClass::Permanent => {
184 debug!("Error is permanent, not retrying");
185 return Err(error);
186 }
187 ErrorClass::Retryable | ErrorClass::RateLimited => {
188 if attempts > config.max_retries {
189 warn!(
190 "Operation '{}' failed after {} attempts",
191 operation_name, attempts
192 );
193 return Err(error);
194 }
195
196 let delay = if let Some(duration) = backoff.next_backoff() {
198 if error_class == ErrorClass::RateLimited {
200 duration * 2
201 } else {
202 duration
203 }
204 } else {
205 warn!("Backoff exhausted for '{}'", operation_name);
207 return Err(error);
208 };
209
210 debug!("Retrying '{}' after {:?}", operation_name, delay);
211 tokio::time::sleep(delay).await;
212 }
213 }
214 }
215 }
216 }
217}
218
219pub async fn retry_with_backoff<F, Fut, T>(
221 operation: F,
222 config: &RetryConfig,
223 operation_name: &str,
224) -> Result<T, String>
225where
226 F: FnMut() -> Fut,
227 Fut: Future<Output = Result<T, String>>,
228{
229 retry_async(
230 operation,
231 |_| ErrorClass::Retryable, config,
233 operation_name,
234 )
235 .await
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use std::sync::atomic::{AtomicU32, Ordering};
242 use std::sync::Arc;
243
244 #[tokio::test]
245 async fn test_retry_succeeds_first_attempt() {
246 let config = RetryConfig::fast();
247 let result = retry_async(
248 || async { Ok::<_, String>("success") },
249 |_| ErrorClass::Retryable,
250 &config,
251 "test_op",
252 )
253 .await;
254
255 assert_eq!(result.unwrap(), "success");
256 }
257
258 #[tokio::test]
259 async fn test_retry_succeeds_after_failures() {
260 let attempts = Arc::new(AtomicU32::new(0));
261 let attempts_clone = attempts.clone();
262
263 let config = RetryConfig::fast();
264 let result = retry_async(
265 || {
266 let attempts = attempts_clone.clone();
267 async move {
268 let count = attempts.fetch_add(1, Ordering::SeqCst);
269 if count < 2 {
270 Err("temporary failure".to_string())
271 } else {
272 Ok("success")
273 }
274 }
275 },
276 |_| ErrorClass::Retryable,
277 &config,
278 "test_op",
279 )
280 .await;
281
282 assert_eq!(result.unwrap(), "success");
283 assert_eq!(attempts.load(Ordering::SeqCst), 3);
284 }
285
286 #[tokio::test]
287 async fn test_retry_permanent_error_no_retry() {
288 let attempts = Arc::new(AtomicU32::new(0));
289 let attempts_clone = attempts.clone();
290
291 let config = RetryConfig::fast();
292 let result = retry_async(
293 || {
294 let attempts = attempts_clone.clone();
295 async move {
296 attempts.fetch_add(1, Ordering::SeqCst);
297 Err::<String, _>("permanent error".to_string())
298 }
299 },
300 |_| ErrorClass::Permanent,
301 &config,
302 "test_op",
303 )
304 .await;
305
306 assert!(result.is_err());
307 assert_eq!(attempts.load(Ordering::SeqCst), 1); }
309
310 #[tokio::test]
311 async fn test_retry_exhausts_all_attempts() {
312 let attempts = Arc::new(AtomicU32::new(0));
313 let attempts_clone = attempts.clone();
314
315 let config = RetryConfig {
316 max_retries: 2,
317 base_delay_ms: 10,
318 max_delay_ms: 100,
319 backoff_multiplier: 2.0,
320 use_jitter: false,
321 };
322
323 let result = retry_async(
324 || {
325 let attempts = attempts_clone.clone();
326 async move {
327 attempts.fetch_add(1, Ordering::SeqCst);
328 Err::<String, _>("always fails".to_string())
329 }
330 },
331 |_| ErrorClass::Retryable,
332 &config,
333 "test_op",
334 )
335 .await;
336
337 assert!(result.is_err());
338 assert_eq!(attempts.load(Ordering::SeqCst), 3); }
340
341 #[test]
342 fn test_create_backoff_config() {
343 let config = RetryConfig {
345 max_retries: 5,
346 base_delay_ms: 100,
347 max_delay_ms: 10_000,
348 backoff_multiplier: 2.0,
349 use_jitter: false,
350 };
351
352 let backoff = create_backoff(&config);
353 assert_eq!(backoff.initial_interval, Duration::from_millis(100));
354 assert_eq!(backoff.max_interval, Duration::from_millis(10_000));
355 assert_eq!(backoff.multiplier, 2.0);
356 assert_eq!(backoff.randomization_factor, 0.0); }
358
359 #[test]
360 fn test_create_backoff_with_jitter() {
361 let config = RetryConfig {
362 max_retries: 5,
363 base_delay_ms: 100,
364 max_delay_ms: 10_000,
365 backoff_multiplier: 2.0,
366 use_jitter: true,
367 };
368
369 let backoff = create_backoff(&config);
370 assert_eq!(backoff.randomization_factor, 0.25); }
372
373 #[test]
374 fn test_retry_config_presets() {
375 let fast = RetryConfig::fast();
376 assert_eq!(fast.base_delay_ms, 100);
377 assert_eq!(fast.max_retries, 5);
378
379 let slow = RetryConfig::slow();
380 assert_eq!(slow.base_delay_ms, 5000);
381 assert_eq!(slow.max_retries, 3);
382 }
383}