eventcore_postgres/
retry.rs

1//! Retry strategies and utilities for `PostgreSQL` operations
2//!
3//! This module provides robust retry mechanisms with exponential backoff
4//! for handling transient database failures.
5
6#![allow(clippy::option_if_let_else)]
7#![allow(clippy::cast_precision_loss)]
8#![allow(clippy::cast_possible_wrap)]
9#![allow(clippy::cast_possible_truncation)]
10#![allow(clippy::cast_sign_loss)]
11#![allow(clippy::match_same_arms)]
12
13use std::future::Future;
14use std::time::Duration;
15
16use serde::{Deserialize, Serialize};
17use thiserror::Error;
18use tracing::{debug, warn};
19
20use crate::PostgresError;
21
22/// Retry strategy configuration
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct RetryStrategy {
25    /// Maximum number of retry attempts
26    pub max_attempts: u32,
27    /// Base delay between attempts
28    pub base_delay: Duration,
29    /// Maximum delay (exponential backoff cap)
30    pub max_delay: Duration,
31    /// Multiplier for exponential backoff
32    pub backoff_multiplier: f64,
33    /// Whether to add jitter to prevent thundering herd
34    pub use_jitter: bool,
35}
36
37impl Default for RetryStrategy {
38    fn default() -> Self {
39        Self {
40            max_attempts: 3,
41            base_delay: Duration::from_millis(100),
42            max_delay: Duration::from_secs(5),
43            backoff_multiplier: 2.0,
44            use_jitter: true,
45        }
46    }
47}
48
49impl RetryStrategy {
50    /// Create a conservative retry strategy for critical operations
51    pub const fn conservative() -> Self {
52        Self {
53            max_attempts: 5,
54            base_delay: Duration::from_millis(250),
55            max_delay: Duration::from_secs(10),
56            backoff_multiplier: 1.5,
57            use_jitter: true,
58        }
59    }
60
61    /// Create an aggressive retry strategy for non-critical operations
62    pub const fn aggressive() -> Self {
63        Self {
64            max_attempts: 2,
65            base_delay: Duration::from_millis(50),
66            max_delay: Duration::from_secs(2),
67            backoff_multiplier: 3.0,
68            use_jitter: false,
69        }
70    }
71
72    /// Calculate delay for a given attempt number
73    pub fn calculate_delay(&self, attempt: u32) -> Duration {
74        if attempt == 0 {
75            return Duration::ZERO;
76        }
77
78        let delay_ms =
79            self.base_delay.as_millis() as f64 * self.backoff_multiplier.powi((attempt - 1) as i32);
80
81        let delay = Duration::from_millis(delay_ms as u64);
82        let capped_delay = std::cmp::min(delay, self.max_delay);
83
84        if self.use_jitter {
85            add_jitter(capped_delay)
86        } else {
87            capped_delay
88        }
89    }
90}
91
92/// Add random jitter to prevent thundering herd effect
93fn add_jitter(delay: Duration) -> Duration {
94    use rand::Rng;
95    let jitter_factor = rand::rng().random_range(0.8..1.2);
96    let jittered_ms = (delay.as_millis() as f64 * jitter_factor) as u64;
97    Duration::from_millis(jittered_ms)
98}
99
100/// Errors that can occur during retry operations
101#[derive(Debug, Error)]
102pub enum RetryError {
103    /// All retry attempts exhausted
104    #[error("All retry attempts exhausted after {attempts} tries. Last error: {last_error}")]
105    ExhaustedAttempts {
106        /// Number of attempts made
107        attempts: u32,
108        /// The last error encountered
109        last_error: PostgresError,
110    },
111
112    /// Non-retryable error encountered
113    #[error("Non-retryable error: {0}")]
114    NonRetryable(PostgresError),
115}
116
117impl From<RetryError> for PostgresError {
118    fn from(error: RetryError) -> Self {
119        match error {
120            RetryError::ExhaustedAttempts { last_error, .. } => last_error,
121            RetryError::NonRetryable(error) => error,
122        }
123    }
124}
125
126/// Determines if an error is retryable
127pub fn is_retryable_error(error: &PostgresError) -> bool {
128    match error {
129        PostgresError::Connection(sqlx_error) => {
130            use sqlx::Error;
131            match sqlx_error {
132                // Connection issues are retryable
133                Error::Io(_) | Error::Protocol(_) | Error::PoolTimedOut | Error::PoolClosed => true,
134                // Database errors might be retryable depending on the type
135                Error::Database(db_err) => {
136                    if let Some(code) = db_err.code() {
137                        // PostgreSQL error codes that are retryable
138                        matches!(
139                            code.as_ref(),
140                            "40001" | // serialization_failure
141                            "40P01" | // deadlock_detected  
142                            "53300" | // too_many_connections
143                            "08000" | // connection_exception
144                            "08003" | // connection_does_not_exist
145                            "08006" | // connection_failure
146                            "08001" | // sqlclient_unable_to_establish_sqlconnection
147                            "08004" // sqlserver_rejected_establishment_of_sqlconnection
148                        )
149                    } else {
150                        false
151                    }
152                }
153                // Configuration and other errors are not retryable
154                _ => false,
155            }
156        }
157        PostgresError::PoolCreation(_) => true, // Pool creation can be retried
158        PostgresError::Transaction(_) => true,  // Transaction errors might be temporary
159        PostgresError::Migration(_) => false,   // Migration errors are not retryable
160        PostgresError::Serialization(_) => false, // Serialization errors are not retryable
161    }
162}
163
164/// Execute an operation with retry logic
165pub async fn retry_operation<F, Fut, T, E>(
166    strategy: &RetryStrategy,
167    operation_name: &str,
168    mut operation: F,
169) -> Result<T, RetryError>
170where
171    F: FnMut() -> Fut,
172    Fut: Future<Output = Result<T, E>>,
173    E: Into<PostgresError> + std::fmt::Debug,
174{
175    let mut last_error = None;
176
177    for attempt in 0..strategy.max_attempts {
178        match operation().await {
179            Ok(result) => {
180                if attempt > 0 {
181                    debug!(
182                        "Operation '{}' succeeded on attempt {} after retries",
183                        operation_name,
184                        attempt + 1
185                    );
186                }
187                return Ok(result);
188            }
189            Err(error) => {
190                let postgres_error = error.into();
191
192                // Check if this error is retryable
193                if !is_retryable_error(&postgres_error) {
194                    warn!(
195                        "Operation '{}' failed with non-retryable error: {:?}",
196                        operation_name, postgres_error
197                    );
198                    return Err(RetryError::NonRetryable(postgres_error));
199                }
200
201                last_error = Some(postgres_error);
202
203                // If this is not the last attempt, wait before retrying
204                if attempt < strategy.max_attempts - 1 {
205                    let delay = strategy.calculate_delay(attempt + 1);
206                    warn!(
207                        "Operation '{}' failed on attempt {}, retrying in {:?}. Error: {:?}",
208                        operation_name,
209                        attempt + 1,
210                        delay,
211                        last_error.as_ref().unwrap()
212                    );
213                    tokio::time::sleep(delay).await;
214                }
215            }
216        }
217    }
218
219    // All attempts exhausted
220    let final_error = last_error.expect("Should have at least one error");
221    Err(RetryError::ExhaustedAttempts {
222        attempts: strategy.max_attempts,
223        last_error: final_error,
224    })
225}
226
227/// Macro for retrying database operations with default strategy
228#[macro_export]
229macro_rules! retry_db_operation {
230    ($strategy:expr, $operation_name:expr, $operation:expr) => {
231        $crate::retry::retry_operation($strategy, $operation_name, || async { $operation }).await
232    };
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use std::sync::atomic::{AtomicU32, Ordering};
239    use std::sync::Arc;
240
241    #[test]
242    fn test_retry_strategy_delay_calculation() {
243        let strategy = RetryStrategy {
244            max_attempts: 5,
245            base_delay: Duration::from_millis(100),
246            max_delay: Duration::from_secs(2),
247            backoff_multiplier: 2.0,
248            use_jitter: false,
249        };
250
251        // First attempt should have no delay
252        assert_eq!(strategy.calculate_delay(0), Duration::ZERO);
253
254        // Second attempt: base delay
255        assert_eq!(strategy.calculate_delay(1), Duration::from_millis(100));
256
257        // Third attempt: base * 2
258        assert_eq!(strategy.calculate_delay(2), Duration::from_millis(200));
259
260        // Fourth attempt: base * 4
261        assert_eq!(strategy.calculate_delay(3), Duration::from_millis(400));
262
263        // Fifth attempt: base * 8 = 800ms, but capped at max_delay (2s)
264        assert_eq!(strategy.calculate_delay(4), Duration::from_millis(800));
265    }
266
267    #[test]
268    fn test_retry_strategy_presets() {
269        let conservative = RetryStrategy::conservative();
270        assert_eq!(conservative.max_attempts, 5);
271        assert!(conservative.use_jitter);
272
273        let aggressive = RetryStrategy::aggressive();
274        assert_eq!(aggressive.max_attempts, 2);
275        assert!(!aggressive.use_jitter);
276    }
277
278    #[tokio::test]
279    async fn test_retry_operation_success_first_attempt() {
280        let strategy = RetryStrategy::default();
281        let counter = Arc::new(AtomicU32::new(0));
282        let counter_clone = Arc::clone(&counter);
283
284        let result = retry_operation(&strategy, "test_operation", || {
285            let counter = Arc::clone(&counter_clone);
286            async move {
287                counter.fetch_add(1, Ordering::SeqCst);
288                Ok::<i32, PostgresError>(42)
289            }
290        })
291        .await;
292
293        assert!(result.is_ok());
294        assert_eq!(result.unwrap(), 42);
295        assert_eq!(counter.load(Ordering::SeqCst), 1);
296    }
297
298    #[tokio::test]
299    async fn test_retry_operation_success_after_retries() {
300        let strategy = RetryStrategy {
301            max_attempts: 3,
302            base_delay: Duration::from_millis(1), // Fast for testing
303            max_delay: Duration::from_millis(10),
304            backoff_multiplier: 2.0,
305            use_jitter: false,
306        };
307
308        let counter = Arc::new(AtomicU32::new(0));
309        let counter_clone = Arc::clone(&counter);
310
311        let result = retry_operation(&strategy, "test_operation", || {
312            let counter = Arc::clone(&counter_clone);
313            async move {
314                let count = counter.fetch_add(1, Ordering::SeqCst);
315                if count < 2 {
316                    // Fail first two attempts
317                    Err(PostgresError::Connection(sqlx::Error::PoolTimedOut))
318                } else {
319                    // Succeed on third attempt
320                    Ok::<i32, PostgresError>(42)
321                }
322            }
323        })
324        .await;
325
326        assert!(result.is_ok());
327        assert_eq!(result.unwrap(), 42);
328        assert_eq!(counter.load(Ordering::SeqCst), 3);
329    }
330
331    #[tokio::test]
332    async fn test_retry_operation_exhausted_attempts() {
333        let strategy = RetryStrategy {
334            max_attempts: 2,
335            base_delay: Duration::from_millis(1), // Fast for testing
336            max_delay: Duration::from_millis(10),
337            backoff_multiplier: 2.0,
338            use_jitter: false,
339        };
340
341        let counter = Arc::new(AtomicU32::new(0));
342        let counter_clone = Arc::clone(&counter);
343
344        let result = retry_operation(&strategy, "test_operation", || {
345            let counter = Arc::clone(&counter_clone);
346            async move {
347                counter.fetch_add(1, Ordering::SeqCst);
348                Err::<i32, PostgresError>(PostgresError::Connection(sqlx::Error::PoolTimedOut))
349            }
350        })
351        .await;
352
353        assert!(result.is_err());
354        assert!(matches!(
355            result.unwrap_err(),
356            RetryError::ExhaustedAttempts { attempts: 2, .. }
357        ));
358        assert_eq!(counter.load(Ordering::SeqCst), 2);
359    }
360
361    #[test]
362    fn test_is_retryable_error() {
363        // Retryable errors
364        assert!(is_retryable_error(&PostgresError::Connection(
365            sqlx::Error::PoolTimedOut
366        )));
367        assert!(is_retryable_error(&PostgresError::PoolCreation(
368            "test".to_string()
369        )));
370        assert!(is_retryable_error(&PostgresError::Transaction(
371            "test".to_string()
372        )));
373
374        // Non-retryable errors
375        assert!(!is_retryable_error(&PostgresError::Migration(
376            "test".to_string()
377        )));
378        assert!(!is_retryable_error(&PostgresError::Serialization(
379            serde_json::Error::io(std::io::Error::new(std::io::ErrorKind::Other, "test"))
380        )));
381    }
382}