use anyhow::Result;
use fs2::FileExt;
use std::fs::{self, File};
use std::path::PathBuf;
#[derive(Clone)]
pub struct WorkspaceLock {
path: PathBuf,
}
pub struct WorkspaceLockGuard {
file: File,
}
impl Drop for WorkspaceLockGuard {
fn drop(&mut self) {
let _ = self.file.unlock();
}
}
impl WorkspaceLock {
pub fn new() -> Result<Self> {
let state_dir = crate::agent::get_state_dir()?;
let path = state_dir.join("workspace.lock");
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
Ok(Self { path })
}
pub fn acquire(&self) -> Result<WorkspaceLockGuard> {
let file = File::create(&self.path)?;
file.lock_exclusive()?;
Ok(WorkspaceLockGuard { file })
}
pub fn try_acquire(&self) -> Result<Option<WorkspaceLockGuard>> {
let file = File::create(&self.path)?;
match file.try_lock_exclusive() {
Ok(()) => Ok(Some(WorkspaceLockGuard { file })),
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => Ok(None),
#[cfg(unix)]
Err(ref e) if e.raw_os_error() == Some(35) || e.raw_os_error() == Some(11) => {
Ok(None)
}
Err(e) => Err(e.into()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Barrier};
fn test_lock(dir: &std::path::Path) -> WorkspaceLock {
WorkspaceLock {
path: dir.join("test.lock"),
}
}
#[test]
fn acquire_and_release() {
let tmp = tempfile::tempdir().unwrap();
let lock = test_lock(tmp.path());
let guard = lock.acquire().unwrap();
drop(guard);
let _guard2 = lock.acquire().unwrap();
}
#[test]
fn try_acquire_returns_none_when_held() {
let tmp = tempfile::tempdir().unwrap();
let lock_path = tmp.path().join("test.lock");
let file = File::create(&lock_path).unwrap();
file.lock_exclusive().unwrap();
let lock = WorkspaceLock {
path: lock_path.clone(),
};
let result = lock.try_acquire().unwrap();
assert!(result.is_none(), "try_acquire should return None when held");
file.unlock().unwrap();
drop(file);
let result = lock.try_acquire().unwrap();
assert!(result.is_some(), "try_acquire should succeed after release");
}
#[test]
fn guard_drop_releases_lock() {
let tmp = tempfile::tempdir().unwrap();
let lock = test_lock(tmp.path());
{
let _guard = lock.acquire().unwrap();
}
let _guard2 = lock.acquire().unwrap();
}
#[test]
fn concurrent_threads_serialize() {
let tmp = tempfile::tempdir().unwrap();
let path = tmp.path().to_path_buf();
let counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
let barrier = Arc::new(Barrier::new(3));
let handles: Vec<_> = (0..3)
.map(|_| {
let p = path.clone();
let c = counter.clone();
let b = barrier.clone();
std::thread::spawn(move || {
let lock = test_lock(&p);
b.wait(); let _guard = lock.acquire().unwrap();
c.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
})
})
.collect();
for h in handles {
h.join().unwrap();
}
assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 3);
}
}