1use std::{
5 fs::File,
6 io,
7 path::{Path, PathBuf},
8};
9
10use fs2::FileExt;
11use thiserror::Error;
12
13#[derive(Debug, Error)]
14pub enum LockError {
15 #[error("failed to acquire lock: {0}")]
16 Acquire(#[source] io::Error),
17 #[error("lock file not accessible: {0}")]
18 Io(#[source] io::Error),
19}
20
21pub type Result<T> = std::result::Result<T, LockError>;
22
23pub struct ReadLockGuard {
24 _file: File,
25}
26
27impl Drop for ReadLockGuard {
28 fn drop(&mut self) {
29 let _ = self._file.unlock();
30 }
31}
32
33pub struct WriteLockGuard {
34 _file: File,
35}
36
37impl Drop for WriteLockGuard {
38 fn drop(&mut self) {
39 let _ = self._file.unlock();
40 }
41}
42
43pub struct RepoLock {
44 lock_path: PathBuf,
45}
46
47impl RepoLock {
48 pub fn new(repo_root: &Path) -> Self {
49 let lock_path = repo_root.join(".heddle/locks/repo.lock");
50 Self { lock_path }
51 }
52
53 pub fn at(lock_path: PathBuf) -> Self {
54 Self { lock_path }
55 }
56
57 pub fn read(&self) -> Result<ReadLockGuard> {
58 self.ensure_lock_dir()?;
59 let file = self.open_lock_file()?;
60 file.lock_shared().map_err(LockError::Acquire)?;
61 Ok(ReadLockGuard { _file: file })
62 }
63
64 pub fn write(&self) -> Result<WriteLockGuard> {
65 self.ensure_lock_dir()?;
66 let file = self.open_lock_file()?;
67 file.lock_exclusive().map_err(LockError::Acquire)?;
68 Ok(WriteLockGuard { _file: file })
69 }
70
71 pub fn try_read(&self) -> Result<Option<ReadLockGuard>> {
72 self.ensure_lock_dir()?;
73 let file = self.open_lock_file()?;
74
75 match file.try_lock_shared() {
76 Ok(()) => Ok(Some(ReadLockGuard { _file: file })),
77 Err(_) => Ok(None),
78 }
79 }
80
81 pub fn try_write(&self) -> Result<Option<WriteLockGuard>> {
82 self.ensure_lock_dir()?;
83 let file = self.open_lock_file()?;
84
85 match file.try_lock_exclusive() {
86 Ok(()) => Ok(Some(WriteLockGuard { _file: file })),
87 Err(_) => Ok(None),
88 }
89 }
90
91 fn ensure_lock_dir(&self) -> Result<()> {
92 if let Some(parent) = self.lock_path.parent() {
93 std::fs::create_dir_all(parent).map_err(LockError::Io)?;
94 }
95 Ok(())
96 }
97
98 fn open_lock_file(&self) -> Result<File> {
99 File::create(&self.lock_path).map_err(LockError::Io)
100 }
101}
102
103pub trait RepositoryLockExt {
104 fn locker(&self) -> RepoLock;
105}
106
107#[cfg(test)]
108mod tests {
109 use std::{sync::Arc, thread};
110
111 use tempfile::TempDir;
112
113 use super::*;
114
115 #[test]
116 fn test_read_lock_acquired() {
117 let temp = TempDir::new().unwrap();
118 let lock = RepoLock::new(temp.path());
119
120 let guard = lock.read().unwrap();
121 assert!(std::mem::size_of_val(&guard) > 0);
122 }
123
124 #[test]
125 fn test_write_lock_acquired() {
126 let temp = TempDir::new().unwrap();
127 let lock = RepoLock::new(temp.path());
128
129 let guard = lock.write().unwrap();
130 assert!(std::mem::size_of_val(&guard) > 0);
131 }
132
133 #[test]
134 fn test_multiple_readers() {
135 let temp = TempDir::new().unwrap();
136 let lock = Arc::new(RepoLock::new(temp.path()));
137
138 let mut handles = vec![];
139 for _ in 0..10 {
140 let lock = Arc::clone(&lock);
141 let handle = thread::spawn(move || {
142 let _guard = lock.read().unwrap();
143 thread::sleep(std::time::Duration::from_millis(10));
144 });
145 handles.push(handle);
146 }
147
148 for handle in handles {
149 handle.join().unwrap();
150 }
151 }
152
153 #[test]
154 fn test_writer_excludes_reader() {
155 let temp = TempDir::new().unwrap();
156 let lock = Arc::new(RepoLock::new(temp.path()));
157
158 let _write_guard = lock.write().unwrap();
159 let read_result = lock.try_read().unwrap();
160 assert!(read_result.is_none(), "Reader should be blocked by writer");
161 }
162
163 #[test]
164 fn test_reader_excludes_writer() {
165 let temp = TempDir::new().unwrap();
166 let lock = Arc::new(RepoLock::new(temp.path()));
167
168 let _read_guard = lock.read().unwrap();
169 let write_result = lock.try_write().unwrap();
170 assert!(write_result.is_none(), "Writer should be blocked by reader");
171 }
172
173 #[test]
174 fn test_lock_released_on_drop() {
175 let temp = TempDir::new().unwrap();
176 let lock = RepoLock::new(temp.path());
177
178 {
179 let _guard = lock.write().unwrap();
180 }
181
182 let _guard2 = lock.read().unwrap();
183 }
184}