1use crate::DataProfilerError;
4use crate::security::sanitize_error_message;
5use std::time::Duration;
6use tokio::time::sleep;
7
8#[derive(Debug, Clone)]
10pub struct RetryConfig {
11 pub max_retries: u32,
13 pub initial_delay: Duration,
15 pub max_delay: Duration,
17 pub backoff_multiplier: f32,
19 pub use_jitter: bool,
21}
22
23impl Default for RetryConfig {
24 fn default() -> Self {
25 Self {
26 max_retries: 3,
27 initial_delay: Duration::from_millis(100),
28 max_delay: Duration::from_secs(10),
29 backoff_multiplier: 2.0,
30 use_jitter: true,
31 }
32 }
33}
34
35pub async fn retry_database_operation<T, F, Fut, E>(
37 config: &RetryConfig,
38 operation: F,
39 operation_name: &str,
40) -> Result<T, DataProfilerError>
41where
42 F: Fn() -> Fut,
43 Fut: std::future::Future<Output = Result<T, E>>,
44 E: std::fmt::Display + Send + Sync + 'static,
45{
46 let mut last_error_msg = None;
47 let mut delay = config.initial_delay;
48
49 for attempt in 0..=config.max_retries {
50 match operation().await {
51 Ok(result) => return Ok(result),
52 Err(error) => {
53 last_error_msg = Some(error.to_string());
54
55 if attempt < config.max_retries {
56 let actual_delay = if config.use_jitter {
57 add_jitter(delay)
58 } else {
59 delay
60 };
61
62 let sanitized_error = sanitize_error_message(&error.to_string());
63 log::warn!(
64 "Database operation '{}' failed on attempt {}/{}, retrying in {:?}: {}",
65 operation_name,
66 attempt + 1,
67 config.max_retries + 1,
68 actual_delay,
69 sanitized_error
70 );
71
72 sleep(actual_delay).await;
73 delay = std::cmp::min(
74 Duration::from_millis(
75 (delay.as_millis() as f32 * config.backoff_multiplier) as u64,
76 ),
77 config.max_delay,
78 );
79 }
80 }
81 }
82 }
83
84 Err(DataProfilerError::DatabaseRetryExhausted {
85 operation: operation_name.to_string(),
86 attempts: config.max_retries + 1,
87 last_error: last_error_msg.unwrap_or_else(|| "unknown error".to_string()),
88 })
89}
90
91fn add_jitter(delay: Duration) -> Duration {
93 use rand::Rng;
94 let mut rng = rand::rng();
95 let jitter_factor = rng.random_range(0.5..1.5);
96 Duration::from_millis((delay.as_millis() as f64 * jitter_factor) as u64)
97}
98
99pub fn is_retryable_error(error: &str) -> bool {
101 let error_lower = error.to_lowercase();
102
103 error_lower.contains("connection")
104 || error_lower.contains("timeout")
105 || error_lower.contains("network")
106 || error_lower.contains("temporary")
107 || error_lower.contains("unavailable")
108 || error_lower.contains("broken pipe")
109 || error_lower.contains("connection reset")
110 || error_lower.contains("connection refused")
111 || error_lower.contains("host unreachable")
112 || error_lower.contains("too many connections")
113 || error_lower.contains("database is locked")
114 || error_lower.contains("server has gone away")
115 || error_lower.contains("connection timed out")
116}
117
118pub async fn retry_on_connection_error<T, F, Fut, E>(
120 config: &RetryConfig,
121 operation: F,
122 operation_name: &str,
123) -> Result<T, DataProfilerError>
124where
125 F: Fn() -> Fut,
126 Fut: std::future::Future<Output = Result<T, E>>,
127 E: std::fmt::Display + Send + Sync + 'static,
128{
129 let mut last_error_msg = None;
130 let mut delay = config.initial_delay;
131
132 for attempt in 0..=config.max_retries {
133 match operation().await {
134 Ok(result) => return Ok(result),
135 Err(error) => {
136 let error_str = error.to_string();
137
138 if !is_retryable_error(&error_str) {
139 return Err(DataProfilerError::database_query(&error_str));
140 }
141
142 last_error_msg = Some(error_str);
143
144 if attempt < config.max_retries {
145 let actual_delay = if config.use_jitter {
146 add_jitter(delay)
147 } else {
148 delay
149 };
150
151 let sanitized_error = sanitize_error_message(&error.to_string());
152 log::warn!(
153 "Retryable database error in '{}' (attempt {}/{}), retrying in {:?}: {}",
154 operation_name,
155 attempt + 1,
156 config.max_retries + 1,
157 actual_delay,
158 sanitized_error
159 );
160
161 sleep(actual_delay).await;
162 delay = std::cmp::min(
163 Duration::from_millis(
164 (delay.as_millis() as f32 * config.backoff_multiplier) as u64,
165 ),
166 config.max_delay,
167 );
168 }
169 }
170 }
171 }
172
173 Err(DataProfilerError::DatabaseRetryExhausted {
174 operation: operation_name.to_string(),
175 attempts: config.max_retries + 1,
176 last_error: last_error_msg.unwrap_or_else(|| "unknown error".to_string()),
177 })
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use std::sync::Arc;
184 use std::sync::atomic::{AtomicU32, Ordering};
185
186 #[tokio::test]
187 async fn test_retry_success_after_failure() {
188 let config = RetryConfig {
189 max_retries: 2,
190 initial_delay: Duration::from_millis(10),
191 max_delay: Duration::from_millis(100),
192 backoff_multiplier: 2.0,
193 use_jitter: false,
194 };
195
196 let counter = Arc::new(AtomicU32::new(0));
197 let counter_clone = counter.clone();
198
199 let result = retry_database_operation(
200 &config,
201 || {
202 let c = counter_clone.clone();
203 async move {
204 let count = c.fetch_add(1, Ordering::SeqCst);
205 if count < 2 {
206 Err("Connection failed")
207 } else {
208 Ok("Success")
209 }
210 }
211 },
212 "test_operation",
213 )
214 .await;
215
216 assert!(result.is_ok());
217 assert_eq!(result.expect("Expected successful result"), "Success");
218 assert_eq!(counter.load(Ordering::SeqCst), 3);
219 }
220
221 #[test]
222 fn test_is_retryable_error() {
223 assert!(is_retryable_error("Connection refused"));
224 assert!(is_retryable_error("Database timeout"));
225 assert!(is_retryable_error("Network error"));
226 assert!(is_retryable_error("Too many connections"));
227 assert!(is_retryable_error("database is locked"));
228
229 assert!(!is_retryable_error("Syntax error"));
230 assert!(!is_retryable_error("Permission denied"));
231 assert!(!is_retryable_error("Table not found"));
232 }
233}