use crate::core::errors::DataProfilerError;
use crate::database::security::sanitize_error_message;
use std::time::Duration;
use tokio::time::sleep;
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_retries: u32,
pub initial_delay: Duration,
pub max_delay: Duration,
pub backoff_multiplier: f32,
pub use_jitter: bool,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(10),
backoff_multiplier: 2.0,
use_jitter: true,
}
}
}
pub async fn retry_database_operation<T, F, Fut, E>(
config: &RetryConfig,
operation: F,
operation_name: &str,
) -> Result<T, DataProfilerError>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
E: std::fmt::Display + Send + Sync + 'static,
{
let mut last_error_msg = None;
let mut delay = config.initial_delay;
for attempt in 0..=config.max_retries {
match operation().await {
Ok(result) => return Ok(result),
Err(error) => {
last_error_msg = Some(error.to_string());
if attempt < config.max_retries {
let actual_delay = if config.use_jitter {
add_jitter(delay)
} else {
delay
};
let sanitized_error = sanitize_error_message(&error.to_string());
log::warn!(
"Database operation '{}' failed on attempt {}/{}, retrying in {:?}: {}",
operation_name,
attempt + 1,
config.max_retries + 1,
actual_delay,
sanitized_error
);
sleep(actual_delay).await;
delay = std::cmp::min(
Duration::from_millis(
(delay.as_millis() as f32 * config.backoff_multiplier) as u64,
),
config.max_delay,
);
}
}
}
}
Err(DataProfilerError::DatabaseRetryExhausted {
operation: operation_name.to_string(),
attempts: config.max_retries + 1,
last_error: last_error_msg.unwrap_or_else(|| "unknown error".to_string()),
})
}
fn add_jitter(delay: Duration) -> Duration {
use rand::Rng;
let mut rng = rand::rng();
let jitter_factor = rng.random_range(0.5..1.5);
Duration::from_millis((delay.as_millis() as f64 * jitter_factor) as u64)
}
pub fn is_retryable_error(error: &str) -> bool {
let error_lower = error.to_lowercase();
error_lower.contains("connection") ||
error_lower.contains("timeout") ||
error_lower.contains("network") ||
error_lower.contains("temporary") ||
error_lower.contains("unavailable") ||
error_lower.contains("broken pipe") ||
error_lower.contains("connection reset") ||
error_lower.contains("connection refused") ||
error_lower.contains("host unreachable") ||
error_lower.contains("too many connections") ||
error_lower.contains("database is locked") || error_lower.contains("server has gone away") || error_lower.contains("connection timed out") }
pub async fn retry_on_connection_error<T, F, Fut, E>(
config: &RetryConfig,
operation: F,
operation_name: &str,
) -> Result<T, DataProfilerError>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
E: std::fmt::Display + Send + Sync + 'static,
{
let mut last_error_msg = None;
let mut delay = config.initial_delay;
for attempt in 0..=config.max_retries {
match operation().await {
Ok(result) => return Ok(result),
Err(error) => {
let error_str = error.to_string();
if !is_retryable_error(&error_str) {
return Err(DataProfilerError::database_query(&error_str));
}
last_error_msg = Some(error_str);
if attempt < config.max_retries {
let actual_delay = if config.use_jitter {
add_jitter(delay)
} else {
delay
};
let sanitized_error = sanitize_error_message(&error.to_string());
log::warn!(
"Retryable database error in '{}' (attempt {}/{}), retrying in {:?}: {}",
operation_name,
attempt + 1,
config.max_retries + 1,
actual_delay,
sanitized_error
);
sleep(actual_delay).await;
delay = std::cmp::min(
Duration::from_millis(
(delay.as_millis() as f32 * config.backoff_multiplier) as u64,
),
config.max_delay,
);
}
}
}
}
Err(DataProfilerError::DatabaseRetryExhausted {
operation: operation_name.to_string(),
attempts: config.max_retries + 1,
last_error: last_error_msg.unwrap_or_else(|| "unknown error".to_string()),
})
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
#[tokio::test]
async fn test_retry_success_after_failure() {
let config = RetryConfig {
max_retries: 2,
initial_delay: Duration::from_millis(10),
max_delay: Duration::from_millis(100),
backoff_multiplier: 2.0,
use_jitter: false,
};
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let result = retry_database_operation(
&config,
|| {
let c = counter_clone.clone();
async move {
let count = c.fetch_add(1, Ordering::SeqCst);
if count < 2 {
Err("Connection failed")
} else {
Ok("Success")
}
}
},
"test_operation",
)
.await;
assert!(result.is_ok());
assert_eq!(result.expect("Expected successful result"), "Success");
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
#[test]
fn test_is_retryable_error() {
assert!(is_retryable_error("Connection refused"));
assert!(is_retryable_error("Database timeout"));
assert!(is_retryable_error("Network error"));
assert!(is_retryable_error("Too many connections"));
assert!(is_retryable_error("database is locked"));
assert!(!is_retryable_error("Syntax error"));
assert!(!is_retryable_error("Permission denied"));
assert!(!is_retryable_error("Table not found"));
}
}