1use crate::Result;
7use std::future::Future;
8use std::time::Duration;
9
10#[derive(Debug, Clone)]
24pub struct RetryConfig {
25 pub max_attempts: u32,
27 pub initial_backoff: Duration,
29 pub max_backoff: Duration,
31 pub backoff_multiplier: f64,
33}
34
35impl Default for RetryConfig {
36 fn default() -> Self {
37 Self {
38 max_attempts: 3,
39 initial_backoff: Duration::from_millis(100),
40 max_backoff: Duration::from_secs(5),
41 backoff_multiplier: 2.0,
42 }
43 }
44}
45
46impl RetryConfig {
47 pub fn fast() -> Self {
49 Self {
50 max_attempts: 3,
51 initial_backoff: Duration::from_millis(10),
52 max_backoff: Duration::from_secs(1),
53 backoff_multiplier: 2.0,
54 }
55 }
56
57 pub fn slow() -> Self {
59 Self {
60 max_attempts: 5,
61 initial_backoff: Duration::from_millis(500),
62 max_backoff: Duration::from_secs(30),
63 backoff_multiplier: 2.0,
64 }
65 }
66}
67
68#[tracing::instrument(skip(operation, config), fields(max_attempts = config.max_attempts))]
90pub async fn retry_on_transient_error<F, Fut, T>(operation: F, config: RetryConfig) -> Result<T>
91where
92 F: Fn() -> Fut,
93 Fut: Future<Output = Result<T>>,
94{
95 let mut attempt = 0;
96 let mut backoff = config.initial_backoff;
97
98 loop {
99 attempt += 1;
100
101 match operation().await {
102 Ok(result) => {
103 if attempt > 1 {
104 tracing::info!(
105 "Operation succeeded on attempt {}/{}",
106 attempt,
107 config.max_attempts
108 );
109 }
110 return Ok(result);
111 }
112 Err(e) => {
113 if attempt >= config.max_attempts {
114 tracing::error!("Operation failed after {} attempts: {}", attempt, e);
115 return Err(e);
116 }
117
118 if !e.is_retryable() {
119 tracing::warn!(
120 "Non-retryable error encountered on attempt {}: {}",
121 attempt,
122 e
123 );
124 return Err(e);
125 }
126
127 tracing::warn!(
128 "Retryable error on attempt {}/{}, retrying after {:?}: {}",
129 attempt,
130 config.max_attempts,
131 backoff,
132 e
133 );
134
135 tokio::time::sleep(backoff).await;
136 backoff = std::cmp::min(
137 Duration::from_secs_f64(backoff.as_secs_f64() * config.backoff_multiplier),
138 config.max_backoff,
139 );
140 }
141 }
142 }
143}
144
145#[derive(Debug, Clone, Default)]
147pub struct RetryStats {
148 pub total_operations: u64,
150 pub successful_operations: u64,
152 pub failed_operations: u64,
154 pub total_retries: u64,
156}
157
158impl RetryStats {
159 pub fn new() -> Self {
161 Self::default()
162 }
163
164 pub fn success_rate(&self) -> f64 {
166 if self.total_operations == 0 {
167 return 0.0;
168 }
169 self.successful_operations as f64 / self.total_operations as f64
170 }
171
172 pub fn avg_retries_per_operation(&self) -> f64 {
174 if self.total_operations == 0 {
175 return 0.0;
176 }
177 self.total_retries as f64 / self.total_operations as f64
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use crate::StorageError;
185 use std::sync::atomic::{AtomicU32, Ordering};
186 use std::sync::Arc;
187
188 #[tokio::test]
189 async fn test_retry_succeeds_immediately() {
190 let config = RetryConfig::fast();
191
192 let result = retry_on_transient_error(|| async { Ok(42) }, config).await;
193
194 assert!(result.is_ok());
195 assert_eq!(result.unwrap(), 42);
196 }
197
198 #[tokio::test]
199 async fn test_retry_succeeds_eventually() {
200 let counter = Arc::new(AtomicU32::new(0));
201 let counter_clone = counter.clone();
202
203 let config = RetryConfig {
204 max_attempts: 3,
205 initial_backoff: Duration::from_millis(10),
206 max_backoff: Duration::from_secs(1),
207 backoff_multiplier: 2.0,
208 };
209
210 let result = retry_on_transient_error(
211 || {
212 let c = counter_clone.clone();
213 async move {
214 let count = c.fetch_add(1, Ordering::SeqCst);
215 if count < 2 {
216 Err(StorageError::Database(sqlx::Error::PoolTimedOut))
218 } else {
219 Ok(42)
220 }
221 }
222 },
223 config,
224 )
225 .await;
226
227 assert!(result.is_ok());
228 assert_eq!(result.unwrap(), 42);
229 assert_eq!(counter.load(Ordering::SeqCst), 3);
230 }
231
232 #[tokio::test]
233 async fn test_retry_fails_after_max_attempts() {
234 let counter = Arc::new(AtomicU32::new(0));
235 let counter_clone = counter.clone();
236
237 let config = RetryConfig {
238 max_attempts: 3,
239 initial_backoff: Duration::from_millis(10),
240 max_backoff: Duration::from_secs(1),
241 backoff_multiplier: 2.0,
242 };
243
244 let result = retry_on_transient_error(
245 || {
246 let c = counter_clone.clone();
247 async move {
248 c.fetch_add(1, Ordering::SeqCst);
249 Err::<i32, _>(StorageError::Database(sqlx::Error::PoolTimedOut))
251 }
252 },
253 config,
254 )
255 .await;
256
257 assert!(result.is_err());
258 assert_eq!(counter.load(Ordering::SeqCst), 3);
259 }
260
261 #[tokio::test]
262 async fn test_retry_stops_on_non_retryable_error() {
263 let counter = Arc::new(AtomicU32::new(0));
264 let counter_clone = counter.clone();
265
266 let config = RetryConfig::fast();
267
268 let result = retry_on_transient_error(
269 || {
270 let c = counter_clone.clone();
271 async move {
272 c.fetch_add(1, Ordering::SeqCst);
273 Err::<i32, _>(StorageError::validation("Invalid input"))
275 }
276 },
277 config,
278 )
279 .await;
280
281 assert!(result.is_err());
282 assert_eq!(counter.load(Ordering::SeqCst), 1);
284 }
285
286 #[test]
287 fn test_retry_config_default() {
288 let config = RetryConfig::default();
289 assert_eq!(config.max_attempts, 3);
290 assert_eq!(config.initial_backoff, Duration::from_millis(100));
291 assert_eq!(config.max_backoff, Duration::from_secs(5));
292 assert_eq!(config.backoff_multiplier, 2.0);
293 }
294
295 #[test]
296 fn test_retry_config_fast() {
297 let config = RetryConfig::fast();
298 assert_eq!(config.max_attempts, 3);
299 assert_eq!(config.initial_backoff, Duration::from_millis(10));
300 assert_eq!(config.max_backoff, Duration::from_secs(1));
301 }
302
303 #[test]
304 fn test_retry_config_slow() {
305 let config = RetryConfig::slow();
306 assert_eq!(config.max_attempts, 5);
307 assert_eq!(config.initial_backoff, Duration::from_millis(500));
308 assert_eq!(config.max_backoff, Duration::from_secs(30));
309 }
310
311 #[test]
312 fn test_retry_stats() {
313 let mut stats = RetryStats::new();
314 assert_eq!(stats.success_rate(), 0.0);
315 assert_eq!(stats.avg_retries_per_operation(), 0.0);
316
317 stats.total_operations = 100;
318 stats.successful_operations = 95;
319 stats.failed_operations = 5;
320 stats.total_retries = 20;
321
322 assert_eq!(stats.success_rate(), 0.95);
323 assert_eq!(stats.avg_retries_per_operation(), 0.2);
324 }
325}