1use std::time::Duration;
6
7use crate::alias::RetryConfig;
8use crate::error::{Error, Result};
9
10pub async fn retry_with_backoff<T, F, Fut, R>(
26 config: &RetryConfig,
27 mut operation: F,
28 is_retryable: R,
29) -> Result<T>
30where
31 F: FnMut() -> Fut,
32 Fut: std::future::Future<Output = Result<T>>,
33 R: Fn(&Error) -> bool,
34{
35 let mut attempt = 0;
36
37 loop {
38 attempt += 1;
39
40 match operation().await {
41 Ok(result) => return Ok(result),
42 Err(e) => {
43 if attempt >= config.max_attempts || !is_retryable(&e) {
44 return Err(e);
45 }
46
47 let backoff = calculate_backoff(config, attempt);
48 tracing::debug!(
49 attempt = attempt,
50 backoff_ms = backoff.as_millis(),
51 error = %e,
52 "Retrying after transient error"
53 );
54
55 tokio::time::sleep(backoff).await;
56 }
57 }
58 }
59}
60
61fn calculate_backoff(config: &RetryConfig, attempt: u32) -> Duration {
63 let base_ms = config.initial_backoff_ms * (1u64 << (attempt - 1).min(10));
65 let capped_ms = base_ms.min(config.max_backoff_ms);
66
67 let jitter_ms = rand_jitter(capped_ms);
69 Duration::from_millis(capped_ms + jitter_ms)
70}
71
72fn rand_jitter(max: u64) -> u64 {
74 use std::time::SystemTime;
75 let nanos = SystemTime::now()
76 .duration_since(SystemTime::UNIX_EPOCH)
77 .unwrap_or_default()
78 .subsec_nanos() as u64;
79 nanos % max.max(1)
80}
81
82pub fn is_retryable_error(error: &Error) -> bool {
84 match error {
85 Error::Network(msg) => {
86 let msg_lower = msg.to_lowercase();
88 msg_lower.contains("timeout")
89 || msg_lower.contains("connection reset")
90 || msg_lower.contains("connection refused")
91 || msg_lower.contains("503")
92 || msg_lower.contains("service unavailable")
93 || msg_lower.contains("too many requests")
94 || msg_lower.contains("429")
95 || msg_lower.contains("request rate")
96 || msg_lower.contains("slow down")
97 }
98 Error::Io(e) => {
99 matches!(
101 e.kind(),
102 std::io::ErrorKind::ConnectionReset
103 | std::io::ErrorKind::ConnectionRefused
104 | std::io::ErrorKind::TimedOut
105 | std::io::ErrorKind::Interrupted
106 )
107 }
108 Error::Auth(_)
110 | Error::NotFound(_)
111 | Error::AliasNotFound(_)
112 | Error::Conflict(_)
113 | Error::InvalidPath(_)
114 | Error::Config(_)
115 | Error::UnsupportedFeature(_) => false,
116 Error::General(msg) => {
118 let msg_lower = msg.to_lowercase();
119 msg_lower.contains("timeout") || msg_lower.contains("temporary")
120 }
121 _ => false,
122 }
123}
124
125#[derive(Debug, Clone)]
127pub struct RetryBuilder {
128 max_attempts: u32,
129 initial_backoff_ms: u64,
130 max_backoff_ms: u64,
131}
132
133impl RetryBuilder {
134 pub fn new() -> Self {
135 Self {
136 max_attempts: 3,
137 initial_backoff_ms: 100,
138 max_backoff_ms: 10000,
139 }
140 }
141
142 pub fn max_attempts(mut self, n: u32) -> Self {
143 self.max_attempts = n;
144 self
145 }
146
147 pub fn initial_backoff_ms(mut self, ms: u64) -> Self {
148 self.initial_backoff_ms = ms;
149 self
150 }
151
152 pub fn max_backoff_ms(mut self, ms: u64) -> Self {
153 self.max_backoff_ms = ms;
154 self
155 }
156
157 pub fn build(self) -> RetryConfig {
158 RetryConfig {
159 max_attempts: self.max_attempts,
160 initial_backoff_ms: self.initial_backoff_ms,
161 max_backoff_ms: self.max_backoff_ms,
162 }
163 }
164}
165
166impl Default for RetryBuilder {
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn test_calculate_backoff() {
178 let config = RetryConfig {
179 max_attempts: 3,
180 initial_backoff_ms: 100,
181 max_backoff_ms: 10000,
182 };
183
184 let b1 = calculate_backoff(&config, 1);
186 assert!(b1.as_millis() >= 100 && b1.as_millis() < 200);
187
188 let b2 = calculate_backoff(&config, 2);
190 assert!(b2.as_millis() >= 200 && b2.as_millis() < 400);
191
192 let b3 = calculate_backoff(&config, 3);
194 assert!(b3.as_millis() >= 400 && b3.as_millis() < 800);
195 }
196
197 #[test]
198 fn test_backoff_cap() {
199 let config = RetryConfig {
200 max_attempts: 10,
201 initial_backoff_ms: 1000,
202 max_backoff_ms: 5000,
203 };
204
205 let b = calculate_backoff(&config, 10);
207 assert!(b.as_millis() <= 10000); }
209
210 #[test]
211 fn test_is_retryable_error() {
212 assert!(is_retryable_error(&Error::Network(
214 "connection timeout".to_string()
215 )));
216 assert!(is_retryable_error(&Error::Network(
217 "503 Service Unavailable".to_string()
218 )));
219 assert!(is_retryable_error(&Error::Network(
220 "429 Too Many Requests".to_string()
221 )));
222
223 assert!(!is_retryable_error(&Error::Auth(
225 "access denied".to_string()
226 )));
227
228 assert!(!is_retryable_error(&Error::NotFound(
230 "object not found".to_string()
231 )));
232 }
233
234 #[test]
235 fn test_retry_builder() {
236 let config = RetryBuilder::new()
237 .max_attempts(5)
238 .initial_backoff_ms(200)
239 .max_backoff_ms(20000)
240 .build();
241
242 assert_eq!(config.max_attempts, 5);
243 assert_eq!(config.initial_backoff_ms, 200);
244 assert_eq!(config.max_backoff_ms, 20000);
245 }
246
247 #[tokio::test]
248 async fn test_retry_success_first_attempt() {
249 let config = RetryConfig::default();
250 let mut calls = 0;
251
252 let result = retry_with_backoff(
253 &config,
254 || {
255 calls += 1;
256 async { Ok::<_, Error>(42) }
257 },
258 |_| true,
259 )
260 .await;
261
262 assert_eq!(result.unwrap(), 42);
263 assert_eq!(calls, 1);
264 }
265
266 #[tokio::test]
267 async fn test_retry_success_after_failure() {
268 let config = RetryConfig {
269 max_attempts: 3,
270 initial_backoff_ms: 1, max_backoff_ms: 10,
272 };
273 let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
274 let call_count_clone = call_count.clone();
275
276 let result = retry_with_backoff(
277 &config,
278 || {
279 let cc = call_count_clone.clone();
280 async move {
281 let count = cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
282 if count < 2 {
283 Err(Error::Network("timeout".to_string()))
284 } else {
285 Ok(42)
286 }
287 }
288 },
289 is_retryable_error,
290 )
291 .await;
292
293 assert_eq!(result.unwrap(), 42);
294 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 3);
295 }
296
297 #[tokio::test]
298 async fn test_retry_exhausted() {
299 let config = RetryConfig {
300 max_attempts: 2,
301 initial_backoff_ms: 1,
302 max_backoff_ms: 10,
303 };
304 let mut calls = 0;
305
306 let result: Result<()> = retry_with_backoff(
307 &config,
308 || {
309 calls += 1;
310 async { Err(Error::Network("always fails".to_string())) }
311 },
312 |_| true,
313 )
314 .await;
315
316 assert!(result.is_err());
317 assert_eq!(calls, 2);
318 }
319
320 #[tokio::test]
321 async fn test_retry_non_retryable() {
322 let config = RetryConfig {
323 max_attempts: 3,
324 initial_backoff_ms: 1,
325 max_backoff_ms: 10,
326 };
327 let mut calls = 0;
328
329 let result: Result<()> = retry_with_backoff(
330 &config,
331 || {
332 calls += 1;
333 async { Err(Error::NotFound("not found".to_string())) }
334 },
335 is_retryable_error,
336 )
337 .await;
338
339 assert!(result.is_err());
340 assert_eq!(calls, 1); }
342}