Skip to main content

objects/
lock.rs

1// SPDX-License-Identifier: Apache-2.0
2//! Repository locking for concurrent access.
3
4use std::{
5    fs::File,
6    io,
7    path::{Path, PathBuf},
8};
9
10use fs2::FileExt;
11use thiserror::Error;
12
13#[derive(Debug, Error)]
14pub enum LockError {
15    #[error("failed to acquire lock: {0}")]
16    Acquire(#[source] io::Error),
17    #[error("lock file not accessible: {0}")]
18    Io(#[source] io::Error),
19}
20
21pub type Result<T> = std::result::Result<T, LockError>;
22
23pub struct ReadLockGuard {
24    _file: File,
25}
26
27impl Drop for ReadLockGuard {
28    fn drop(&mut self) {
29        let _ = self._file.unlock();
30    }
31}
32
33pub struct WriteLockGuard {
34    _file: File,
35}
36
37impl Drop for WriteLockGuard {
38    fn drop(&mut self) {
39        let _ = self._file.unlock();
40    }
41}
42
43pub struct RepoLock {
44    lock_path: PathBuf,
45}
46
47impl RepoLock {
48    pub fn new(repo_root: &Path) -> Self {
49        let lock_path = repo_root.join(".heddle/locks/repo.lock");
50        Self { lock_path }
51    }
52
53    pub fn at(lock_path: PathBuf) -> Self {
54        Self { lock_path }
55    }
56
57    pub fn read(&self) -> Result<ReadLockGuard> {
58        self.ensure_lock_dir()?;
59        let file = self.open_lock_file()?;
60        file.lock_shared().map_err(LockError::Acquire)?;
61        Ok(ReadLockGuard { _file: file })
62    }
63
64    pub fn write(&self) -> Result<WriteLockGuard> {
65        self.ensure_lock_dir()?;
66        let file = self.open_lock_file()?;
67        file.lock_exclusive().map_err(LockError::Acquire)?;
68        Ok(WriteLockGuard { _file: file })
69    }
70
71    pub fn try_read(&self) -> Result<Option<ReadLockGuard>> {
72        self.ensure_lock_dir()?;
73        let file = self.open_lock_file()?;
74
75        match file.try_lock_shared() {
76            Ok(()) => Ok(Some(ReadLockGuard { _file: file })),
77            Err(_) => Ok(None),
78        }
79    }
80
81    pub fn try_write(&self) -> Result<Option<WriteLockGuard>> {
82        self.ensure_lock_dir()?;
83        let file = self.open_lock_file()?;
84
85        match file.try_lock_exclusive() {
86            Ok(()) => Ok(Some(WriteLockGuard { _file: file })),
87            Err(_) => Ok(None),
88        }
89    }
90
91    fn ensure_lock_dir(&self) -> Result<()> {
92        if let Some(parent) = self.lock_path.parent() {
93            std::fs::create_dir_all(parent).map_err(LockError::Io)?;
94        }
95        Ok(())
96    }
97
98    fn open_lock_file(&self) -> Result<File> {
99        File::create(&self.lock_path).map_err(LockError::Io)
100    }
101}
102
103pub trait RepositoryLockExt {
104    fn locker(&self) -> RepoLock;
105}
106
107#[cfg(test)]
108mod tests {
109    use std::{sync::Arc, thread};
110
111    use tempfile::TempDir;
112
113    use super::*;
114
115    #[test]
116    fn test_read_lock_acquired() {
117        let temp = TempDir::new().unwrap();
118        let lock = RepoLock::new(temp.path());
119
120        let guard = lock.read().unwrap();
121        assert!(std::mem::size_of_val(&guard) > 0);
122    }
123
124    #[test]
125    fn test_write_lock_acquired() {
126        let temp = TempDir::new().unwrap();
127        let lock = RepoLock::new(temp.path());
128
129        let guard = lock.write().unwrap();
130        assert!(std::mem::size_of_val(&guard) > 0);
131    }
132
133    #[test]
134    fn test_multiple_readers() {
135        let temp = TempDir::new().unwrap();
136        let lock = Arc::new(RepoLock::new(temp.path()));
137
138        let mut handles = vec![];
139        for _ in 0..10 {
140            let lock = Arc::clone(&lock);
141            let handle = thread::spawn(move || {
142                let _guard = lock.read().unwrap();
143                thread::sleep(std::time::Duration::from_millis(10));
144            });
145            handles.push(handle);
146        }
147
148        for handle in handles {
149            handle.join().unwrap();
150        }
151    }
152
153    #[test]
154    fn test_writer_excludes_reader() {
155        let temp = TempDir::new().unwrap();
156        let lock = Arc::new(RepoLock::new(temp.path()));
157
158        let _write_guard = lock.write().unwrap();
159        let read_result = lock.try_read().unwrap();
160        assert!(read_result.is_none(), "Reader should be blocked by writer");
161    }
162
163    #[test]
164    fn test_reader_excludes_writer() {
165        let temp = TempDir::new().unwrap();
166        let lock = Arc::new(RepoLock::new(temp.path()));
167
168        let _read_guard = lock.read().unwrap();
169        let write_result = lock.try_write().unwrap();
170        assert!(write_result.is_none(), "Writer should be blocked by reader");
171    }
172
173    #[test]
174    fn test_lock_released_on_drop() {
175        let temp = TempDir::new().unwrap();
176        let lock = RepoLock::new(temp.path());
177
178        {
179            let _guard = lock.write().unwrap();
180        }
181
182        let _guard2 = lock.read().unwrap();
183    }
184}