1use anyhow::{Context, Result};
10use chrono::Utc;
11use rusqlite::{Connection, OptionalExtension, params};
12use std::path::PathBuf;
13use std::sync::Mutex;
14use std::time::Duration;
15
16#[derive(Debug, Clone)]
18pub struct LockRecord {
19 pub lock_id: String,
21 pub lock_type: String,
23 pub resource_path: String,
25 pub agent_id: String,
27 pub process_id: i32,
29 pub acquired_at: i64,
31 pub expires_at: Option<i64>,
33 pub hostname: String,
35}
36
37pub struct LockStore {
39 conn: Mutex<Connection>,
41 current_pid: i32,
43 current_hostname: String,
45}
46
47impl LockStore {
48 pub async fn new_default() -> Result<Self> {
50 let db_path = Self::default_db_path()?;
51 Self::new_with_path(&db_path).await
52 }
53
54 fn default_db_path() -> Result<PathBuf> {
56 let home = dirs::home_dir().context("Could not determine home directory")?;
57 let brainwires_dir = home.join(".brainwires");
58 std::fs::create_dir_all(&brainwires_dir)
59 .context("Failed to create ~/.brainwires directory")?;
60 Ok(brainwires_dir.join("locks.db"))
61 }
62
63 pub async fn new_with_path(db_path: &PathBuf) -> Result<Self> {
65 let current_pid = std::process::id() as i32;
66 let current_hostname = gethostname::gethostname().to_string_lossy().to_string();
67
68 let conn = Connection::open(db_path)
70 .with_context(|| format!("Failed to open lock database at {:?}", db_path))?;
71
72 conn.execute_batch(
74 "PRAGMA journal_mode=WAL;
75 PRAGMA busy_timeout=5000;
76 PRAGMA synchronous=NORMAL;",
77 )
78 .context("Failed to configure SQLite")?;
79
80 let store = Self {
81 conn: Mutex::new(conn),
82 current_pid,
83 current_hostname,
84 };
85
86 store.ensure_table()?;
88
89 Ok(store)
90 }
91
92 fn ensure_table(&self) -> Result<()> {
94 let conn = self.conn.lock().expect("SQLite connection lock poisoned");
95 conn.execute(
96 "CREATE TABLE IF NOT EXISTS locks (
97 lock_id TEXT PRIMARY KEY,
98 lock_type TEXT NOT NULL,
99 resource_path TEXT NOT NULL,
100 agent_id TEXT NOT NULL,
101 process_id INTEGER NOT NULL,
102 acquired_at INTEGER NOT NULL,
103 expires_at INTEGER,
104 hostname TEXT NOT NULL
105 )",
106 [],
107 )
108 .context("Failed to create locks table")?;
109
110 conn.execute(
112 "CREATE INDEX IF NOT EXISTS idx_locks_agent ON locks(agent_id, process_id, hostname)",
113 [],
114 )
115 .context("Failed to create locks index")?;
116
117 Ok(())
118 }
119
120 fn generate_lock_id(lock_type: &str, resource_path: &str) -> String {
122 format!("{}:{}", lock_type, resource_path)
123 }
124
125 pub async fn try_acquire(
127 &self,
128 lock_type: &str,
129 resource_path: &str,
130 agent_id: &str,
131 timeout: Option<Duration>,
132 ) -> Result<bool> {
133 let lock_id = Self::generate_lock_id(lock_type, resource_path);
134 let conn = self.conn.lock().expect("SQLite connection lock poisoned");
135
136 let existing: Option<LockRecord> = conn
138 .query_row(
139 "SELECT lock_id, lock_type, resource_path, agent_id, process_id,
140 acquired_at, expires_at, hostname
141 FROM locks WHERE lock_id = ?",
142 [&lock_id],
143 |row| {
144 Ok(LockRecord {
145 lock_id: row.get(0)?,
146 lock_type: row.get(1)?,
147 resource_path: row.get(2)?,
148 agent_id: row.get(3)?,
149 process_id: row.get(4)?,
150 acquired_at: row.get(5)?,
151 expires_at: row.get(6)?,
152 hostname: row.get(7)?,
153 })
154 },
155 )
156 .ok();
157
158 if let Some(ref existing) = existing {
159 if existing.agent_id == agent_id
161 && existing.process_id == self.current_pid
162 && existing.hostname == self.current_hostname
163 {
164 return Ok(true);
165 }
166
167 if self.is_lock_stale(existing) {
169 conn.execute("DELETE FROM locks WHERE lock_id = ?", [&lock_id])
171 .context("Failed to remove stale lock")?;
172 } else {
173 return Ok(false);
175 }
176 }
177
178 let now = Utc::now().timestamp_millis();
180 let expires_at = timeout.map(|t| now + t.as_millis() as i64);
181
182 conn.execute(
183 "INSERT OR REPLACE INTO locks
184 (lock_id, lock_type, resource_path, agent_id, process_id, acquired_at, expires_at, hostname)
185 VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
186 params![
187 lock_id,
188 lock_type,
189 resource_path,
190 agent_id,
191 self.current_pid,
192 now,
193 expires_at,
194 self.current_hostname,
195 ],
196 )
197 .context("Failed to acquire lock")?;
198
199 Ok(true)
200 }
201
202 pub async fn release(
204 &self,
205 lock_type: &str,
206 resource_path: &str,
207 agent_id: &str,
208 ) -> Result<bool> {
209 let lock_id = Self::generate_lock_id(lock_type, resource_path);
210 let conn = self.conn.lock().expect("SQLite connection lock poisoned");
211
212 let deleted = conn.execute(
214 "DELETE FROM locks WHERE lock_id = ? AND agent_id = ? AND process_id = ? AND hostname = ?",
215 params![lock_id, agent_id, self.current_pid, self.current_hostname],
216 ).context("Failed to release lock")?;
217
218 Ok(deleted > 0)
219 }
220
221 pub async fn release_all_for_agent(&self, agent_id: &str) -> Result<usize> {
223 let conn = self.conn.lock().expect("SQLite connection lock poisoned");
224
225 let deleted = conn
226 .execute(
227 "DELETE FROM locks WHERE agent_id = ? AND process_id = ? AND hostname = ?",
228 params![agent_id, self.current_pid, self.current_hostname],
229 )
230 .context("Failed to release agent locks")?;
231
232 Ok(deleted)
233 }
234
235 pub async fn is_locked(
237 &self,
238 lock_type: &str,
239 resource_path: &str,
240 ) -> Result<Option<LockRecord>> {
241 let lock_id = Self::generate_lock_id(lock_type, resource_path);
242 let conn = self.conn.lock().expect("SQLite connection lock poisoned");
243
244 conn.query_row(
245 "SELECT lock_id, lock_type, resource_path, agent_id, process_id,
246 acquired_at, expires_at, hostname
247 FROM locks WHERE lock_id = ?",
248 [&lock_id],
249 |row| {
250 Ok(LockRecord {
251 lock_id: row.get(0)?,
252 lock_type: row.get(1)?,
253 resource_path: row.get(2)?,
254 agent_id: row.get(3)?,
255 process_id: row.get(4)?,
256 acquired_at: row.get(5)?,
257 expires_at: row.get(6)?,
258 hostname: row.get(7)?,
259 })
260 },
261 )
262 .optional()
263 .context("Failed to check lock status")
264 }
265
266 pub async fn cleanup_stale(&self) -> Result<usize> {
268 let now = Utc::now().timestamp_millis();
269 let conn = self.conn.lock().expect("SQLite connection lock poisoned");
270
271 let expired_count = conn
273 .execute(
274 "DELETE FROM locks WHERE expires_at IS NOT NULL AND expires_at < ?",
275 [now],
276 )
277 .context("Failed to cleanup expired locks")?;
278
279 let mut stmt = conn
281 .prepare(
282 "SELECT lock_id, lock_type, resource_path, agent_id, process_id,
283 acquired_at, expires_at, hostname
284 FROM locks WHERE hostname = ?",
285 )
286 .context("Failed to prepare stale lock query")?;
287
288 let locks: Vec<LockRecord> = stmt
289 .query_map([&self.current_hostname], |row| {
290 Ok(LockRecord {
291 lock_id: row.get(0)?,
292 lock_type: row.get(1)?,
293 resource_path: row.get(2)?,
294 agent_id: row.get(3)?,
295 process_id: row.get(4)?,
296 acquired_at: row.get(5)?,
297 expires_at: row.get(6)?,
298 hostname: row.get(7)?,
299 })
300 })
301 .context("Failed to query locks")?
302 .filter_map(|r| r.ok())
303 .collect();
304
305 drop(stmt);
306
307 let mut stale_count = 0;
309 for lock in locks {
310 if !Self::is_process_alive(lock.process_id) {
311 conn.execute("DELETE FROM locks WHERE lock_id = ?", [&lock.lock_id])
312 .ok();
313 stale_count += 1;
314 }
315 }
316
317 Ok(expired_count + stale_count)
318 }
319
320 pub async fn list_locks(&self) -> Result<Vec<LockRecord>> {
322 let conn = self.conn.lock().expect("SQLite connection lock poisoned");
323
324 let mut stmt = conn
325 .prepare(
326 "SELECT lock_id, lock_type, resource_path, agent_id, process_id,
327 acquired_at, expires_at, hostname
328 FROM locks",
329 )
330 .context("Failed to prepare list locks query")?;
331
332 let locks = stmt
333 .query_map([], |row| {
334 Ok(LockRecord {
335 lock_id: row.get(0)?,
336 lock_type: row.get(1)?,
337 resource_path: row.get(2)?,
338 agent_id: row.get(3)?,
339 process_id: row.get(4)?,
340 acquired_at: row.get(5)?,
341 expires_at: row.get(6)?,
342 hostname: row.get(7)?,
343 })
344 })
345 .context("Failed to query locks")?
346 .filter_map(|r| r.ok())
347 .collect();
348
349 Ok(locks)
350 }
351
352 pub async fn force_release(&self, lock_id: &str) -> Result<()> {
354 let conn = self.conn.lock().expect("SQLite connection lock poisoned");
355 conn.execute("DELETE FROM locks WHERE lock_id = ?", [lock_id])
356 .context("Failed to force release lock")?;
357 Ok(())
358 }
359
360 fn is_lock_stale(&self, lock: &LockRecord) -> bool {
362 let now = Utc::now().timestamp_millis();
363
364 if let Some(expires_at) = lock.expires_at
366 && now > expires_at
367 {
368 return true;
369 }
370
371 if lock.hostname == self.current_hostname && !Self::is_process_alive(lock.process_id) {
373 return true;
374 }
375
376 false
377 }
378
379 #[cfg(unix)]
381 fn is_process_alive(pid: i32) -> bool {
382 unsafe { libc::kill(pid, 0) == 0 }
385 }
386
387 #[cfg(windows)]
388 fn is_process_alive(pid: i32) -> bool {
389 use windows_sys::Win32::Foundation::{CloseHandle, STILL_ACTIVE};
390 use windows_sys::Win32::System::Threading::{
391 GetExitCodeProcess, OpenProcess, PROCESS_QUERY_LIMITED_INFORMATION,
392 };
393
394 unsafe {
395 let handle = OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, 0, pid as u32);
396 if handle == 0 {
397 return false;
398 }
399
400 let mut exit_code: u32 = 0;
401 let result = GetExitCodeProcess(handle, &mut exit_code);
402 CloseHandle(handle);
403
404 result != 0 && exit_code == STILL_ACTIVE
405 }
406 }
407
408 #[cfg(not(any(unix, windows)))]
409 fn is_process_alive(_pid: i32) -> bool {
410 true
412 }
413
414 pub async fn stats(&self) -> Result<LockStats> {
416 let locks = self.list_locks().await?;
417
418 let mut file_read_locks = 0;
419 let mut file_write_locks = 0;
420 let mut build_locks = 0;
421 let mut test_locks = 0;
422 let mut stale_locks = 0;
423
424 for lock in &locks {
425 match lock.lock_type.as_str() {
426 "file_read" => file_read_locks += 1,
427 "file_write" => file_write_locks += 1,
428 "build" => build_locks += 1,
429 "test" => test_locks += 1,
430 "build_test" => {
431 build_locks += 1;
432 test_locks += 1;
433 }
434 _ => {}
435 }
436
437 if self.is_lock_stale(lock) {
438 stale_locks += 1;
439 }
440 }
441
442 Ok(LockStats {
443 total_locks: locks.len(),
444 file_read_locks,
445 file_write_locks,
446 build_locks,
447 test_locks,
448 stale_locks,
449 })
450 }
451}
452
453#[derive(Debug, Clone)]
455pub struct LockStats {
456 pub total_locks: usize,
458 pub file_read_locks: usize,
460 pub file_write_locks: usize,
462 pub build_locks: usize,
464 pub test_locks: usize,
466 pub stale_locks: usize,
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473 use tempfile::TempDir;
474
475 async fn create_test_store() -> (LockStore, TempDir) {
476 let temp = TempDir::new().unwrap();
477 let db_path = temp.path().join("test_locks.db");
478 let store = LockStore::new_with_path(&db_path).await.unwrap();
479 (store, temp)
480 }
481
482 #[tokio::test]
483 async fn test_acquire_and_release_lock() {
484 let (store, _temp) = create_test_store().await;
485
486 let acquired = store
488 .try_acquire("file_write", "/test/file.txt", "agent-1", None)
489 .await
490 .unwrap();
491 assert!(acquired);
492
493 let lock = store
495 .is_locked("file_write", "/test/file.txt")
496 .await
497 .unwrap();
498 assert!(lock.is_some());
499 assert_eq!(lock.unwrap().agent_id, "agent-1");
500
501 let released = store
503 .release("file_write", "/test/file.txt", "agent-1")
504 .await
505 .unwrap();
506 assert!(released);
507
508 let lock = store
510 .is_locked("file_write", "/test/file.txt")
511 .await
512 .unwrap();
513 assert!(lock.is_none());
514 }
515
516 #[tokio::test]
517 async fn test_idempotent_acquire() {
518 let (store, _temp) = create_test_store().await;
519
520 let acquired1 = store
522 .try_acquire("file_write", "/test/file.txt", "agent-1", None)
523 .await
524 .unwrap();
525 let acquired2 = store
526 .try_acquire("file_write", "/test/file.txt", "agent-1", None)
527 .await
528 .unwrap();
529
530 assert!(acquired1);
531 assert!(acquired2);
532 }
533
534 #[tokio::test]
535 async fn test_lock_conflict() {
536 let (store, _temp) = create_test_store().await;
537
538 let acquired1 = store
540 .try_acquire("file_write", "/test/file.txt", "agent-1", None)
541 .await
542 .unwrap();
543 assert!(acquired1);
544
545 let acquired2 = store
549 .try_acquire("file_write", "/test/file.txt", "agent-2", None)
550 .await
551 .unwrap();
552 assert!(!acquired2);
554 }
555
556 #[tokio::test]
557 async fn test_release_all_for_agent() {
558 let (store, _temp) = create_test_store().await;
559
560 store
562 .try_acquire("file_write", "/test/file1.txt", "agent-1", None)
563 .await
564 .unwrap();
565 store
566 .try_acquire("file_read", "/test/file2.txt", "agent-1", None)
567 .await
568 .unwrap();
569 store
570 .try_acquire("build", "/test/project", "agent-1", None)
571 .await
572 .unwrap();
573
574 let released = store.release_all_for_agent("agent-1").await.unwrap();
576 assert_eq!(released, 3);
577
578 let locks = store.list_locks().await.unwrap();
580 assert!(locks.is_empty());
581 }
582
583 #[tokio::test]
584 async fn test_expired_lock_cleanup() {
585 let (store, _temp) = create_test_store().await;
586
587 store
589 .try_acquire(
590 "file_write",
591 "/test/file.txt",
592 "agent-1",
593 Some(Duration::from_millis(1)),
594 )
595 .await
596 .unwrap();
597
598 tokio::time::sleep(Duration::from_millis(10)).await;
600
601 let cleaned = store.cleanup_stale().await.unwrap();
603 assert_eq!(cleaned, 1);
604
605 let lock = store
607 .is_locked("file_write", "/test/file.txt")
608 .await
609 .unwrap();
610 assert!(lock.is_none());
611 }
612
613 #[tokio::test]
614 async fn test_list_locks() {
615 let (store, _temp) = create_test_store().await;
616
617 store
618 .try_acquire("file_write", "/test/file1.txt", "agent-1", None)
619 .await
620 .unwrap();
621 store
622 .try_acquire("file_read", "/test/file2.txt", "agent-1", None)
623 .await
624 .unwrap();
625
626 let locks = store.list_locks().await.unwrap();
627 assert_eq!(locks.len(), 2);
628 }
629
630 #[tokio::test]
631 async fn test_stats() {
632 let (store, _temp) = create_test_store().await;
633
634 store
635 .try_acquire("file_write", "/test/file1.txt", "agent-1", None)
636 .await
637 .unwrap();
638 store
639 .try_acquire("file_read", "/test/file2.txt", "agent-1", None)
640 .await
641 .unwrap();
642 store
643 .try_acquire("build", "/test/project", "agent-1", None)
644 .await
645 .unwrap();
646
647 let stats = store.stats().await.unwrap();
648 assert_eq!(stats.total_locks, 3);
649 assert_eq!(stats.file_write_locks, 1);
650 assert_eq!(stats.file_read_locks, 1);
651 assert_eq!(stats.build_locks, 1);
652 }
653
654 #[test]
655 fn test_is_process_alive() {
656 let current_pid = std::process::id() as i32;
658 assert!(LockStore::is_process_alive(current_pid));
659
660 #[cfg(unix)]
662 {
663 }
666 }
667}