use std::{
collections::HashMap,
fs::File,
io,
path::{Path, PathBuf},
sync::{Arc, Condvar, Mutex, MutexGuard, OnceLock},
thread::{self, ThreadId},
};
use fs2::FileExt;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum LockError {
#[error("failed to acquire lock: {0}")]
Acquire(#[source] io::Error),
#[error("lock file not accessible: {0}")]
Io(#[source] io::Error),
}
pub type Result<T> = std::result::Result<T, LockError>;
struct GateState {
owner: Option<ThreadId>,
depth: usize,
flock: Option<File>,
}
struct Entry {
gate: Mutex<GateState>,
cv: Condvar,
}
impl Entry {
fn new() -> Self {
Self {
gate: Mutex::new(GateState {
owner: None,
depth: 0,
flock: None,
}),
cv: Condvar::new(),
}
}
}
static REGISTRY: OnceLock<Mutex<HashMap<PathBuf, Arc<Entry>>>> = OnceLock::new();
fn registry() -> &'static Mutex<HashMap<PathBuf, Arc<Entry>>> {
REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
}
fn entry_for(key: PathBuf) -> Arc<Entry> {
let mut map = registry().lock().unwrap_or_else(|e| e.into_inner());
Arc::clone(map.entry(key).or_insert_with(|| Arc::new(Entry::new())))
}
fn lock_gate(entry: &Entry) -> MutexGuard<'_, GateState> {
entry.gate.lock().unwrap_or_else(|e| e.into_inner())
}
pub struct ReadLockGuard {
_file: Option<File>,
}
impl Drop for ReadLockGuard {
fn drop(&mut self) {
if let Some(file) = &self._file {
let _ = file.unlock();
}
}
}
pub struct WriteLockGuard {
entry: Arc<Entry>,
}
impl Drop for WriteLockGuard {
fn drop(&mut self) {
let mut state = lock_gate(&self.entry);
if state.depth > 0 {
state.depth -= 1;
}
if state.depth == 0 {
state.owner = None;
state.flock = None;
self.entry.cv.notify_one();
}
}
}
pub struct RepoLock {
lock_path: PathBuf,
}
impl RepoLock {
pub fn new(repo_root: &Path) -> Self {
let lock_path = repo_root.join(".heddle/locks/repo.lock");
Self { lock_path }
}
pub fn at(lock_path: PathBuf) -> Self {
Self { lock_path }
}
pub fn read(&self) -> Result<ReadLockGuard> {
self.ensure_lock_dir()?;
let entry = entry_for(self.registry_key());
{
let state = lock_gate(&entry);
if state.owner == Some(thread::current().id()) {
return Ok(ReadLockGuard { _file: None });
}
}
let file = self.open_lock_file()?;
file.lock_shared().map_err(LockError::Acquire)?;
Ok(ReadLockGuard { _file: Some(file) })
}
pub fn write(&self) -> Result<WriteLockGuard> {
self.ensure_lock_dir()?;
let entry = entry_for(self.registry_key());
let tid = thread::current().id();
let mut state = lock_gate(&entry);
loop {
match state.owner {
Some(owner) if owner == tid => {
state.depth += 1;
return Ok(WriteLockGuard {
entry: Arc::clone(&entry),
});
}
None => {
let file = self.open_lock_file()?;
file.lock_exclusive().map_err(LockError::Acquire)?;
state.owner = Some(tid);
state.depth = 1;
state.flock = Some(file);
return Ok(WriteLockGuard {
entry: Arc::clone(&entry),
});
}
Some(_) => {
state = entry.cv.wait(state).unwrap_or_else(|e| e.into_inner());
}
}
}
}
pub fn try_read(&self) -> Result<Option<ReadLockGuard>> {
self.ensure_lock_dir()?;
let file = self.open_lock_file()?;
match file.try_lock_shared() {
Ok(()) => Ok(Some(ReadLockGuard { _file: Some(file) })),
Err(_) => Ok(None),
}
}
pub fn try_write(&self) -> Result<Option<WriteLockGuard>> {
self.ensure_lock_dir()?;
let entry = entry_for(self.registry_key());
let mut state = lock_gate(&entry);
match state.owner {
Some(_) => Ok(None),
None => {
let file = self.open_lock_file()?;
match file.try_lock_exclusive() {
Ok(()) => {
state.owner = Some(thread::current().id());
state.depth = 1;
state.flock = Some(file);
Ok(Some(WriteLockGuard {
entry: Arc::clone(&entry),
}))
}
Err(_) => Ok(None),
}
}
}
}
fn ensure_lock_dir(&self) -> Result<()> {
if let Some(parent) = self.lock_path.parent() {
std::fs::create_dir_all(parent).map_err(LockError::Io)?;
}
Ok(())
}
fn registry_key(&self) -> PathBuf {
match self.lock_path.parent() {
Some(parent) => {
let canon_parent = parent
.canonicalize()
.unwrap_or_else(|_| parent.to_path_buf());
match self.lock_path.file_name() {
Some(name) => canon_parent.join(name),
None => canon_parent,
}
}
None => self.lock_path.clone(),
}
}
fn open_lock_file(&self) -> Result<File> {
File::create(&self.lock_path).map_err(LockError::Io)
}
}
pub trait RepositoryLockExt {
fn locker(&self) -> RepoLock;
}
#[cfg(test)]
mod tests {
use std::{
sync::{
Arc,
mpsc::{self},
},
thread,
};
use tempfile::TempDir;
use super::*;
#[test]
fn test_read_lock_acquired() {
let temp = TempDir::new().unwrap();
let lock = RepoLock::new(temp.path());
let guard = lock.read().unwrap();
assert!(std::mem::size_of_val(&guard) > 0);
}
#[test]
fn test_write_lock_acquired() {
let temp = TempDir::new().unwrap();
let lock = RepoLock::new(temp.path());
let guard = lock.write().unwrap();
assert!(std::mem::size_of_val(&guard) > 0);
}
#[test]
fn test_multiple_readers() {
let temp = TempDir::new().unwrap();
let lock = Arc::new(RepoLock::new(temp.path()));
let mut handles = vec![];
for _ in 0..10 {
let lock = Arc::clone(&lock);
let handle = thread::spawn(move || {
let _guard = lock.read().unwrap();
thread::sleep(std::time::Duration::from_millis(10));
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_writer_excludes_reader() {
let temp = TempDir::new().unwrap();
let lock = Arc::new(RepoLock::new(temp.path()));
let _write_guard = lock.write().unwrap();
let read_result = lock.try_read().unwrap();
assert!(read_result.is_none(), "Reader should be blocked by writer");
}
#[test]
fn test_reader_excludes_writer() {
let temp = TempDir::new().unwrap();
let lock = Arc::new(RepoLock::new(temp.path()));
let _read_guard = lock.read().unwrap();
let write_result = lock.try_write().unwrap();
assert!(write_result.is_none(), "Writer should be blocked by reader");
}
#[test]
fn test_lock_released_on_drop() {
let temp = TempDir::new().unwrap();
let lock = RepoLock::new(temp.path());
{
let _guard = lock.write().unwrap();
}
let _guard2 = lock.read().unwrap();
}
#[test]
fn same_thread_write_is_reentrant() {
let temp = TempDir::new().unwrap();
let lock = RepoLock::new(temp.path());
let _a = lock.write().unwrap();
let _b = lock.write().unwrap();
}
#[test]
fn same_thread_read_under_write_does_not_deadlock() {
let temp = TempDir::new().unwrap();
let lock = RepoLock::new(temp.path());
let _w = lock.write().unwrap();
let _r = lock.read().unwrap();
}
#[test]
fn distinct_threads_still_exclude() {
let temp = TempDir::new().unwrap();
let lock = Arc::new(RepoLock::new(temp.path()));
let (acquired_tx, acquired_rx) = mpsc::channel();
let (release_tx, release_rx) = mpsc::channel();
let lock_a = Arc::clone(&lock);
let handle = thread::spawn(move || {
let _g = lock_a.write().unwrap();
acquired_tx.send(()).unwrap();
release_rx.recv().unwrap();
});
acquired_rx.recv().unwrap();
assert!(
lock.try_write().unwrap().is_none(),
"a second thread must not acquire the write lock"
);
release_tx.send(()).unwrap();
handle.join().unwrap();
assert!(
lock.try_write().unwrap().is_some(),
"write lock is available once the owning thread releases"
);
}
#[test]
fn reentrant_release_keeps_lock_until_outermost_drop() {
let temp = TempDir::new().unwrap();
let lock = Arc::new(RepoLock::new(temp.path()));
let a1 = lock.write().unwrap();
let a2 = lock.write().unwrap();
let other = |lock: &Arc<RepoLock>| {
let lock = Arc::clone(lock);
thread::spawn(move || lock.try_write().unwrap().is_none())
.join()
.unwrap()
};
assert!(other(&lock), "excluded while held at depth 2");
drop(a2);
assert!(other(&lock), "still excluded while held at depth 1");
drop(a1);
let lock_b = Arc::clone(&lock);
let now_available = thread::spawn(move || lock_b.try_write().unwrap().is_some())
.join()
.unwrap();
assert!(now_available, "available after the outermost guard drops");
}
#[test]
fn try_write_is_non_reentrant_even_for_owner() {
let temp = TempDir::new().unwrap();
let lock = RepoLock::new(temp.path());
let _held = lock.write().unwrap();
assert!(
lock.try_write().unwrap().is_none(),
"try_write must report contention even for the lock's own owner thread"
);
}
}