1use anyhow::{anyhow, Result};
32use serde::{Deserialize, Serialize};
33use std::future::Future;
34use std::time::Duration;
35use tokio::time::sleep;
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
39pub enum BackoffStrategy {
40 Fixed,
42 Exponential,
44 Linear,
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
50pub enum JitterType {
51 None,
53 Full,
55 Equal,
57 Decorrelated,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct RetryPolicy {
64 pub max_attempts: u32,
66 pub base_delay: Duration,
68 pub max_delay: Duration,
70 pub strategy: BackoffStrategy,
72 pub jitter: JitterType,
74 pub backoff_multiplier: f64,
76 pub total_timeout: Option<Duration>,
78}
79
80impl RetryPolicy {
81 pub fn exponential(base_delay: Duration, max_attempts: u32) -> Self {
87 Self {
88 max_attempts,
89 base_delay,
90 max_delay: Duration::from_secs(60),
91 strategy: BackoffStrategy::Exponential,
92 jitter: JitterType::Equal,
93 backoff_multiplier: 2.0,
94 total_timeout: None,
95 }
96 }
97
98 pub fn fixed(delay: Duration, max_attempts: u32) -> Self {
100 Self {
101 max_attempts,
102 base_delay: delay,
103 max_delay: delay,
104 strategy: BackoffStrategy::Fixed,
105 jitter: JitterType::None,
106 backoff_multiplier: 1.0,
107 total_timeout: None,
108 }
109 }
110
111 pub fn linear(base_delay: Duration, max_attempts: u32) -> Self {
113 Self {
114 max_attempts,
115 base_delay,
116 max_delay: Duration::from_secs(60),
117 strategy: BackoffStrategy::Linear,
118 jitter: JitterType::Equal,
119 backoff_multiplier: 1.0,
120 total_timeout: None,
121 }
122 }
123
124 pub fn with_max_delay(mut self, max_delay: Duration) -> Self {
126 self.max_delay = max_delay;
127 self
128 }
129
130 pub fn with_jitter(mut self, jitter: JitterType) -> Self {
132 self.jitter = jitter;
133 self
134 }
135
136 pub fn with_timeout(mut self, timeout: Duration) -> Self {
138 self.total_timeout = Some(timeout);
139 self
140 }
141
142 pub fn with_multiplier(mut self, multiplier: f64) -> Self {
144 self.backoff_multiplier = multiplier;
145 self
146 }
147
148 fn calculate_delay(&self, attempt: u32) -> Duration {
150 if attempt == 0 {
151 return Duration::from_secs(0);
152 }
153
154 let base_ms = self.base_delay.as_millis() as f64;
155
156 let computed_delay_ms = match self.strategy {
157 BackoffStrategy::Fixed => base_ms,
158 BackoffStrategy::Exponential => {
159 base_ms * self.backoff_multiplier.powi(attempt as i32 - 1)
160 }
161 BackoffStrategy::Linear => base_ms * attempt as f64,
162 };
163
164 let capped_ms = computed_delay_ms.min(self.max_delay.as_millis() as f64);
166
167 let final_ms = match self.jitter {
169 JitterType::None => capped_ms,
170 JitterType::Full => {
171 fastrand::f64() * capped_ms
173 }
174 JitterType::Equal => {
175 capped_ms / 2.0 + (fastrand::f64() * capped_ms / 2.0)
177 }
178 JitterType::Decorrelated => {
179 let last_delay = if attempt > 1 {
181 self.calculate_delay(attempt - 1).as_millis() as f64
182 } else {
183 base_ms
184 };
185 let random_delay = base_ms + (fastrand::f64() * (last_delay * 3.0 - base_ms));
186 random_delay.min(self.max_delay.as_millis() as f64)
187 }
188 };
189
190 Duration::from_millis(final_ms as u64)
191 }
192
193 pub async fn retry<F, Fut, T, E>(&self, mut f: F) -> Result<T>
201 where
202 F: FnMut() -> Fut,
203 Fut: Future<Output = Result<T, E>>,
204 E: std::error::Error + Send + Sync + 'static,
205 {
206 let start_time = std::time::Instant::now();
207 let mut last_error = None;
208
209 for attempt in 0..self.max_attempts {
210 if let Some(timeout) = self.total_timeout {
212 if start_time.elapsed() >= timeout {
213 return Err(anyhow!("Retry timeout exceeded after {attempt} attempts"));
214 }
215 }
216
217 match f().await {
219 Ok(result) => return Ok(result),
220 Err(e) => {
221 last_error = Some(e);
222
223 if attempt + 1 < self.max_attempts {
225 let delay = self.calculate_delay(attempt + 1);
226 sleep(delay).await;
227 }
228 }
229 }
230 }
231
232 if let Some(e) = last_error {
234 Err(anyhow!(
235 "Operation failed after {} attempts: {}",
236 self.max_attempts,
237 e
238 ))
239 } else {
240 Err(anyhow!(
241 "Operation failed after {} attempts",
242 self.max_attempts
243 ))
244 }
245 }
246}
247
248impl Default for RetryPolicy {
249 fn default() -> Self {
250 Self::exponential(Duration::from_millis(100), 3)
251 }
252}
253
254pub trait Retryable<T, E> {
256 fn with_retry(self, policy: RetryPolicy) -> impl Future<Output = Result<T>>;
258}
259
260#[derive(Debug, Clone, Default, Serialize, Deserialize)]
262pub struct RetryStats {
263 pub total_attempts: u64,
265 pub successful_ops: u64,
267 pub failed_ops: u64,
269 pub total_delay_ms: u64,
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use std::sync::atomic::{AtomicU32, Ordering};
277 use std::sync::Arc;
278
279 #[tokio::test]
280 async fn test_retry_success_first_attempt() {
281 let policy = RetryPolicy::exponential(Duration::from_millis(10), 3);
282
283 let result = policy
284 .retry(|| async { Ok::<_, std::io::Error>("success") })
285 .await;
286
287 assert!(result.is_ok());
288 assert_eq!(result.unwrap(), "success");
289 }
290
291 #[tokio::test]
292 async fn test_retry_success_after_failures() {
293 let counter = Arc::new(AtomicU32::new(0));
294 let policy = RetryPolicy::exponential(Duration::from_millis(10), 3);
295
296 let counter_clone = counter.clone();
297 let result = policy
298 .retry(|| {
299 let c = counter_clone.clone();
300 async move {
301 let count = c.fetch_add(1, Ordering::SeqCst);
302 if count < 2 {
303 Err(std::io::Error::new(
304 std::io::ErrorKind::Other,
305 "Transient failure",
306 ))
307 } else {
308 Ok("success")
309 }
310 }
311 })
312 .await;
313
314 assert!(result.is_ok());
315 assert_eq!(result.unwrap(), "success");
316 assert_eq!(counter.load(Ordering::SeqCst), 3);
317 }
318
319 #[tokio::test]
320 async fn test_retry_all_attempts_fail() {
321 let policy = RetryPolicy::exponential(Duration::from_millis(10), 3);
322
323 let result = policy
324 .retry(|| async {
325 Err::<&str, std::io::Error>(std::io::Error::new(
326 std::io::ErrorKind::Other,
327 "Always fails",
328 ))
329 })
330 .await;
331
332 assert!(result.is_err());
333 }
334
335 #[tokio::test]
336 async fn test_fixed_backoff() {
337 let policy = RetryPolicy::fixed(Duration::from_millis(50), 3);
338
339 for i in 1..=3 {
340 let delay = policy.calculate_delay(i);
341 assert_eq!(delay.as_millis(), 50);
342 }
343 }
344
345 #[tokio::test]
346 async fn test_exponential_backoff() {
347 let policy = RetryPolicy::exponential(Duration::from_millis(100), 4);
348
349 let policy_no_jitter = policy.with_jitter(JitterType::None);
351 let d1 = policy_no_jitter.calculate_delay(1).as_millis();
352 let d2 = policy_no_jitter.calculate_delay(2).as_millis();
353 let d3 = policy_no_jitter.calculate_delay(3).as_millis();
354
355 assert_eq!(d1, 100);
356 assert_eq!(d2, 200);
357 assert_eq!(d3, 400);
358 }
359
360 #[tokio::test]
361 async fn test_linear_backoff() {
362 let policy =
363 RetryPolicy::linear(Duration::from_millis(100), 4).with_jitter(JitterType::None);
364
365 let d1 = policy.calculate_delay(1).as_millis();
366 let d2 = policy.calculate_delay(2).as_millis();
367 let d3 = policy.calculate_delay(3).as_millis();
368
369 assert_eq!(d1, 100);
370 assert_eq!(d2, 200);
371 assert_eq!(d3, 300);
372 }
373
374 #[tokio::test]
375 async fn test_max_delay_cap() {
376 let policy = RetryPolicy::exponential(Duration::from_millis(100), 10)
377 .with_max_delay(Duration::from_millis(500))
378 .with_jitter(JitterType::None);
379
380 let delay = policy.calculate_delay(5);
381 assert!(delay.as_millis() <= 500);
382 }
383
384 #[tokio::test]
385 async fn test_jitter_full() {
386 let policy =
387 RetryPolicy::exponential(Duration::from_millis(100), 3).with_jitter(JitterType::Full);
388
389 for _ in 0..10 {
391 let delay = policy.calculate_delay(1);
392 assert!(delay.as_millis() <= 100);
393 }
394 }
395
396 #[tokio::test]
397 async fn test_jitter_equal() {
398 let policy =
399 RetryPolicy::exponential(Duration::from_millis(100), 3).with_jitter(JitterType::Equal);
400
401 for _ in 0..10 {
403 let delay = policy.calculate_delay(1);
404 let ms = delay.as_millis();
405 assert!(ms >= 50 && ms <= 100);
406 }
407 }
408
409 #[tokio::test]
410 async fn test_timeout() {
411 let policy = RetryPolicy::exponential(Duration::from_millis(50), 10)
412 .with_timeout(Duration::from_millis(150));
413
414 let start = std::time::Instant::now();
415 let result = policy
416 .retry(|| async {
417 Err::<&str, std::io::Error>(std::io::Error::new(
418 std::io::ErrorKind::Other,
419 "Always fails",
420 ))
421 })
422 .await;
423
424 let elapsed = start.elapsed();
425 assert!(result.is_err());
426 assert!(elapsed < Duration::from_millis(500));
428 assert!(elapsed >= Duration::from_millis(150));
429 }
430}