1use std::{
20 collections::HashMap,
21 fs::File,
22 io,
23 path::{Path, PathBuf},
24 sync::{Arc, Condvar, Mutex, MutexGuard, OnceLock},
25 thread::{self, ThreadId},
26};
27
28use fs2::FileExt;
29use thiserror::Error;
30
31#[derive(Debug, Error)]
32pub enum LockError {
33 #[error("failed to acquire lock: {0}")]
34 Acquire(#[source] io::Error),
35 #[error("lock file not accessible: {0}")]
36 Io(#[source] io::Error),
37}
38
39pub type Result<T> = std::result::Result<T, LockError>;
40
41struct GateState {
44 owner: Option<ThreadId>,
45 depth: usize,
46 flock: Option<File>,
47}
48
49struct Entry {
50 gate: Mutex<GateState>,
51 cv: Condvar,
52}
53
54impl Entry {
55 fn new() -> Self {
56 Self {
57 gate: Mutex::new(GateState {
58 owner: None,
59 depth: 0,
60 flock: None,
61 }),
62 cv: Condvar::new(),
63 }
64 }
65}
66
67static REGISTRY: OnceLock<Mutex<HashMap<PathBuf, Arc<Entry>>>> = OnceLock::new();
71
72fn registry() -> &'static Mutex<HashMap<PathBuf, Arc<Entry>>> {
73 REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
74}
75
76fn entry_for(key: PathBuf) -> Arc<Entry> {
77 let mut map = registry().lock().unwrap_or_else(|e| e.into_inner());
78 Arc::clone(map.entry(key).or_insert_with(|| Arc::new(Entry::new())))
79}
80
81fn lock_gate(entry: &Entry) -> MutexGuard<'_, GateState> {
82 entry.gate.lock().unwrap_or_else(|e| e.into_inner())
83}
84
85pub struct ReadLockGuard {
86 _file: Option<File>,
89}
90
91impl Drop for ReadLockGuard {
92 fn drop(&mut self) {
93 if let Some(file) = &self._file {
94 let _ = file.unlock();
95 }
96 }
97}
98
99pub struct WriteLockGuard {
100 entry: Arc<Entry>,
101}
102
103impl Drop for WriteLockGuard {
104 fn drop(&mut self) {
105 let mut state = lock_gate(&self.entry);
106 if state.depth > 0 {
107 state.depth -= 1;
108 }
109 if state.depth == 0 {
110 state.owner = None;
111 state.flock = None;
113 self.entry.cv.notify_one();
114 }
115 }
116}
117
118pub struct RepoLock {
119 lock_path: PathBuf,
120}
121
122impl RepoLock {
123 pub fn new(repo_root: &Path) -> Self {
124 let lock_path = repo_root.join(".heddle/locks/repo.lock");
125 Self { lock_path }
126 }
127
128 pub fn at(lock_path: PathBuf) -> Self {
129 Self { lock_path }
130 }
131
132 pub fn read(&self) -> Result<ReadLockGuard> {
133 self.ensure_lock_dir()?;
134 let entry = entry_for(self.registry_key());
135
136 {
140 let state = lock_gate(&entry);
141 if state.owner == Some(thread::current().id()) {
142 return Ok(ReadLockGuard { _file: None });
143 }
144 }
145
146 let file = self.open_lock_file()?;
147 file.lock_shared().map_err(LockError::Acquire)?;
148 Ok(ReadLockGuard { _file: Some(file) })
149 }
150
151 pub fn write(&self) -> Result<WriteLockGuard> {
152 self.ensure_lock_dir()?;
153 let entry = entry_for(self.registry_key());
154 let tid = thread::current().id();
155 let mut state = lock_gate(&entry);
156 loop {
157 match state.owner {
158 Some(owner) if owner == tid => {
159 state.depth += 1;
160 return Ok(WriteLockGuard {
161 entry: Arc::clone(&entry),
162 });
163 }
164 None => {
165 let file = self.open_lock_file()?;
170 file.lock_exclusive().map_err(LockError::Acquire)?;
171 state.owner = Some(tid);
172 state.depth = 1;
173 state.flock = Some(file);
174 return Ok(WriteLockGuard {
175 entry: Arc::clone(&entry),
176 });
177 }
178 Some(_) => {
179 state = entry.cv.wait(state).unwrap_or_else(|e| e.into_inner());
180 }
181 }
182 }
183 }
184
185 pub fn try_read(&self) -> Result<Option<ReadLockGuard>> {
186 self.ensure_lock_dir()?;
187 let file = self.open_lock_file()?;
188
189 match file.try_lock_shared() {
190 Ok(()) => Ok(Some(ReadLockGuard { _file: Some(file) })),
191 Err(_) => Ok(None),
192 }
193 }
194
195 pub fn try_write(&self) -> Result<Option<WriteLockGuard>> {
196 self.ensure_lock_dir()?;
197 let entry = entry_for(self.registry_key());
198 let mut state = lock_gate(&entry);
199 match state.owner {
206 Some(_) => Ok(None),
207 None => {
208 let file = self.open_lock_file()?;
209 match file.try_lock_exclusive() {
210 Ok(()) => {
211 state.owner = Some(thread::current().id());
212 state.depth = 1;
213 state.flock = Some(file);
214 Ok(Some(WriteLockGuard {
215 entry: Arc::clone(&entry),
216 }))
217 }
218 Err(_) => Ok(None),
219 }
220 }
221 }
222 }
223
224 fn ensure_lock_dir(&self) -> Result<()> {
225 if let Some(parent) = self.lock_path.parent() {
226 std::fs::create_dir_all(parent).map_err(LockError::Io)?;
227 }
228 Ok(())
229 }
230
231 fn registry_key(&self) -> PathBuf {
235 match self.lock_path.parent() {
236 Some(parent) => {
237 let canon_parent = parent
238 .canonicalize()
239 .unwrap_or_else(|_| parent.to_path_buf());
240 match self.lock_path.file_name() {
241 Some(name) => canon_parent.join(name),
242 None => canon_parent,
243 }
244 }
245 None => self.lock_path.clone(),
246 }
247 }
248
249 fn open_lock_file(&self) -> Result<File> {
250 File::create(&self.lock_path).map_err(LockError::Io)
251 }
252}
253
254pub trait RepositoryLockExt {
255 fn locker(&self) -> RepoLock;
256}
257
258#[cfg(test)]
259mod tests {
260 use std::{
261 sync::{
262 Arc,
263 mpsc::{self},
264 },
265 thread,
266 };
267
268 use tempfile::TempDir;
269
270 use super::*;
271
272 #[test]
273 fn test_read_lock_acquired() {
274 let temp = TempDir::new().unwrap();
275 let lock = RepoLock::new(temp.path());
276
277 let guard = lock.read().unwrap();
278 assert!(std::mem::size_of_val(&guard) > 0);
279 }
280
281 #[test]
282 fn test_write_lock_acquired() {
283 let temp = TempDir::new().unwrap();
284 let lock = RepoLock::new(temp.path());
285
286 let guard = lock.write().unwrap();
287 assert!(std::mem::size_of_val(&guard) > 0);
288 }
289
290 #[test]
291 fn test_multiple_readers() {
292 let temp = TempDir::new().unwrap();
293 let lock = Arc::new(RepoLock::new(temp.path()));
294
295 let mut handles = vec![];
296 for _ in 0..10 {
297 let lock = Arc::clone(&lock);
298 let handle = thread::spawn(move || {
299 let _guard = lock.read().unwrap();
300 thread::sleep(std::time::Duration::from_millis(10));
301 });
302 handles.push(handle);
303 }
304
305 for handle in handles {
306 handle.join().unwrap();
307 }
308 }
309
310 #[test]
311 fn test_writer_excludes_reader() {
312 let temp = TempDir::new().unwrap();
313 let lock = Arc::new(RepoLock::new(temp.path()));
314
315 let _write_guard = lock.write().unwrap();
316 let read_result = lock.try_read().unwrap();
317 assert!(read_result.is_none(), "Reader should be blocked by writer");
318 }
319
320 #[test]
321 fn test_reader_excludes_writer() {
322 let temp = TempDir::new().unwrap();
323 let lock = Arc::new(RepoLock::new(temp.path()));
324
325 let _read_guard = lock.read().unwrap();
326 let write_result = lock.try_write().unwrap();
327 assert!(write_result.is_none(), "Writer should be blocked by reader");
328 }
329
330 #[test]
331 fn test_lock_released_on_drop() {
332 let temp = TempDir::new().unwrap();
333 let lock = RepoLock::new(temp.path());
334
335 {
336 let _guard = lock.write().unwrap();
337 }
338
339 let _guard2 = lock.read().unwrap();
340 }
341
342 #[test]
345 fn same_thread_write_is_reentrant() {
346 let temp = TempDir::new().unwrap();
347 let lock = RepoLock::new(temp.path());
348
349 let _a = lock.write().unwrap();
350 let _b = lock.write().unwrap();
351 }
354
355 #[test]
358 fn same_thread_read_under_write_does_not_deadlock() {
359 let temp = TempDir::new().unwrap();
360 let lock = RepoLock::new(temp.path());
361
362 let _w = lock.write().unwrap();
363 let _r = lock.read().unwrap();
364 }
365
366 #[test]
369 fn distinct_threads_still_exclude() {
370 let temp = TempDir::new().unwrap();
371 let lock = Arc::new(RepoLock::new(temp.path()));
372
373 let (acquired_tx, acquired_rx) = mpsc::channel();
374 let (release_tx, release_rx) = mpsc::channel();
375 let lock_a = Arc::clone(&lock);
376 let handle = thread::spawn(move || {
377 let _g = lock_a.write().unwrap();
378 acquired_tx.send(()).unwrap();
379 release_rx.recv().unwrap();
380 });
381
382 acquired_rx.recv().unwrap();
383 assert!(
384 lock.try_write().unwrap().is_none(),
385 "a second thread must not acquire the write lock"
386 );
387
388 release_tx.send(()).unwrap();
389 handle.join().unwrap();
390
391 assert!(
392 lock.try_write().unwrap().is_some(),
393 "write lock is available once the owning thread releases"
394 );
395 }
396
397 #[test]
400 fn reentrant_release_keeps_lock_until_outermost_drop() {
401 let temp = TempDir::new().unwrap();
402 let lock = Arc::new(RepoLock::new(temp.path()));
403
404 let a1 = lock.write().unwrap();
405 let a2 = lock.write().unwrap();
406
407 let other = |lock: &Arc<RepoLock>| {
408 let lock = Arc::clone(lock);
409 thread::spawn(move || lock.try_write().unwrap().is_none())
410 .join()
411 .unwrap()
412 };
413
414 assert!(other(&lock), "excluded while held at depth 2");
415 drop(a2);
416 assert!(other(&lock), "still excluded while held at depth 1");
417 drop(a1);
418
419 let lock_b = Arc::clone(&lock);
420 let now_available = thread::spawn(move || lock_b.try_write().unwrap().is_some())
421 .join()
422 .unwrap();
423 assert!(now_available, "available after the outermost guard drops");
424 }
425
426 #[test]
434 fn try_write_is_non_reentrant_even_for_owner() {
435 let temp = TempDir::new().unwrap();
436 let lock = RepoLock::new(temp.path());
437
438 let _held = lock.write().unwrap();
439 assert!(
440 lock.try_write().unwrap().is_none(),
441 "try_write must report contention even for the lock's own owner thread"
442 );
443 }
444}