1use std::fmt::Display;
6use std::future::Future;
7use std::time::Duration;
8
9use tokio::time::sleep;
10use tracing::{debug, warn};
11
12const DEFAULT_MAX_ATTEMPTS: u32 = 3;
14
15const DEFAULT_INITIAL_DELAY_MS: u64 = 100;
17
18const DEFAULT_MAX_DELAY_MS: u64 = 2000;
20
21#[derive(Debug, Clone)]
23pub struct RetryConfig {
24 pub max_attempts: u32,
26 pub initial_delay_ms: u64,
28 pub max_delay_ms: u64,
30}
31
32impl Default for RetryConfig {
33 fn default() -> Self {
34 Self {
35 max_attempts: DEFAULT_MAX_ATTEMPTS,
36 initial_delay_ms: DEFAULT_INITIAL_DELAY_MS,
37 max_delay_ms: DEFAULT_MAX_DELAY_MS,
38 }
39 }
40}
41
42impl RetryConfig {
43 pub fn new(max_attempts: u32, initial_delay_ms: u64, max_delay_ms: u64) -> Self {
46 Self {
47 max_attempts: max_attempts.max(1),
48 initial_delay_ms,
49 max_delay_ms,
50 }
51 }
52
53 fn delay_for_attempt(&self, attempt: u32) -> Duration {
55 let delay_ms = self
56 .initial_delay_ms
57 .saturating_mul(1u64.checked_shl(attempt).unwrap_or(u64::MAX));
58 let capped_delay_ms = delay_ms.min(self.max_delay_ms);
59 Duration::from_millis(capped_delay_ms)
60 }
61}
62
63pub async fn with_retry<T, E, F, Fut>(config: &RetryConfig, operation: F) -> Result<T, E>
89where
90 F: Fn() -> Fut,
91 Fut: Future<Output = Result<T, E>>,
92 E: Display,
93{
94 let mut last_error: Option<E> = None;
95
96 for attempt in 0..config.max_attempts {
97 match operation().await {
98 Ok(result) => {
99 if attempt > 0 {
100 debug!("Operation succeeded on attempt {}", attempt + 1);
101 }
102 return Ok(result);
103 }
104 Err(e) => {
105 let is_last_attempt = attempt + 1 >= config.max_attempts;
106
107 if is_last_attempt {
108 warn!(
109 "Operation failed after {} attempts: {}",
110 config.max_attempts, e
111 );
112 last_error = Some(e);
113 } else {
114 let delay = config.delay_for_attempt(attempt);
115 warn!(
116 "Operation failed (attempt {}/{}): {}. Retrying in {:?}...",
117 attempt + 1,
118 config.max_attempts,
119 e,
120 delay
121 );
122 sleep(delay).await;
123 last_error = Some(e);
124 }
125 }
126 }
127 }
128
129 Err(last_error.expect("at least one attempt should have been made"))
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136 use std::sync::Arc;
137 use std::sync::atomic::{AtomicU32, Ordering};
138
139 #[tokio::test]
140 async fn test_retry_success_first_attempt() {
141 let config = RetryConfig::default();
142 let result: Result<&str, &str> = with_retry(&config, || async { Ok("success") }).await;
143 assert_eq!(result, Ok("success"));
144 }
145
146 #[tokio::test]
147 async fn test_retry_success_after_failures() {
148 let config = RetryConfig::new(3, 10, 100); let attempt_count = Arc::new(AtomicU32::new(0));
150 let attempt_count_clone = attempt_count.clone();
151
152 let result: Result<&str, &str> = with_retry(&config, || {
153 let count = attempt_count_clone.clone();
154 async move {
155 let current = count.fetch_add(1, Ordering::SeqCst);
156 if current < 2 {
157 Err("transient error")
158 } else {
159 Ok("success")
160 }
161 }
162 })
163 .await;
164
165 assert_eq!(result, Ok("success"));
166 assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
167 }
168
169 #[tokio::test]
170 async fn test_retry_all_failures() {
171 let config = RetryConfig::new(3, 10, 100); let attempt_count = Arc::new(AtomicU32::new(0));
173 let attempt_count_clone = attempt_count.clone();
174
175 let result: Result<&str, &str> = with_retry(&config, || {
176 let count = attempt_count_clone.clone();
177 async move {
178 count.fetch_add(1, Ordering::SeqCst);
179 Err("persistent error")
180 }
181 })
182 .await;
183
184 assert_eq!(result, Err("persistent error"));
185 assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
186 }
187
188 #[test]
189 fn test_delay_for_attempt_no_overflow() {
190 let config = RetryConfig::new(100, 100, 2000);
192 let d64 = config.delay_for_attempt(64);
194 let d100 = config.delay_for_attempt(99);
195 assert_eq!(d64, Duration::from_millis(2000));
197 assert_eq!(d100, Duration::from_millis(2000));
198 }
199
200 #[test]
201 fn test_delay_calculation() {
202 let config = RetryConfig::new(5, 100, 1000);
203
204 assert_eq!(config.delay_for_attempt(0), Duration::from_millis(100));
205 assert_eq!(config.delay_for_attempt(1), Duration::from_millis(200));
206 assert_eq!(config.delay_for_attempt(2), Duration::from_millis(400));
207 assert_eq!(config.delay_for_attempt(3), Duration::from_millis(800));
208 assert_eq!(config.delay_for_attempt(4), Duration::from_millis(1000)); }
210}