1#![allow(clippy::option_if_let_else)]
7#![allow(clippy::cast_precision_loss)]
8#![allow(clippy::cast_possible_wrap)]
9#![allow(clippy::cast_possible_truncation)]
10#![allow(clippy::cast_sign_loss)]
11#![allow(clippy::match_same_arms)]
12
13use std::future::Future;
14use std::time::Duration;
15
16use serde::{Deserialize, Serialize};
17use thiserror::Error;
18use tracing::{debug, warn};
19
20use crate::PostgresError;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct RetryStrategy {
25 pub max_attempts: u32,
27 pub base_delay: Duration,
29 pub max_delay: Duration,
31 pub backoff_multiplier: f64,
33 pub use_jitter: bool,
35}
36
37impl Default for RetryStrategy {
38 fn default() -> Self {
39 Self {
40 max_attempts: 3,
41 base_delay: Duration::from_millis(100),
42 max_delay: Duration::from_secs(5),
43 backoff_multiplier: 2.0,
44 use_jitter: true,
45 }
46 }
47}
48
49impl RetryStrategy {
50 pub const fn conservative() -> Self {
52 Self {
53 max_attempts: 5,
54 base_delay: Duration::from_millis(250),
55 max_delay: Duration::from_secs(10),
56 backoff_multiplier: 1.5,
57 use_jitter: true,
58 }
59 }
60
61 pub const fn aggressive() -> Self {
63 Self {
64 max_attempts: 2,
65 base_delay: Duration::from_millis(50),
66 max_delay: Duration::from_secs(2),
67 backoff_multiplier: 3.0,
68 use_jitter: false,
69 }
70 }
71
72 pub fn calculate_delay(&self, attempt: u32) -> Duration {
74 if attempt == 0 {
75 return Duration::ZERO;
76 }
77
78 let delay_ms =
79 self.base_delay.as_millis() as f64 * self.backoff_multiplier.powi((attempt - 1) as i32);
80
81 let delay = Duration::from_millis(delay_ms as u64);
82 let capped_delay = std::cmp::min(delay, self.max_delay);
83
84 if self.use_jitter {
85 add_jitter(capped_delay)
86 } else {
87 capped_delay
88 }
89 }
90}
91
92fn add_jitter(delay: Duration) -> Duration {
94 use rand::Rng;
95 let jitter_factor = rand::rng().random_range(0.8..1.2);
96 let jittered_ms = (delay.as_millis() as f64 * jitter_factor) as u64;
97 Duration::from_millis(jittered_ms)
98}
99
100#[derive(Debug, Error)]
102pub enum RetryError {
103 #[error("All retry attempts exhausted after {attempts} tries. Last error: {last_error}")]
105 ExhaustedAttempts {
106 attempts: u32,
108 last_error: PostgresError,
110 },
111
112 #[error("Non-retryable error: {0}")]
114 NonRetryable(PostgresError),
115}
116
117impl From<RetryError> for PostgresError {
118 fn from(error: RetryError) -> Self {
119 match error {
120 RetryError::ExhaustedAttempts { last_error, .. } => last_error,
121 RetryError::NonRetryable(error) => error,
122 }
123 }
124}
125
126pub fn is_retryable_error(error: &PostgresError) -> bool {
128 match error {
129 PostgresError::Connection(sqlx_error) => {
130 use sqlx::Error;
131 match sqlx_error {
132 Error::Io(_) | Error::Protocol(_) | Error::PoolTimedOut | Error::PoolClosed => true,
134 Error::Database(db_err) => {
136 if let Some(code) = db_err.code() {
137 matches!(
139 code.as_ref(),
140 "40001" | "40P01" | "53300" | "08000" | "08003" | "08006" | "08001" | "08004" )
149 } else {
150 false
151 }
152 }
153 _ => false,
155 }
156 }
157 PostgresError::PoolCreation(_) => true, PostgresError::Transaction(_) => true, PostgresError::Migration(_) => false, PostgresError::Serialization(_) => false, }
162}
163
164pub async fn retry_operation<F, Fut, T, E>(
166 strategy: &RetryStrategy,
167 operation_name: &str,
168 mut operation: F,
169) -> Result<T, RetryError>
170where
171 F: FnMut() -> Fut,
172 Fut: Future<Output = Result<T, E>>,
173 E: Into<PostgresError> + std::fmt::Debug,
174{
175 let mut last_error = None;
176
177 for attempt in 0..strategy.max_attempts {
178 match operation().await {
179 Ok(result) => {
180 if attempt > 0 {
181 debug!(
182 "Operation '{}' succeeded on attempt {} after retries",
183 operation_name,
184 attempt + 1
185 );
186 }
187 return Ok(result);
188 }
189 Err(error) => {
190 let postgres_error = error.into();
191
192 if !is_retryable_error(&postgres_error) {
194 warn!(
195 "Operation '{}' failed with non-retryable error: {:?}",
196 operation_name, postgres_error
197 );
198 return Err(RetryError::NonRetryable(postgres_error));
199 }
200
201 last_error = Some(postgres_error);
202
203 if attempt < strategy.max_attempts - 1 {
205 let delay = strategy.calculate_delay(attempt + 1);
206 warn!(
207 "Operation '{}' failed on attempt {}, retrying in {:?}. Error: {:?}",
208 operation_name,
209 attempt + 1,
210 delay,
211 last_error.as_ref().unwrap()
212 );
213 tokio::time::sleep(delay).await;
214 }
215 }
216 }
217 }
218
219 let final_error = last_error.expect("Should have at least one error");
221 Err(RetryError::ExhaustedAttempts {
222 attempts: strategy.max_attempts,
223 last_error: final_error,
224 })
225}
226
227#[macro_export]
229macro_rules! retry_db_operation {
230 ($strategy:expr, $operation_name:expr, $operation:expr) => {
231 $crate::retry::retry_operation($strategy, $operation_name, || async { $operation }).await
232 };
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use std::sync::atomic::{AtomicU32, Ordering};
239 use std::sync::Arc;
240
241 #[test]
242 fn test_retry_strategy_delay_calculation() {
243 let strategy = RetryStrategy {
244 max_attempts: 5,
245 base_delay: Duration::from_millis(100),
246 max_delay: Duration::from_secs(2),
247 backoff_multiplier: 2.0,
248 use_jitter: false,
249 };
250
251 assert_eq!(strategy.calculate_delay(0), Duration::ZERO);
253
254 assert_eq!(strategy.calculate_delay(1), Duration::from_millis(100));
256
257 assert_eq!(strategy.calculate_delay(2), Duration::from_millis(200));
259
260 assert_eq!(strategy.calculate_delay(3), Duration::from_millis(400));
262
263 assert_eq!(strategy.calculate_delay(4), Duration::from_millis(800));
265 }
266
267 #[test]
268 fn test_retry_strategy_presets() {
269 let conservative = RetryStrategy::conservative();
270 assert_eq!(conservative.max_attempts, 5);
271 assert!(conservative.use_jitter);
272
273 let aggressive = RetryStrategy::aggressive();
274 assert_eq!(aggressive.max_attempts, 2);
275 assert!(!aggressive.use_jitter);
276 }
277
278 #[tokio::test]
279 async fn test_retry_operation_success_first_attempt() {
280 let strategy = RetryStrategy::default();
281 let counter = Arc::new(AtomicU32::new(0));
282 let counter_clone = Arc::clone(&counter);
283
284 let result = retry_operation(&strategy, "test_operation", || {
285 let counter = Arc::clone(&counter_clone);
286 async move {
287 counter.fetch_add(1, Ordering::SeqCst);
288 Ok::<i32, PostgresError>(42)
289 }
290 })
291 .await;
292
293 assert!(result.is_ok());
294 assert_eq!(result.unwrap(), 42);
295 assert_eq!(counter.load(Ordering::SeqCst), 1);
296 }
297
298 #[tokio::test]
299 async fn test_retry_operation_success_after_retries() {
300 let strategy = RetryStrategy {
301 max_attempts: 3,
302 base_delay: Duration::from_millis(1), max_delay: Duration::from_millis(10),
304 backoff_multiplier: 2.0,
305 use_jitter: false,
306 };
307
308 let counter = Arc::new(AtomicU32::new(0));
309 let counter_clone = Arc::clone(&counter);
310
311 let result = retry_operation(&strategy, "test_operation", || {
312 let counter = Arc::clone(&counter_clone);
313 async move {
314 let count = counter.fetch_add(1, Ordering::SeqCst);
315 if count < 2 {
316 Err(PostgresError::Connection(sqlx::Error::PoolTimedOut))
318 } else {
319 Ok::<i32, PostgresError>(42)
321 }
322 }
323 })
324 .await;
325
326 assert!(result.is_ok());
327 assert_eq!(result.unwrap(), 42);
328 assert_eq!(counter.load(Ordering::SeqCst), 3);
329 }
330
331 #[tokio::test]
332 async fn test_retry_operation_exhausted_attempts() {
333 let strategy = RetryStrategy {
334 max_attempts: 2,
335 base_delay: Duration::from_millis(1), max_delay: Duration::from_millis(10),
337 backoff_multiplier: 2.0,
338 use_jitter: false,
339 };
340
341 let counter = Arc::new(AtomicU32::new(0));
342 let counter_clone = Arc::clone(&counter);
343
344 let result = retry_operation(&strategy, "test_operation", || {
345 let counter = Arc::clone(&counter_clone);
346 async move {
347 counter.fetch_add(1, Ordering::SeqCst);
348 Err::<i32, PostgresError>(PostgresError::Connection(sqlx::Error::PoolTimedOut))
349 }
350 })
351 .await;
352
353 assert!(result.is_err());
354 assert!(matches!(
355 result.unwrap_err(),
356 RetryError::ExhaustedAttempts { attempts: 2, .. }
357 ));
358 assert_eq!(counter.load(Ordering::SeqCst), 2);
359 }
360
361 #[test]
362 fn test_is_retryable_error() {
363 assert!(is_retryable_error(&PostgresError::Connection(
365 sqlx::Error::PoolTimedOut
366 )));
367 assert!(is_retryable_error(&PostgresError::PoolCreation(
368 "test".to_string()
369 )));
370 assert!(is_retryable_error(&PostgresError::Transaction(
371 "test".to_string()
372 )));
373
374 assert!(!is_retryable_error(&PostgresError::Migration(
376 "test".to_string()
377 )));
378 assert!(!is_retryable_error(&PostgresError::Serialization(
379 serde_json::Error::io(std::io::Error::new(std::io::ErrorKind::Other, "test"))
380 )));
381 }
382}