use anyhow::{Context, Result};
use chrono::Utc;
use rusqlite::{Connection, OptionalExtension, params};
use std::path::PathBuf;
use std::sync::Mutex;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct LockRecord {
pub lock_id: String,
pub lock_type: String,
pub resource_path: String,
pub agent_id: String,
pub process_id: i32,
pub acquired_at: i64,
pub expires_at: Option<i64>,
pub hostname: String,
}
pub struct LockStore {
conn: Mutex<Connection>,
current_pid: i32,
current_hostname: String,
}
impl LockStore {
pub async fn new_default() -> Result<Self> {
let db_path = Self::default_db_path()?;
Self::new_with_path(&db_path).await
}
fn default_db_path() -> Result<PathBuf> {
let home = dirs::home_dir().context("Could not determine home directory")?;
let brainwires_dir = home.join(".brainwires");
std::fs::create_dir_all(&brainwires_dir)
.context("Failed to create ~/.brainwires directory")?;
Ok(brainwires_dir.join("locks.db"))
}
pub async fn new_with_path(db_path: &PathBuf) -> Result<Self> {
let current_pid = std::process::id() as i32;
let current_hostname = gethostname::gethostname().to_string_lossy().to_string();
let conn = Connection::open(db_path)
.with_context(|| format!("Failed to open lock database at {:?}", db_path))?;
conn.execute_batch(
"PRAGMA journal_mode=WAL;
PRAGMA busy_timeout=5000;
PRAGMA synchronous=NORMAL;",
)
.context("Failed to configure SQLite")?;
let store = Self {
conn: Mutex::new(conn),
current_pid,
current_hostname,
};
store.ensure_table()?;
Ok(store)
}
fn ensure_table(&self) -> Result<()> {
let conn = self.conn.lock().expect("SQLite connection lock poisoned");
conn.execute(
"CREATE TABLE IF NOT EXISTS locks (
lock_id TEXT PRIMARY KEY,
lock_type TEXT NOT NULL,
resource_path TEXT NOT NULL,
agent_id TEXT NOT NULL,
process_id INTEGER NOT NULL,
acquired_at INTEGER NOT NULL,
expires_at INTEGER,
hostname TEXT NOT NULL
)",
[],
)
.context("Failed to create locks table")?;
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_locks_agent ON locks(agent_id, process_id, hostname)",
[],
)
.context("Failed to create locks index")?;
Ok(())
}
fn generate_lock_id(lock_type: &str, resource_path: &str) -> String {
format!("{}:{}", lock_type, resource_path)
}
pub async fn try_acquire(
&self,
lock_type: &str,
resource_path: &str,
agent_id: &str,
timeout: Option<Duration>,
) -> Result<bool> {
let lock_id = Self::generate_lock_id(lock_type, resource_path);
let conn = self.conn.lock().expect("SQLite connection lock poisoned");
let existing: Option<LockRecord> = conn
.query_row(
"SELECT lock_id, lock_type, resource_path, agent_id, process_id,
acquired_at, expires_at, hostname
FROM locks WHERE lock_id = ?",
[&lock_id],
|row| {
Ok(LockRecord {
lock_id: row.get(0)?,
lock_type: row.get(1)?,
resource_path: row.get(2)?,
agent_id: row.get(3)?,
process_id: row.get(4)?,
acquired_at: row.get(5)?,
expires_at: row.get(6)?,
hostname: row.get(7)?,
})
},
)
.ok();
if let Some(ref existing) = existing {
if existing.agent_id == agent_id
&& existing.process_id == self.current_pid
&& existing.hostname == self.current_hostname
{
return Ok(true);
}
if self.is_lock_stale(existing) {
conn.execute("DELETE FROM locks WHERE lock_id = ?", [&lock_id])
.context("Failed to remove stale lock")?;
} else {
return Ok(false);
}
}
let now = Utc::now().timestamp_millis();
let expires_at = timeout.map(|t| now + t.as_millis() as i64);
conn.execute(
"INSERT OR REPLACE INTO locks
(lock_id, lock_type, resource_path, agent_id, process_id, acquired_at, expires_at, hostname)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
params![
lock_id,
lock_type,
resource_path,
agent_id,
self.current_pid,
now,
expires_at,
self.current_hostname,
],
)
.context("Failed to acquire lock")?;
Ok(true)
}
pub async fn release(
&self,
lock_type: &str,
resource_path: &str,
agent_id: &str,
) -> Result<bool> {
let lock_id = Self::generate_lock_id(lock_type, resource_path);
let conn = self.conn.lock().expect("SQLite connection lock poisoned");
let deleted = conn.execute(
"DELETE FROM locks WHERE lock_id = ? AND agent_id = ? AND process_id = ? AND hostname = ?",
params![lock_id, agent_id, self.current_pid, self.current_hostname],
).context("Failed to release lock")?;
Ok(deleted > 0)
}
pub async fn release_all_for_agent(&self, agent_id: &str) -> Result<usize> {
let conn = self.conn.lock().expect("SQLite connection lock poisoned");
let deleted = conn
.execute(
"DELETE FROM locks WHERE agent_id = ? AND process_id = ? AND hostname = ?",
params![agent_id, self.current_pid, self.current_hostname],
)
.context("Failed to release agent locks")?;
Ok(deleted)
}
pub async fn is_locked(
&self,
lock_type: &str,
resource_path: &str,
) -> Result<Option<LockRecord>> {
let lock_id = Self::generate_lock_id(lock_type, resource_path);
let conn = self.conn.lock().expect("SQLite connection lock poisoned");
conn.query_row(
"SELECT lock_id, lock_type, resource_path, agent_id, process_id,
acquired_at, expires_at, hostname
FROM locks WHERE lock_id = ?",
[&lock_id],
|row| {
Ok(LockRecord {
lock_id: row.get(0)?,
lock_type: row.get(1)?,
resource_path: row.get(2)?,
agent_id: row.get(3)?,
process_id: row.get(4)?,
acquired_at: row.get(5)?,
expires_at: row.get(6)?,
hostname: row.get(7)?,
})
},
)
.optional()
.context("Failed to check lock status")
}
pub async fn cleanup_stale(&self) -> Result<usize> {
let now = Utc::now().timestamp_millis();
let conn = self.conn.lock().expect("SQLite connection lock poisoned");
let expired_count = conn
.execute(
"DELETE FROM locks WHERE expires_at IS NOT NULL AND expires_at < ?",
[now],
)
.context("Failed to cleanup expired locks")?;
let mut stmt = conn
.prepare(
"SELECT lock_id, lock_type, resource_path, agent_id, process_id,
acquired_at, expires_at, hostname
FROM locks WHERE hostname = ?",
)
.context("Failed to prepare stale lock query")?;
let locks: Vec<LockRecord> = stmt
.query_map([&self.current_hostname], |row| {
Ok(LockRecord {
lock_id: row.get(0)?,
lock_type: row.get(1)?,
resource_path: row.get(2)?,
agent_id: row.get(3)?,
process_id: row.get(4)?,
acquired_at: row.get(5)?,
expires_at: row.get(6)?,
hostname: row.get(7)?,
})
})
.context("Failed to query locks")?
.filter_map(|r| r.ok())
.collect();
drop(stmt);
let mut stale_count = 0;
for lock in locks {
if !Self::is_process_alive(lock.process_id) {
conn.execute("DELETE FROM locks WHERE lock_id = ?", [&lock.lock_id])
.ok();
stale_count += 1;
}
}
Ok(expired_count + stale_count)
}
pub async fn list_locks(&self) -> Result<Vec<LockRecord>> {
let conn = self.conn.lock().expect("SQLite connection lock poisoned");
let mut stmt = conn
.prepare(
"SELECT lock_id, lock_type, resource_path, agent_id, process_id,
acquired_at, expires_at, hostname
FROM locks",
)
.context("Failed to prepare list locks query")?;
let locks = stmt
.query_map([], |row| {
Ok(LockRecord {
lock_id: row.get(0)?,
lock_type: row.get(1)?,
resource_path: row.get(2)?,
agent_id: row.get(3)?,
process_id: row.get(4)?,
acquired_at: row.get(5)?,
expires_at: row.get(6)?,
hostname: row.get(7)?,
})
})
.context("Failed to query locks")?
.filter_map(|r| r.ok())
.collect();
Ok(locks)
}
pub async fn force_release(&self, lock_id: &str) -> Result<()> {
let conn = self.conn.lock().expect("SQLite connection lock poisoned");
conn.execute("DELETE FROM locks WHERE lock_id = ?", [lock_id])
.context("Failed to force release lock")?;
Ok(())
}
fn is_lock_stale(&self, lock: &LockRecord) -> bool {
let now = Utc::now().timestamp_millis();
if let Some(expires_at) = lock.expires_at
&& now > expires_at
{
return true;
}
if lock.hostname == self.current_hostname && !Self::is_process_alive(lock.process_id) {
return true;
}
false
}
#[cfg(unix)]
fn is_process_alive(pid: i32) -> bool {
unsafe { libc::kill(pid, 0) == 0 }
}
#[cfg(windows)]
fn is_process_alive(pid: i32) -> bool {
use windows_sys::Win32::Foundation::{CloseHandle, STILL_ACTIVE};
use windows_sys::Win32::System::Threading::{
GetExitCodeProcess, OpenProcess, PROCESS_QUERY_LIMITED_INFORMATION,
};
unsafe {
let handle = OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, 0, pid as u32);
if handle == 0 {
return false;
}
let mut exit_code: u32 = 0;
let result = GetExitCodeProcess(handle, &mut exit_code);
CloseHandle(handle);
result != 0 && exit_code == STILL_ACTIVE
}
}
#[cfg(not(any(unix, windows)))]
fn is_process_alive(_pid: i32) -> bool {
true
}
pub async fn stats(&self) -> Result<LockStats> {
let locks = self.list_locks().await?;
let mut file_read_locks = 0;
let mut file_write_locks = 0;
let mut build_locks = 0;
let mut test_locks = 0;
let mut stale_locks = 0;
for lock in &locks {
match lock.lock_type.as_str() {
"file_read" => file_read_locks += 1,
"file_write" => file_write_locks += 1,
"build" => build_locks += 1,
"test" => test_locks += 1,
"build_test" => {
build_locks += 1;
test_locks += 1;
}
_ => {}
}
if self.is_lock_stale(lock) {
stale_locks += 1;
}
}
Ok(LockStats {
total_locks: locks.len(),
file_read_locks,
file_write_locks,
build_locks,
test_locks,
stale_locks,
})
}
}
#[derive(Debug, Clone)]
pub struct LockStats {
pub total_locks: usize,
pub file_read_locks: usize,
pub file_write_locks: usize,
pub build_locks: usize,
pub test_locks: usize,
pub stale_locks: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
async fn create_test_store() -> (LockStore, TempDir) {
let temp = TempDir::new().unwrap();
let db_path = temp.path().join("test_locks.db");
let store = LockStore::new_with_path(&db_path).await.unwrap();
(store, temp)
}
#[tokio::test]
async fn test_acquire_and_release_lock() {
let (store, _temp) = create_test_store().await;
let acquired = store
.try_acquire("file_write", "/test/file.txt", "agent-1", None)
.await
.unwrap();
assert!(acquired);
let lock = store
.is_locked("file_write", "/test/file.txt")
.await
.unwrap();
assert!(lock.is_some());
assert_eq!(lock.unwrap().agent_id, "agent-1");
let released = store
.release("file_write", "/test/file.txt", "agent-1")
.await
.unwrap();
assert!(released);
let lock = store
.is_locked("file_write", "/test/file.txt")
.await
.unwrap();
assert!(lock.is_none());
}
#[tokio::test]
async fn test_idempotent_acquire() {
let (store, _temp) = create_test_store().await;
let acquired1 = store
.try_acquire("file_write", "/test/file.txt", "agent-1", None)
.await
.unwrap();
let acquired2 = store
.try_acquire("file_write", "/test/file.txt", "agent-1", None)
.await
.unwrap();
assert!(acquired1);
assert!(acquired2);
}
#[tokio::test]
async fn test_lock_conflict() {
let (store, _temp) = create_test_store().await;
let acquired1 = store
.try_acquire("file_write", "/test/file.txt", "agent-1", None)
.await
.unwrap();
assert!(acquired1);
let acquired2 = store
.try_acquire("file_write", "/test/file.txt", "agent-2", None)
.await
.unwrap();
assert!(!acquired2);
}
#[tokio::test]
async fn test_release_all_for_agent() {
let (store, _temp) = create_test_store().await;
store
.try_acquire("file_write", "/test/file1.txt", "agent-1", None)
.await
.unwrap();
store
.try_acquire("file_read", "/test/file2.txt", "agent-1", None)
.await
.unwrap();
store
.try_acquire("build", "/test/project", "agent-1", None)
.await
.unwrap();
let released = store.release_all_for_agent("agent-1").await.unwrap();
assert_eq!(released, 3);
let locks = store.list_locks().await.unwrap();
assert!(locks.is_empty());
}
#[tokio::test]
async fn test_expired_lock_cleanup() {
let (store, _temp) = create_test_store().await;
store
.try_acquire(
"file_write",
"/test/file.txt",
"agent-1",
Some(Duration::from_millis(1)),
)
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(10)).await;
let cleaned = store.cleanup_stale().await.unwrap();
assert_eq!(cleaned, 1);
let lock = store
.is_locked("file_write", "/test/file.txt")
.await
.unwrap();
assert!(lock.is_none());
}
#[tokio::test]
async fn test_list_locks() {
let (store, _temp) = create_test_store().await;
store
.try_acquire("file_write", "/test/file1.txt", "agent-1", None)
.await
.unwrap();
store
.try_acquire("file_read", "/test/file2.txt", "agent-1", None)
.await
.unwrap();
let locks = store.list_locks().await.unwrap();
assert_eq!(locks.len(), 2);
}
#[tokio::test]
async fn test_stats() {
let (store, _temp) = create_test_store().await;
store
.try_acquire("file_write", "/test/file1.txt", "agent-1", None)
.await
.unwrap();
store
.try_acquire("file_read", "/test/file2.txt", "agent-1", None)
.await
.unwrap();
store
.try_acquire("build", "/test/project", "agent-1", None)
.await
.unwrap();
let stats = store.stats().await.unwrap();
assert_eq!(stats.total_locks, 3);
assert_eq!(stats.file_write_locks, 1);
assert_eq!(stats.file_read_locks, 1);
assert_eq!(stats.build_locks, 1);
}
#[test]
fn test_is_process_alive() {
let current_pid = std::process::id() as i32;
assert!(LockStore::is_process_alive(current_pid));
#[cfg(unix)]
{
}
}
}