use anyhow::{Context, Result};
use std::future::Future;
use std::time::Duration;
use tokio::time::sleep;
#[derive(Debug, Clone)]
pub struct DbRetryConfig {
pub max_attempts: u32,
pub initial_delay: Duration,
pub max_delay: Duration,
pub backoff_multiplier: f64,
}
impl Default for DbRetryConfig {
fn default() -> Self {
Self {
max_attempts: 5,
initial_delay: Duration::from_millis(50),
max_delay: Duration::from_secs(5),
backoff_multiplier: 2.0,
}
}
}
impl DbRetryConfig {
pub fn new(max_attempts: u32, initial_delay: Duration) -> Self {
Self {
max_attempts,
initial_delay,
..Default::default()
}
}
pub fn aggressive() -> Self {
Self {
max_attempts: 10,
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(10),
backoff_multiplier: 1.5,
}
}
fn calculate_delay(&self, attempt: u32) -> Duration {
let base_delay = self.initial_delay.as_millis() as f64;
let exponential = base_delay * self.backoff_multiplier.powi(attempt as i32);
let max_delay_ms = self.max_delay.as_millis() as f64;
let delay = exponential.min(max_delay_ms);
Duration::from_millis(delay as u64)
}
}
fn is_database_locked(err: &rusqlite::Error) -> bool {
matches!(
err,
rusqlite::Error::SqliteFailure(
rusqlite::ffi::Error {
code: rusqlite::ffi::ErrorCode::DatabaseBusy,
..
},
_,
) | rusqlite::Error::SqliteFailure(
rusqlite::ffi::Error {
code: rusqlite::ffi::ErrorCode::DatabaseLocked,
..
},
_,
)
)
}
pub async fn retry_db_operation<F, Fut, T, E>(
mut operation: F,
config: &DbRetryConfig,
) -> std::result::Result<T, E>
where
F: FnMut() -> Fut,
Fut: Future<Output = std::result::Result<T, E>>,
E: std::fmt::Display,
{
let mut attempt = 0;
let mut last_error;
loop {
match operation().await {
Ok(result) => {
if attempt > 0 {
tracing::info!("Database operation succeeded after {} retries", attempt);
}
return Ok(result);
}
Err(err) => {
let error_msg = err.to_string();
last_error = err;
let is_locked = error_msg.to_lowercase().contains("locked")
|| error_msg.to_lowercase().contains("busy");
if !is_locked {
tracing::debug!("Database error is not retryable: {}", error_msg);
return Err(last_error);
}
if attempt >= config.max_attempts {
tracing::warn!(
"Max database retry attempts ({}) exceeded for lock error",
config.max_attempts
);
return Err(last_error);
}
let delay = config.calculate_delay(attempt);
tracing::info!(
"Database locked (attempt {}/{}), retrying after {}ms",
attempt + 1,
config.max_attempts,
delay.as_millis()
);
sleep(delay).await;
attempt += 1;
}
}
}
}
pub async fn retry_db_anyhow<F, Fut, T>(operation: F, config: &DbRetryConfig) -> Result<T>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T>>,
{
retry_db_operation(operation, config)
.await
.context("Database operation failed after retries")
}
pub async fn retry_db_rusqlite<F, Fut, T>(
mut operation: F,
config: &DbRetryConfig,
) -> std::result::Result<T, rusqlite::Error>
where
F: FnMut() -> Fut,
Fut: Future<Output = std::result::Result<T, rusqlite::Error>>,
{
let mut attempt = 0;
let mut last_error;
loop {
match operation().await {
Ok(result) => {
if attempt > 0 {
tracing::info!("Database operation succeeded after {} retries", attempt);
}
return Ok(result);
}
Err(err) => {
let is_locked = is_database_locked(&err);
if !is_locked {
tracing::debug!("Database error is not retryable: {}", err);
return Err(err);
}
last_error = err;
if attempt >= config.max_attempts {
tracing::warn!(
"Max database retry attempts ({}) exceeded for lock error",
config.max_attempts
);
return Err(last_error);
}
let delay = config.calculate_delay(attempt);
tracing::info!(
"Database locked (attempt {}/{}), retrying after {}ms",
attempt + 1,
config.max_attempts,
delay.as_millis()
);
sleep(delay).await;
attempt += 1;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_retry_config_defaults() {
let config = DbRetryConfig::default();
assert_eq!(config.max_attempts, 5);
assert_eq!(config.initial_delay, Duration::from_millis(50));
assert_eq!(config.max_delay, Duration::from_secs(5));
}
#[test]
fn test_retry_config_aggressive() {
let config = DbRetryConfig::aggressive();
assert_eq!(config.max_attempts, 10);
assert_eq!(config.initial_delay, Duration::from_millis(100));
}
#[test]
fn test_calculate_delay() {
let config = DbRetryConfig {
initial_delay: Duration::from_millis(50),
max_delay: Duration::from_secs(5),
backoff_multiplier: 2.0,
max_attempts: 5,
};
let delay0 = config.calculate_delay(0);
assert_eq!(delay0, Duration::from_millis(50));
let delay1 = config.calculate_delay(1);
assert_eq!(delay1, Duration::from_millis(100));
let delay2 = config.calculate_delay(2);
assert_eq!(delay2, Duration::from_millis(200));
let delay10 = config.calculate_delay(10);
assert_eq!(delay10, Duration::from_secs(5));
}
#[test]
fn test_is_database_locked() {
let err = rusqlite::Error::QueryReturnedNoRows;
assert!(!is_database_locked(&err));
}
#[tokio::test]
async fn test_retry_success_immediate() {
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
let config = DbRetryConfig::default();
let call_count = Arc::new(AtomicU32::new(0));
let call_count_clone = call_count.clone();
let result = retry_db_operation(
move || {
let count = call_count_clone.clone();
async move {
count.fetch_add(1, Ordering::SeqCst);
Ok::<_, String>(42)
}
},
&config,
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_retry_success_after_retries() {
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
let config = DbRetryConfig::new(3, Duration::from_millis(10));
let call_count = Arc::new(AtomicU32::new(0));
let call_count_clone = call_count.clone();
let result = retry_db_operation(
move || {
let count = call_count_clone.clone();
async move {
let current = count.fetch_add(1, Ordering::SeqCst) + 1;
if current < 3 {
Err("database is locked".to_string())
} else {
Ok::<_, String>(42)
}
}
},
&config,
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
assert_eq!(call_count.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_retry_max_attempts_exceeded() {
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
let config = DbRetryConfig::new(2, Duration::from_millis(10));
let call_count = Arc::new(AtomicU32::new(0));
let call_count_clone = call_count.clone();
let result = retry_db_operation(
move || {
let count = call_count_clone.clone();
async move {
count.fetch_add(1, Ordering::SeqCst);
Err::<i32, _>("database is locked".to_string())
}
},
&config,
)
.await;
assert!(result.is_err());
assert_eq!(call_count.load(Ordering::SeqCst), 3); }
#[tokio::test]
async fn test_retry_non_retryable_error() {
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
let config = DbRetryConfig::default();
let call_count = Arc::new(AtomicU32::new(0));
let call_count_clone = call_count.clone();
let result = retry_db_operation(
move || {
let count = call_count_clone.clone();
async move {
count.fetch_add(1, Ordering::SeqCst);
Err::<i32, _>("constraint violation".to_string())
}
},
&config,
)
.await;
assert!(result.is_err());
assert_eq!(call_count.load(Ordering::SeqCst), 1); }
}