use std::fmt::Display;
use std::time::Duration;
use rand::Rng;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct LockRetryPolicy {
pub retries: u32,
pub base_delay: Duration,
pub max_total_wait: Duration,
}
impl LockRetryPolicy {
pub const fn db_default() -> Self {
Self {
retries: 5,
base_delay: Duration::from_millis(100),
max_total_wait: Duration::from_millis(3_000),
}
}
}
#[derive(Debug)]
pub enum LockRetryError<E> {
Operation(E),
Exhausted {
source: E,
retries: u32,
total_wait: Duration,
},
}
pub fn on_lock<T, E, F>(
mut op: F,
policy: LockRetryPolicy,
) -> std::result::Result<T, LockRetryError<E>>
where
E: Display,
F: FnMut() -> std::result::Result<T, E>,
{
let mut retries = 0u32;
let mut total_wait = Duration::ZERO;
loop {
match op() {
Ok(value) => return Ok(value),
Err(error) => {
if !is_lock_error_message(&error.to_string()) {
return Err(LockRetryError::Operation(error));
}
if retries >= policy.retries || total_wait >= policy.max_total_wait {
return Err(LockRetryError::Exhausted {
source: error,
retries,
total_wait,
});
}
let remaining = policy.max_total_wait.saturating_sub(total_wait);
if remaining.is_zero() {
return Err(LockRetryError::Exhausted {
source: error,
retries,
total_wait,
});
}
let delay = next_delay(policy, retries, remaining);
if delay.is_zero() {
return Err(LockRetryError::Exhausted {
source: error,
retries,
total_wait,
});
}
std::thread::sleep(delay);
total_wait += delay;
retries += 1;
}
}
}
}
pub fn is_lock_error_message(message: &str) -> bool {
message.contains("already open")
|| message.contains("Cannot acquire lock")
|| message.contains("Locking error: Failed locking file")
|| message.contains("File is locked by another process")
|| message.contains("database is locked")
}
fn next_delay(policy: LockRetryPolicy, attempt: u32, remaining: Duration) -> Duration {
let exp = (1u128 << attempt.min(20)) * policy.base_delay.as_millis();
let base_ms = exp.min(u64::MAX as u128) as u64;
if base_ms == 0 {
return Duration::ZERO;
}
let min_ms = (base_ms / 2).max(1);
let max_ms = base_ms.saturating_mul(3).saturating_div(2).max(min_ms);
let jittered_ms = if min_ms == max_ms {
min_ms
} else {
rand::thread_rng().gen_range(min_ms..=max_ms)
};
Duration::from_millis(jittered_ms).min(remaining)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn on_lock_succeeds_after_transient_errors() {
let mut attempts = 0;
let value = on_lock(
|| {
attempts += 1;
if attempts < 3 {
return Err("Locking error: Failed locking file");
}
Ok::<_, &str>(42)
},
LockRetryPolicy {
retries: 5,
base_delay: Duration::from_millis(1),
max_total_wait: Duration::from_millis(10),
},
)
.expect("retries succeed");
assert_eq!(value, 42);
assert_eq!(attempts, 3);
}
#[test]
fn on_lock_returns_final_error_after_exhaustion() {
let err = on_lock::<(), _, _>(
|| Err("Locking error: Failed locking file"),
LockRetryPolicy {
retries: 2,
base_delay: Duration::ZERO,
max_total_wait: Duration::from_millis(1),
},
)
.expect_err("lock retries exhaust");
match err {
LockRetryError::Exhausted {
retries,
total_wait,
..
} => {
assert_eq!(retries, 0);
assert_eq!(total_wait, Duration::ZERO);
}
other => panic!("expected exhausted retry, got {other:?}"),
}
}
#[test]
fn on_lock_fails_fast_for_non_lock_errors() {
let mut attempts = 0;
let err = on_lock::<(), _, _>(
|| {
attempts += 1;
Err("disk full")
},
LockRetryPolicy::db_default(),
)
.expect_err("non-lock error");
assert_eq!(attempts, 1);
match err {
LockRetryError::Operation(message) => assert_eq!(message, "disk full"),
other => panic!("expected operation error, got {other:?}"),
}
}
#[test]
fn on_lock_honors_total_wait_cap() {
let err = on_lock::<(), _, _>(
|| Err("database is locked"),
LockRetryPolicy {
retries: 5,
base_delay: Duration::from_millis(100),
max_total_wait: Duration::from_millis(120),
},
)
.expect_err("lock retries exhaust");
match err {
LockRetryError::Exhausted {
retries,
total_wait,
..
} => {
assert!(retries >= 1);
assert!(total_wait <= Duration::from_millis(120));
}
other => panic!("expected exhausted retry, got {other:?}"),
}
}
}