1use crate::error::ErrorCode;
2use fs2::FileExt;
3use std::{
4 fs::{self, File, OpenOptions},
5 io,
6 path::{Path, PathBuf},
7 thread,
8 time::{Duration, Instant},
9};
10
11#[derive(Debug)]
13pub enum LockError {
14 Timeout { path: PathBuf, waited: Duration },
15 IoError(io::Error),
16}
17
18impl From<io::Error> for LockError {
19 fn from(err: io::Error) -> Self {
20 Self::IoError(err)
21 }
22}
23
24impl LockError {
25 #[must_use]
27 pub const fn code(&self) -> ErrorCode {
28 match self {
29 Self::Timeout { .. } => ErrorCode::LockContention,
30 Self::IoError(_) => ErrorCode::EventFileWriteFailed,
31 }
32 }
33
34 #[must_use]
36 pub const fn hint(&self) -> Option<&'static str> {
37 self.code().hint()
38 }
39}
40
41impl std::fmt::Display for LockError {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 match self {
44 Self::Timeout { path, waited } => {
45 write!(
46 f,
47 "{}: lock timed out after {:?} at {}",
48 self.code().code(),
49 waited,
50 path.display()
51 )
52 }
53 Self::IoError(err) => write!(f, "{}: {}", self.code().code(), err),
54 }
55 }
56}
57
58impl std::error::Error for LockError {}
59
60#[derive(Clone, Copy)]
61enum LockKind {
62 Shared,
63 Exclusive,
64}
65
66#[derive(Debug)]
67struct FileGuard {
68 file: File,
69 path: PathBuf,
70}
71
72impl FileGuard {
73 fn acquire(path: &Path, timeout: Duration, kind: LockKind) -> Result<Self, LockError> {
74 let parent = path.parent().ok_or_else(|| {
75 io::Error::new(io::ErrorKind::InvalidInput, "lock path has no parent")
76 })?;
77 fs::create_dir_all(parent)?;
78
79 let start = Instant::now();
80 loop {
81 let file = OpenOptions::new()
82 .create(true)
83 .read(true)
84 .write(true)
85 .truncate(false)
86 .open(path)?;
87
88 let locked = match kind {
89 LockKind::Shared => file.try_lock_shared().is_err(),
90 LockKind::Exclusive => file.try_lock_exclusive().is_err(),
91 };
92
93 if !locked {
94 return Ok(Self {
95 file,
96 path: path.to_path_buf(),
97 });
98 }
99
100 if start.elapsed() >= timeout {
101 return Err(LockError::Timeout {
102 path: path.to_path_buf(),
103 waited: start.elapsed(),
104 });
105 }
106
107 thread::sleep(Duration::from_millis(10));
108 }
109 }
110
111 fn release(self) {
112 let _ = self.file.unlock();
113 }
114
115 fn path(&self) -> &Path {
116 &self.path
117 }
118}
119
120impl Drop for FileGuard {
121 fn drop(&mut self) {
122 let _ = self.file.unlock();
123 }
124}
125
126#[derive(Debug)]
128pub struct ShardLock {
129 guard: FileGuard,
130}
131
132impl ShardLock {
133 pub fn acquire(path: &Path, timeout: Duration) -> Result<Self, LockError> {
140 Ok(Self {
141 guard: FileGuard::acquire(path, timeout, LockKind::Exclusive)?,
142 })
143 }
144
145 pub fn release(self) {
147 self.guard.release();
148 }
149
150 #[must_use]
152 pub fn path(&self) -> &Path {
153 self.guard.path()
154 }
155}
156
157pub struct DbReadLock {
159 guard: FileGuard,
160}
161
162impl DbReadLock {
163 pub fn acquire(path: &Path, timeout: Duration) -> Result<Self, LockError> {
170 Ok(Self {
171 guard: FileGuard::acquire(path, timeout, LockKind::Shared)?,
172 })
173 }
174
175 pub fn release(self) {
177 self.guard.release();
178 }
179
180 #[must_use]
182 pub fn path(&self) -> &Path {
183 self.guard.path()
184 }
185}
186
187pub struct DbWriteLock {
189 guard: FileGuard,
190}
191
192impl DbWriteLock {
193 pub fn acquire(path: &Path, timeout: Duration) -> Result<Self, LockError> {
200 Ok(Self {
201 guard: FileGuard::acquire(path, timeout, LockKind::Exclusive)?,
202 })
203 }
204
205 pub fn release(self) {
207 self.guard.release();
208 }
209
210 #[must_use]
212 pub fn path(&self) -> &Path {
213 self.guard.path()
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::{DbReadLock, DbWriteLock, LockError, ShardLock};
220 use crate::error::ErrorCode;
221 use std::{
222 path::PathBuf,
223 sync::{Arc, Barrier},
224 thread,
225 time::Duration,
226 };
227
228 fn lock_path(name: &str) -> PathBuf {
229 let mut path = std::env::temp_dir();
230 path.push("bones_lock_tests");
231 path.push(name);
232 path
233 }
234
235 #[test]
236 fn shard_lock_allows_acquire_and_release() -> Result<(), LockError> {
237 let path = lock_path("basic.lock");
238 let lock = ShardLock::acquire(&path, Duration::from_millis(50))?;
239 assert_eq!(lock.path(), path.as_path());
240 lock.release();
241 Ok(())
242 }
243
244 #[test]
245 fn shard_lock_times_out_when_held() {
246 let path = lock_path("timeout.lock");
247 let _guard = ShardLock::acquire(&path, Duration::from_millis(50)).unwrap();
248 let err = ShardLock::acquire(&path, Duration::from_millis(20)).unwrap_err();
249
250 assert!(matches!(err, LockError::Timeout { path: p, .. } if p == path));
251 }
252
253 #[test]
254 fn lock_error_maps_to_machine_code() {
255 let timeout = LockError::Timeout {
256 path: lock_path("code.lock"),
257 waited: Duration::from_millis(10),
258 };
259 assert_eq!(timeout.code(), ErrorCode::LockContention);
260 assert!(timeout.hint().is_some());
261 }
262
263 #[test]
264 fn sqlite_read_locks_are_compatible() -> Result<(), LockError> {
265 let path = lock_path("read-share.lock");
266 let first = DbReadLock::acquire(&path, Duration::from_millis(50))?;
267 let second = DbReadLock::acquire(&path, Duration::from_millis(50))?;
268
269 first.release();
270 second.release();
271 Ok(())
272 }
273
274 #[test]
275 fn sqlite_write_blocks_readers() {
276 let path = lock_path("write-blocks-read.lock");
277 let _write = DbWriteLock::acquire(&path, Duration::from_millis(50)).unwrap();
278
279 let started = std::time::Instant::now();
280 let read = DbReadLock::acquire(&path, Duration::from_millis(20));
281
282 assert!(matches!(read, Err(LockError::Timeout { .. })));
283 assert!(started.elapsed() >= Duration::from_millis(20));
284 }
285
286 #[test]
287 fn lock_release_allows_follow_up_lock() -> Result<(), LockError> {
288 let path = lock_path("release-followup.lock");
289 {
290 let _first = ShardLock::acquire(&path, Duration::from_millis(50))?;
291 }
292
293 let _second = ShardLock::acquire(&path, Duration::from_millis(50))?;
294 Ok(())
295 }
296
297 #[test]
298 fn contention_is_resolved_after_writer_releases() -> Result<(), LockError> {
299 let path = lock_path("thread.lock");
300
301 let blocker = Arc::new(Barrier::new(2));
302 let waiter = Arc::new(Barrier::new(2));
303
304 let blocker_thread = Arc::clone(&blocker);
305 let waiter_thread = Arc::clone(&waiter);
306 let path_in_thread = path.clone();
307 let handle = thread::spawn(move || {
308 let _writer = ShardLock::acquire(&path_in_thread, Duration::from_millis(200)).unwrap();
309 blocker_thread.wait();
310 waiter_thread.wait();
311 });
312
313 blocker.wait();
314 assert!(matches!(
315 DbReadLock::acquire(&path, Duration::from_millis(20)),
316 Err(LockError::Timeout { .. })
317 ));
318 waiter.wait();
319 handle.join().unwrap();
320
321 let follow_up = ShardLock::acquire(&path, Duration::from_millis(50))?;
322 follow_up.release();
323 Ok(())
324 }
325}