Skip to main content

dataprof_db/
retry.rs

1//! Connection retry logic with exponential backoff
2
3use crate::DataProfilerError;
4use crate::security::sanitize_error_message;
5use std::time::Duration;
6use tokio::time::sleep;
7
8/// Retry configuration for database operations
9#[derive(Debug, Clone)]
10pub struct RetryConfig {
11    /// Maximum number of retry attempts
12    pub max_retries: u32,
13    /// Initial backoff delay
14    pub initial_delay: Duration,
15    /// Maximum backoff delay
16    pub max_delay: Duration,
17    /// Backoff multiplier
18    pub backoff_multiplier: f32,
19    /// Whether to use jitter to avoid thundering herd
20    pub use_jitter: bool,
21}
22
23impl Default for RetryConfig {
24    fn default() -> Self {
25        Self {
26            max_retries: 3,
27            initial_delay: Duration::from_millis(100),
28            max_delay: Duration::from_secs(10),
29            backoff_multiplier: 2.0,
30            use_jitter: true,
31        }
32    }
33}
34
35/// Retry a database operation with exponential backoff
36pub async fn retry_database_operation<T, F, Fut, E>(
37    config: &RetryConfig,
38    operation: F,
39    operation_name: &str,
40) -> Result<T, DataProfilerError>
41where
42    F: Fn() -> Fut,
43    Fut: std::future::Future<Output = Result<T, E>>,
44    E: std::fmt::Display + Send + Sync + 'static,
45{
46    let mut last_error_msg = None;
47    let mut delay = config.initial_delay;
48
49    for attempt in 0..=config.max_retries {
50        match operation().await {
51            Ok(result) => return Ok(result),
52            Err(error) => {
53                last_error_msg = Some(error.to_string());
54
55                if attempt < config.max_retries {
56                    let actual_delay = if config.use_jitter {
57                        add_jitter(delay)
58                    } else {
59                        delay
60                    };
61
62                    let sanitized_error = sanitize_error_message(&error.to_string());
63                    log::warn!(
64                        "Database operation '{}' failed on attempt {}/{}, retrying in {:?}: {}",
65                        operation_name,
66                        attempt + 1,
67                        config.max_retries + 1,
68                        actual_delay,
69                        sanitized_error
70                    );
71
72                    sleep(actual_delay).await;
73                    delay = std::cmp::min(
74                        Duration::from_millis(
75                            (delay.as_millis() as f32 * config.backoff_multiplier) as u64,
76                        ),
77                        config.max_delay,
78                    );
79                }
80            }
81        }
82    }
83
84    Err(DataProfilerError::DatabaseRetryExhausted {
85        operation: operation_name.to_string(),
86        attempts: config.max_retries + 1,
87        last_error: last_error_msg.unwrap_or_else(|| "unknown error".to_string()),
88    })
89}
90
91/// Add jitter to delay to avoid thundering herd problem
92fn add_jitter(delay: Duration) -> Duration {
93    use rand::Rng;
94    let mut rng = rand::rng();
95    let jitter_factor = rng.random_range(0.5..1.5);
96    Duration::from_millis((delay.as_millis() as f64 * jitter_factor) as u64)
97}
98
99/// Check if an error is retryable (connection-related)
100pub fn is_retryable_error(error: &str) -> bool {
101    let error_lower = error.to_lowercase();
102
103    error_lower.contains("connection")
104        || error_lower.contains("timeout")
105        || error_lower.contains("network")
106        || error_lower.contains("temporary")
107        || error_lower.contains("unavailable")
108        || error_lower.contains("broken pipe")
109        || error_lower.contains("connection reset")
110        || error_lower.contains("connection refused")
111        || error_lower.contains("host unreachable")
112        || error_lower.contains("too many connections")
113        || error_lower.contains("database is locked")
114        || error_lower.contains("server has gone away")
115        || error_lower.contains("connection timed out")
116}
117
118/// Enhanced retry logic that only retries on connection errors
119pub async fn retry_on_connection_error<T, F, Fut, E>(
120    config: &RetryConfig,
121    operation: F,
122    operation_name: &str,
123) -> Result<T, DataProfilerError>
124where
125    F: Fn() -> Fut,
126    Fut: std::future::Future<Output = Result<T, E>>,
127    E: std::fmt::Display + Send + Sync + 'static,
128{
129    let mut last_error_msg = None;
130    let mut delay = config.initial_delay;
131
132    for attempt in 0..=config.max_retries {
133        match operation().await {
134            Ok(result) => return Ok(result),
135            Err(error) => {
136                let error_str = error.to_string();
137
138                if !is_retryable_error(&error_str) {
139                    return Err(DataProfilerError::database_query(&error_str));
140                }
141
142                last_error_msg = Some(error_str);
143
144                if attempt < config.max_retries {
145                    let actual_delay = if config.use_jitter {
146                        add_jitter(delay)
147                    } else {
148                        delay
149                    };
150
151                    let sanitized_error = sanitize_error_message(&error.to_string());
152                    log::warn!(
153                        "Retryable database error in '{}' (attempt {}/{}), retrying in {:?}: {}",
154                        operation_name,
155                        attempt + 1,
156                        config.max_retries + 1,
157                        actual_delay,
158                        sanitized_error
159                    );
160
161                    sleep(actual_delay).await;
162                    delay = std::cmp::min(
163                        Duration::from_millis(
164                            (delay.as_millis() as f32 * config.backoff_multiplier) as u64,
165                        ),
166                        config.max_delay,
167                    );
168                }
169            }
170        }
171    }
172
173    Err(DataProfilerError::DatabaseRetryExhausted {
174        operation: operation_name.to_string(),
175        attempts: config.max_retries + 1,
176        last_error: last_error_msg.unwrap_or_else(|| "unknown error".to_string()),
177    })
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use std::sync::Arc;
184    use std::sync::atomic::{AtomicU32, Ordering};
185
186    #[tokio::test]
187    async fn test_retry_success_after_failure() {
188        let config = RetryConfig {
189            max_retries: 2,
190            initial_delay: Duration::from_millis(10),
191            max_delay: Duration::from_millis(100),
192            backoff_multiplier: 2.0,
193            use_jitter: false,
194        };
195
196        let counter = Arc::new(AtomicU32::new(0));
197        let counter_clone = counter.clone();
198
199        let result = retry_database_operation(
200            &config,
201            || {
202                let c = counter_clone.clone();
203                async move {
204                    let count = c.fetch_add(1, Ordering::SeqCst);
205                    if count < 2 {
206                        Err("Connection failed")
207                    } else {
208                        Ok("Success")
209                    }
210                }
211            },
212            "test_operation",
213        )
214        .await;
215
216        assert!(result.is_ok());
217        assert_eq!(result.expect("Expected successful result"), "Success");
218        assert_eq!(counter.load(Ordering::SeqCst), 3);
219    }
220
221    #[test]
222    fn test_is_retryable_error() {
223        assert!(is_retryable_error("Connection refused"));
224        assert!(is_retryable_error("Database timeout"));
225        assert!(is_retryable_error("Network error"));
226        assert!(is_retryable_error("Too many connections"));
227        assert!(is_retryable_error("database is locked"));
228
229        assert!(!is_retryable_error("Syntax error"));
230        assert!(!is_retryable_error("Permission denied"));
231        assert!(!is_retryable_error("Table not found"));
232    }
233}