1#![allow(clippy::option_if_let_else)]
7#![allow(clippy::cast_precision_loss)]
8#![allow(clippy::cast_possible_wrap)]
9#![allow(clippy::cast_possible_truncation)]
10#![allow(clippy::cast_sign_loss)]
11#![allow(clippy::match_same_arms)]
12
13use std::future::Future;
14use std::time::Duration;
15
16use serde::{Deserialize, Serialize};
17use thiserror::Error;
18use tracing::{debug, warn};
19
20use crate::PostgresError;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct RetryStrategy {
25    pub max_attempts: u32,
27    pub base_delay: Duration,
29    pub max_delay: Duration,
31    pub backoff_multiplier: f64,
33    pub use_jitter: bool,
35}
36
37impl Default for RetryStrategy {
38    fn default() -> Self {
39        Self {
40            max_attempts: 3,
41            base_delay: Duration::from_millis(100),
42            max_delay: Duration::from_secs(5),
43            backoff_multiplier: 2.0,
44            use_jitter: true,
45        }
46    }
47}
48
49impl RetryStrategy {
50    pub const fn conservative() -> Self {
52        Self {
53            max_attempts: 5,
54            base_delay: Duration::from_millis(250),
55            max_delay: Duration::from_secs(10),
56            backoff_multiplier: 1.5,
57            use_jitter: true,
58        }
59    }
60
61    pub const fn aggressive() -> Self {
63        Self {
64            max_attempts: 2,
65            base_delay: Duration::from_millis(50),
66            max_delay: Duration::from_secs(2),
67            backoff_multiplier: 3.0,
68            use_jitter: false,
69        }
70    }
71
72    pub fn calculate_delay(&self, attempt: u32) -> Duration {
74        if attempt == 0 {
75            return Duration::ZERO;
76        }
77
78        let delay_ms =
79            self.base_delay.as_millis() as f64 * self.backoff_multiplier.powi((attempt - 1) as i32);
80
81        let delay = Duration::from_millis(delay_ms as u64);
82        let capped_delay = std::cmp::min(delay, self.max_delay);
83
84        if self.use_jitter {
85            add_jitter(capped_delay)
86        } else {
87            capped_delay
88        }
89    }
90}
91
92fn add_jitter(delay: Duration) -> Duration {
94    use rand::Rng;
95    let jitter_factor = rand::rng().random_range(0.8..1.2);
96    let jittered_ms = (delay.as_millis() as f64 * jitter_factor) as u64;
97    Duration::from_millis(jittered_ms)
98}
99
100#[derive(Debug, Error)]
102pub enum RetryError {
103    #[error("All retry attempts exhausted after {attempts} tries. Last error: {last_error}")]
105    ExhaustedAttempts {
106        attempts: u32,
108        last_error: PostgresError,
110    },
111
112    #[error("Non-retryable error: {0}")]
114    NonRetryable(PostgresError),
115}
116
117impl From<RetryError> for PostgresError {
118    fn from(error: RetryError) -> Self {
119        match error {
120            RetryError::ExhaustedAttempts { last_error, .. } => last_error,
121            RetryError::NonRetryable(error) => error,
122        }
123    }
124}
125
126pub fn is_retryable_error(error: &PostgresError) -> bool {
128    match error {
129        PostgresError::Connection(sqlx_error) => {
130            use sqlx::Error;
131            match sqlx_error {
132                Error::Io(_) | Error::Protocol(_) | Error::PoolTimedOut | Error::PoolClosed => true,
134                Error::Database(db_err) => {
136                    if let Some(code) = db_err.code() {
137                        matches!(
139                            code.as_ref(),
140                            "40001" | "40P01" | "53300" | "08000" | "08003" | "08006" | "08001" | "08004" )
149                    } else {
150                        false
151                    }
152                }
153                _ => false,
155            }
156        }
157        PostgresError::PoolCreation(_) => true, PostgresError::Transaction(_) => true,  PostgresError::Migration(_) => false,   PostgresError::Serialization(_) => false, }
162}
163
164pub async fn retry_operation<F, Fut, T, E>(
166    strategy: &RetryStrategy,
167    operation_name: &str,
168    mut operation: F,
169) -> Result<T, RetryError>
170where
171    F: FnMut() -> Fut,
172    Fut: Future<Output = Result<T, E>>,
173    E: Into<PostgresError> + std::fmt::Debug,
174{
175    let mut last_error = None;
176
177    for attempt in 0..strategy.max_attempts {
178        match operation().await {
179            Ok(result) => {
180                if attempt > 0 {
181                    debug!(
182                        "Operation '{}' succeeded on attempt {} after retries",
183                        operation_name,
184                        attempt + 1
185                    );
186                }
187                return Ok(result);
188            }
189            Err(error) => {
190                let postgres_error = error.into();
191
192                if !is_retryable_error(&postgres_error) {
194                    warn!(
195                        "Operation '{}' failed with non-retryable error: {:?}",
196                        operation_name, postgres_error
197                    );
198                    return Err(RetryError::NonRetryable(postgres_error));
199                }
200
201                last_error = Some(postgres_error);
202
203                if attempt < strategy.max_attempts - 1 {
205                    let delay = strategy.calculate_delay(attempt + 1);
206                    warn!(
207                        "Operation '{}' failed on attempt {}, retrying in {:?}. Error: {:?}",
208                        operation_name,
209                        attempt + 1,
210                        delay,
211                        last_error.as_ref().unwrap()
212                    );
213                    tokio::time::sleep(delay).await;
214                }
215            }
216        }
217    }
218
219    let final_error = last_error.expect("Should have at least one error");
221    Err(RetryError::ExhaustedAttempts {
222        attempts: strategy.max_attempts,
223        last_error: final_error,
224    })
225}
226
227#[macro_export]
229macro_rules! retry_db_operation {
230    ($strategy:expr, $operation_name:expr, $operation:expr) => {
231        $crate::retry::retry_operation($strategy, $operation_name, || async { $operation }).await
232    };
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use std::sync::atomic::{AtomicU32, Ordering};
239    use std::sync::Arc;
240
241    #[test]
242    fn test_retry_strategy_delay_calculation() {
243        let strategy = RetryStrategy {
244            max_attempts: 5,
245            base_delay: Duration::from_millis(100),
246            max_delay: Duration::from_secs(2),
247            backoff_multiplier: 2.0,
248            use_jitter: false,
249        };
250
251        assert_eq!(strategy.calculate_delay(0), Duration::ZERO);
253
254        assert_eq!(strategy.calculate_delay(1), Duration::from_millis(100));
256
257        assert_eq!(strategy.calculate_delay(2), Duration::from_millis(200));
259
260        assert_eq!(strategy.calculate_delay(3), Duration::from_millis(400));
262
263        assert_eq!(strategy.calculate_delay(4), Duration::from_millis(800));
265    }
266
267    #[test]
268    fn test_retry_strategy_presets() {
269        let conservative = RetryStrategy::conservative();
270        assert_eq!(conservative.max_attempts, 5);
271        assert!(conservative.use_jitter);
272
273        let aggressive = RetryStrategy::aggressive();
274        assert_eq!(aggressive.max_attempts, 2);
275        assert!(!aggressive.use_jitter);
276    }
277
278    #[tokio::test]
279    async fn test_retry_operation_success_first_attempt() {
280        let strategy = RetryStrategy::default();
281        let counter = Arc::new(AtomicU32::new(0));
282        let counter_clone = Arc::clone(&counter);
283
284        let result = retry_operation(&strategy, "test_operation", || {
285            let counter = Arc::clone(&counter_clone);
286            async move {
287                counter.fetch_add(1, Ordering::SeqCst);
288                Ok::<i32, PostgresError>(42)
289            }
290        })
291        .await;
292
293        assert!(result.is_ok());
294        assert_eq!(result.unwrap(), 42);
295        assert_eq!(counter.load(Ordering::SeqCst), 1);
296    }
297
298    #[tokio::test]
299    async fn test_retry_operation_success_after_retries() {
300        let strategy = RetryStrategy {
301            max_attempts: 3,
302            base_delay: Duration::from_millis(1), max_delay: Duration::from_millis(10),
304            backoff_multiplier: 2.0,
305            use_jitter: false,
306        };
307
308        let counter = Arc::new(AtomicU32::new(0));
309        let counter_clone = Arc::clone(&counter);
310
311        let result = retry_operation(&strategy, "test_operation", || {
312            let counter = Arc::clone(&counter_clone);
313            async move {
314                let count = counter.fetch_add(1, Ordering::SeqCst);
315                if count < 2 {
316                    Err(PostgresError::Connection(sqlx::Error::PoolTimedOut))
318                } else {
319                    Ok::<i32, PostgresError>(42)
321                }
322            }
323        })
324        .await;
325
326        assert!(result.is_ok());
327        assert_eq!(result.unwrap(), 42);
328        assert_eq!(counter.load(Ordering::SeqCst), 3);
329    }
330
331    #[tokio::test]
332    async fn test_retry_operation_exhausted_attempts() {
333        let strategy = RetryStrategy {
334            max_attempts: 2,
335            base_delay: Duration::from_millis(1), max_delay: Duration::from_millis(10),
337            backoff_multiplier: 2.0,
338            use_jitter: false,
339        };
340
341        let counter = Arc::new(AtomicU32::new(0));
342        let counter_clone = Arc::clone(&counter);
343
344        let result = retry_operation(&strategy, "test_operation", || {
345            let counter = Arc::clone(&counter_clone);
346            async move {
347                counter.fetch_add(1, Ordering::SeqCst);
348                Err::<i32, PostgresError>(PostgresError::Connection(sqlx::Error::PoolTimedOut))
349            }
350        })
351        .await;
352
353        assert!(result.is_err());
354        assert!(matches!(
355            result.unwrap_err(),
356            RetryError::ExhaustedAttempts { attempts: 2, .. }
357        ));
358        assert_eq!(counter.load(Ordering::SeqCst), 2);
359    }
360
361    #[test]
362    fn test_is_retryable_error() {
363        assert!(is_retryable_error(&PostgresError::Connection(
365            sqlx::Error::PoolTimedOut
366        )));
367        assert!(is_retryable_error(&PostgresError::PoolCreation(
368            "test".to_string()
369        )));
370        assert!(is_retryable_error(&PostgresError::Transaction(
371            "test".to_string()
372        )));
373
374        assert!(!is_retryable_error(&PostgresError::Migration(
376            "test".to_string()
377        )));
378        assert!(!is_retryable_error(&PostgresError::Serialization(
379            serde_json::Error::io(std::io::Error::new(std::io::ErrorKind::Other, "test"))
380        )));
381    }
382}