1use std::time::Duration;
7use tokio::time::sleep;
8use tracing::warn;
9
10use crate::error::{Error, Result};
11
12#[derive(Debug, Clone)]
14pub struct RetryConfig {
15 pub max_retries: u32,
17 pub initial_delay: Duration,
19 pub max_delay: Duration,
21 pub backoff_multiplier: f64,
23 pub jitter: bool,
25}
26
27impl Default for RetryConfig {
28 fn default() -> Self {
29 Self {
30 max_retries: 3,
31 initial_delay: Duration::from_millis(100),
32 max_delay: Duration::from_secs(10),
33 backoff_multiplier: 2.0,
34 jitter: true,
35 }
36 }
37}
38
39impl RetryConfig {
40 pub fn none() -> Self {
42 Self {
43 max_retries: 0,
44 ..Default::default()
45 }
46 }
47
48 pub fn aggressive() -> Self {
50 Self {
51 max_retries: 5,
52 initial_delay: Duration::from_millis(200),
53 max_delay: Duration::from_secs(30),
54 backoff_multiplier: 2.0,
55 jitter: true,
56 }
57 }
58
59 fn delay_for_attempt(&self, attempt: u32) -> Duration {
61 let base_delay =
62 self.initial_delay.as_millis() as f64 * self.backoff_multiplier.powi(attempt as i32);
63
64 let delay_ms = base_delay.min(self.max_delay.as_millis() as f64);
65
66 let delay_ms = if self.jitter {
67 let jitter = delay_ms * 0.25 * rand_simple();
69 delay_ms + jitter
70 } else {
71 delay_ms
72 };
73
74 Duration::from_millis(delay_ms as u64)
75 }
76}
77
78fn rand_simple() -> f64 {
80 use std::time::SystemTime;
81 let nanos = SystemTime::now()
82 .duration_since(SystemTime::UNIX_EPOCH)
83 .map(|d| d.subsec_nanos())
84 .unwrap_or(0);
85 (nanos % 1000) as f64 / 1000.0
86}
87
88pub async fn with_retry<F, Fut, T>(config: &RetryConfig, mut f: F) -> Result<T>
90where
91 F: FnMut() -> Fut,
92 Fut: std::future::Future<Output = Result<T>>,
93{
94 let mut last_error = None;
95
96 for attempt in 0..=config.max_retries {
97 match f().await {
98 Ok(result) => return Ok(result),
99 Err(e) => {
100 if !e.is_retryable() {
101 return Err(e);
102 }
103
104 if attempt < config.max_retries {
105 let delay = config.delay_for_attempt(attempt);
106 warn!(
107 attempt = attempt + 1,
108 max_retries = config.max_retries,
109 delay_ms = delay.as_millis(),
110 error = %e,
111 "Request failed, retrying"
112 );
113 sleep(delay).await;
114 }
115
116 last_error = Some(e);
117 }
118 }
119 }
120
121 Err(last_error.unwrap_or_else(|| Error::http("unknown retry error")))
122}
123
124#[cfg(test)]
125mod tests {
126 #![allow(clippy::unwrap_used, clippy::expect_used)]
127 use super::*;
128
129 #[test]
130 fn test_retry_config_default() {
131 let config = RetryConfig::default();
132 assert_eq!(config.max_retries, 3);
133 }
134
135 #[test]
136 fn test_delay_calculation() {
137 let config = RetryConfig {
138 initial_delay: Duration::from_millis(100),
139 backoff_multiplier: 2.0,
140 jitter: false,
141 ..Default::default()
142 };
143
144 assert_eq!(config.delay_for_attempt(0), Duration::from_millis(100));
145 assert_eq!(config.delay_for_attempt(1), Duration::from_millis(200));
146 assert_eq!(config.delay_for_attempt(2), Duration::from_millis(400));
147 }
148
149 #[tokio::test]
150 async fn test_with_retry_success() {
151 let config = RetryConfig::default();
152 let result = with_retry(&config, || async { Ok::<_, Error>(42) }).await;
153 assert_eq!(result.unwrap(), 42);
154 }
155}