1use std::fmt::Display;
2use std::time::Duration;
3
4use rand::Rng;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub struct LockRetryPolicy {
8 pub retries: u32,
9 pub base_delay: Duration,
10 pub max_total_wait: Duration,
11}
12
13impl LockRetryPolicy {
14 pub const fn db_default() -> Self {
15 Self {
16 retries: 5,
17 base_delay: Duration::from_millis(100),
18 max_total_wait: Duration::from_millis(3_000),
19 }
20 }
21}
22
23#[derive(Debug)]
24pub enum LockRetryError<E> {
25 Operation(E),
26 Exhausted {
27 source: E,
28 retries: u32,
29 total_wait: Duration,
30 },
31}
32
33pub fn on_lock<T, E, F>(
34 mut op: F,
35 policy: LockRetryPolicy,
36) -> std::result::Result<T, LockRetryError<E>>
37where
38 E: Display,
39 F: FnMut() -> std::result::Result<T, E>,
40{
41 let mut retries = 0u32;
42 let mut total_wait = Duration::ZERO;
43 loop {
44 match op() {
45 Ok(value) => return Ok(value),
46 Err(error) => {
47 if !is_lock_error_message(&error.to_string()) {
48 return Err(LockRetryError::Operation(error));
49 }
50 if retries >= policy.retries || total_wait >= policy.max_total_wait {
51 return Err(LockRetryError::Exhausted {
52 source: error,
53 retries,
54 total_wait,
55 });
56 }
57 let remaining = policy.max_total_wait.saturating_sub(total_wait);
58 if remaining.is_zero() {
59 return Err(LockRetryError::Exhausted {
60 source: error,
61 retries,
62 total_wait,
63 });
64 }
65 let delay = next_delay(policy, retries, remaining);
66 if delay.is_zero() {
67 return Err(LockRetryError::Exhausted {
68 source: error,
69 retries,
70 total_wait,
71 });
72 }
73 std::thread::sleep(delay);
74 total_wait += delay;
75 retries += 1;
76 }
77 }
78 }
79}
80
81pub fn is_lock_error_message(message: &str) -> bool {
82 message.contains("already open")
83 || message.contains("Cannot acquire lock")
84 || message.contains("Locking error: Failed locking file")
85 || message.contains("File is locked by another process")
86 || message.contains("database is locked")
87}
88
89fn next_delay(policy: LockRetryPolicy, attempt: u32, remaining: Duration) -> Duration {
90 let exp = (1u128 << attempt.min(20)) * policy.base_delay.as_millis();
91 let base_ms = exp.min(u64::MAX as u128) as u64;
92 if base_ms == 0 {
93 return Duration::ZERO;
94 }
95 let min_ms = (base_ms / 2).max(1);
96 let max_ms = base_ms.saturating_mul(3).saturating_div(2).max(min_ms);
97 let jittered_ms = if min_ms == max_ms {
98 min_ms
99 } else {
100 rand::thread_rng().gen_range(min_ms..=max_ms)
101 };
102 Duration::from_millis(jittered_ms).min(remaining)
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108
109 #[test]
110 fn on_lock_succeeds_after_transient_errors() {
111 let mut attempts = 0;
112 let value = on_lock(
113 || {
114 attempts += 1;
115 if attempts < 3 {
116 return Err("Locking error: Failed locking file");
117 }
118 Ok::<_, &str>(42)
119 },
120 LockRetryPolicy {
121 retries: 5,
122 base_delay: Duration::from_millis(1),
123 max_total_wait: Duration::from_millis(10),
124 },
125 )
126 .expect("retries succeed");
127 assert_eq!(value, 42);
128 assert_eq!(attempts, 3);
129 }
130
131 #[test]
132 fn on_lock_returns_final_error_after_exhaustion() {
133 let err = on_lock::<(), _, _>(
134 || Err("Locking error: Failed locking file"),
135 LockRetryPolicy {
136 retries: 2,
137 base_delay: Duration::ZERO,
138 max_total_wait: Duration::from_millis(1),
139 },
140 )
141 .expect_err("lock retries exhaust");
142 match err {
143 LockRetryError::Exhausted {
144 retries,
145 total_wait,
146 ..
147 } => {
148 assert_eq!(retries, 0);
149 assert_eq!(total_wait, Duration::ZERO);
150 }
151 other => panic!("expected exhausted retry, got {other:?}"),
152 }
153 }
154
155 #[test]
156 fn on_lock_fails_fast_for_non_lock_errors() {
157 let mut attempts = 0;
158 let err = on_lock::<(), _, _>(
159 || {
160 attempts += 1;
161 Err("disk full")
162 },
163 LockRetryPolicy::db_default(),
164 )
165 .expect_err("non-lock error");
166 assert_eq!(attempts, 1);
167 match err {
168 LockRetryError::Operation(message) => assert_eq!(message, "disk full"),
169 other => panic!("expected operation error, got {other:?}"),
170 }
171 }
172
173 #[test]
174 fn on_lock_honors_total_wait_cap() {
175 let err = on_lock::<(), _, _>(
176 || Err("database is locked"),
177 LockRetryPolicy {
178 retries: 5,
179 base_delay: Duration::from_millis(100),
180 max_total_wait: Duration::from_millis(120),
181 },
182 )
183 .expect_err("lock retries exhaust");
184 match err {
185 LockRetryError::Exhausted {
186 retries,
187 total_wait,
188 ..
189 } => {
190 assert!(retries >= 1);
191 assert!(total_wait <= Duration::from_millis(120));
192 }
193 other => panic!("expected exhausted retry, got {other:?}"),
194 }
195 }
196}