Skip to main content

brainwires_stores/
lock_store.rs

1//! SQLite-backed persistent lock storage for inter-process coordination
2//!
3//! Enables multiple brainwires-cli instances to coordinate file access
4//! and build/test operations through a shared SQLite database.
5//!
6//! SQLite provides ACID compliance and immediate consistency, making it
7//! ideal for lock coordination where eventual consistency would cause bugs.
8
9use anyhow::{Context, Result};
10use chrono::Utc;
11use rusqlite::{Connection, OptionalExtension, params};
12use std::path::PathBuf;
13use std::sync::Mutex;
14use std::time::Duration;
15
16/// Record representing a lock in the database
17#[derive(Debug, Clone)]
18pub struct LockRecord {
19    /// Unique lock identifier
20    pub lock_id: String,
21    /// Type of lock: "file_read", "file_write", "build", "test", "build_test"
22    pub lock_type: String,
23    /// Resource being locked (file path or project path)
24    pub resource_path: String,
25    /// ID of the agent holding the lock
26    pub agent_id: String,
27    /// Process ID for stale lock detection
28    pub process_id: i32,
29    /// When the lock was acquired (Unix timestamp in milliseconds)
30    pub acquired_at: i64,
31    /// When the lock expires (optional, Unix timestamp in milliseconds)
32    pub expires_at: Option<i64>,
33    /// Hostname of the machine holding the lock
34    pub hostname: String,
35}
36
37/// SQLite-backed persistent lock storage
38pub struct LockStore {
39    /// SQLite connection (wrapped in Mutex for thread safety)
40    conn: Mutex<Connection>,
41    /// Current process ID (cached for efficiency)
42    current_pid: i32,
43    /// Current hostname (cached for efficiency)
44    current_hostname: String,
45}
46
47impl LockStore {
48    /// Create a new lock store with default database path (~/.brainwires/locks.db)
49    pub async fn new_default() -> Result<Self> {
50        let db_path = Self::default_db_path()?;
51        Self::new_with_path(&db_path).await
52    }
53
54    /// Get the default database path
55    fn default_db_path() -> Result<PathBuf> {
56        let home = dirs::home_dir().context("Could not determine home directory")?;
57        let brainwires_dir = home.join(".brainwires");
58        std::fs::create_dir_all(&brainwires_dir)
59            .context("Failed to create ~/.brainwires directory")?;
60        Ok(brainwires_dir.join("locks.db"))
61    }
62
63    /// Create a new lock store with a custom database path
64    pub async fn new_with_path(db_path: &PathBuf) -> Result<Self> {
65        let current_pid = std::process::id() as i32;
66        let current_hostname = gethostname::gethostname().to_string_lossy().to_string();
67
68        // Open SQLite connection with WAL mode for better concurrent access
69        let conn = Connection::open(db_path)
70            .with_context(|| format!("Failed to open lock database at {:?}", db_path))?;
71
72        // Enable WAL mode for better concurrent read/write performance
73        conn.execute_batch(
74            "PRAGMA journal_mode=WAL;
75             PRAGMA busy_timeout=5000;
76             PRAGMA synchronous=NORMAL;",
77        )
78        .context("Failed to configure SQLite")?;
79
80        let store = Self {
81            conn: Mutex::new(conn),
82            current_pid,
83            current_hostname,
84        };
85
86        // Ensure the locks table exists
87        store.ensure_table()?;
88
89        Ok(store)
90    }
91
92    /// Ensure the locks table exists
93    fn ensure_table(&self) -> Result<()> {
94        let conn = self.conn.lock().expect("SQLite connection lock poisoned");
95        conn.execute(
96            "CREATE TABLE IF NOT EXISTS locks (
97                lock_id TEXT PRIMARY KEY,
98                lock_type TEXT NOT NULL,
99                resource_path TEXT NOT NULL,
100                agent_id TEXT NOT NULL,
101                process_id INTEGER NOT NULL,
102                acquired_at INTEGER NOT NULL,
103                expires_at INTEGER,
104                hostname TEXT NOT NULL
105            )",
106            [],
107        )
108        .context("Failed to create locks table")?;
109
110        // Create index for faster queries
111        conn.execute(
112            "CREATE INDEX IF NOT EXISTS idx_locks_agent ON locks(agent_id, process_id, hostname)",
113            [],
114        )
115        .context("Failed to create locks index")?;
116
117        Ok(())
118    }
119
120    /// Generate a unique lock ID
121    fn generate_lock_id(lock_type: &str, resource_path: &str) -> String {
122        format!("{}:{}", lock_type, resource_path)
123    }
124
125    /// Try to acquire a lock. Returns true if acquired, false if already held by another.
126    pub async fn try_acquire(
127        &self,
128        lock_type: &str,
129        resource_path: &str,
130        agent_id: &str,
131        timeout: Option<Duration>,
132    ) -> Result<bool> {
133        let lock_id = Self::generate_lock_id(lock_type, resource_path);
134        let conn = self.conn.lock().expect("SQLite connection lock poisoned");
135
136        // Check if lock already exists
137        let existing: Option<LockRecord> = conn
138            .query_row(
139                "SELECT lock_id, lock_type, resource_path, agent_id, process_id,
140                        acquired_at, expires_at, hostname
141                 FROM locks WHERE lock_id = ?",
142                [&lock_id],
143                |row| {
144                    Ok(LockRecord {
145                        lock_id: row.get(0)?,
146                        lock_type: row.get(1)?,
147                        resource_path: row.get(2)?,
148                        agent_id: row.get(3)?,
149                        process_id: row.get(4)?,
150                        acquired_at: row.get(5)?,
151                        expires_at: row.get(6)?,
152                        hostname: row.get(7)?,
153                    })
154                },
155            )
156            .ok();
157
158        if let Some(ref existing) = existing {
159            // If held by same agent in same process, allow (idempotent)
160            if existing.agent_id == agent_id
161                && existing.process_id == self.current_pid
162                && existing.hostname == self.current_hostname
163            {
164                return Ok(true);
165            }
166
167            // Check if the lock is stale
168            if self.is_lock_stale(existing) {
169                // Remove stale lock and proceed
170                conn.execute("DELETE FROM locks WHERE lock_id = ?", [&lock_id])
171                    .context("Failed to remove stale lock")?;
172            } else {
173                // Lock is held by another active process
174                return Ok(false);
175            }
176        }
177
178        // Acquire the lock
179        let now = Utc::now().timestamp_millis();
180        let expires_at = timeout.map(|t| now + t.as_millis() as i64);
181
182        conn.execute(
183            "INSERT OR REPLACE INTO locks
184             (lock_id, lock_type, resource_path, agent_id, process_id, acquired_at, expires_at, hostname)
185             VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
186            params![
187                lock_id,
188                lock_type,
189                resource_path,
190                agent_id,
191                self.current_pid,
192                now,
193                expires_at,
194                self.current_hostname,
195            ],
196        )
197        .context("Failed to acquire lock")?;
198
199        Ok(true)
200    }
201
202    /// Release a lock
203    pub async fn release(
204        &self,
205        lock_type: &str,
206        resource_path: &str,
207        agent_id: &str,
208    ) -> Result<bool> {
209        let lock_id = Self::generate_lock_id(lock_type, resource_path);
210        let conn = self.conn.lock().expect("SQLite connection lock poisoned");
211
212        // Delete only if owned by this agent in this process
213        let deleted = conn.execute(
214            "DELETE FROM locks WHERE lock_id = ? AND agent_id = ? AND process_id = ? AND hostname = ?",
215            params![lock_id, agent_id, self.current_pid, self.current_hostname],
216        ).context("Failed to release lock")?;
217
218        Ok(deleted > 0)
219    }
220
221    /// Release all locks held by a specific agent in the current process
222    pub async fn release_all_for_agent(&self, agent_id: &str) -> Result<usize> {
223        let conn = self.conn.lock().expect("SQLite connection lock poisoned");
224
225        let deleted = conn
226            .execute(
227                "DELETE FROM locks WHERE agent_id = ? AND process_id = ? AND hostname = ?",
228                params![agent_id, self.current_pid, self.current_hostname],
229            )
230            .context("Failed to release agent locks")?;
231
232        Ok(deleted)
233    }
234
235    /// Check if a lock is held and by whom
236    pub async fn is_locked(
237        &self,
238        lock_type: &str,
239        resource_path: &str,
240    ) -> Result<Option<LockRecord>> {
241        let lock_id = Self::generate_lock_id(lock_type, resource_path);
242        let conn = self.conn.lock().expect("SQLite connection lock poisoned");
243
244        conn.query_row(
245            "SELECT lock_id, lock_type, resource_path, agent_id, process_id,
246                    acquired_at, expires_at, hostname
247             FROM locks WHERE lock_id = ?",
248            [&lock_id],
249            |row| {
250                Ok(LockRecord {
251                    lock_id: row.get(0)?,
252                    lock_type: row.get(1)?,
253                    resource_path: row.get(2)?,
254                    agent_id: row.get(3)?,
255                    process_id: row.get(4)?,
256                    acquired_at: row.get(5)?,
257                    expires_at: row.get(6)?,
258                    hostname: row.get(7)?,
259                })
260            },
261        )
262        .optional()
263        .context("Failed to check lock status")
264    }
265
266    /// Cleanup expired and stale locks
267    pub async fn cleanup_stale(&self) -> Result<usize> {
268        let now = Utc::now().timestamp_millis();
269        let conn = self.conn.lock().expect("SQLite connection lock poisoned");
270
271        // First, delete expired locks
272        let expired_count = conn
273            .execute(
274                "DELETE FROM locks WHERE expires_at IS NOT NULL AND expires_at < ?",
275                [now],
276            )
277            .context("Failed to cleanup expired locks")?;
278
279        // Then, get remaining locks to check for dead processes
280        let mut stmt = conn
281            .prepare(
282                "SELECT lock_id, lock_type, resource_path, agent_id, process_id,
283                        acquired_at, expires_at, hostname
284                 FROM locks WHERE hostname = ?",
285            )
286            .context("Failed to prepare stale lock query")?;
287
288        let locks: Vec<LockRecord> = stmt
289            .query_map([&self.current_hostname], |row| {
290                Ok(LockRecord {
291                    lock_id: row.get(0)?,
292                    lock_type: row.get(1)?,
293                    resource_path: row.get(2)?,
294                    agent_id: row.get(3)?,
295                    process_id: row.get(4)?,
296                    acquired_at: row.get(5)?,
297                    expires_at: row.get(6)?,
298                    hostname: row.get(7)?,
299                })
300            })
301            .context("Failed to query locks")?
302            .filter_map(|r| r.ok())
303            .collect();
304
305        drop(stmt);
306
307        // Delete locks from dead processes
308        let mut stale_count = 0;
309        for lock in locks {
310            if !Self::is_process_alive(lock.process_id) {
311                conn.execute("DELETE FROM locks WHERE lock_id = ?", [&lock.lock_id])
312                    .ok();
313                stale_count += 1;
314            }
315        }
316
317        Ok(expired_count + stale_count)
318    }
319
320    /// List all active locks
321    pub async fn list_locks(&self) -> Result<Vec<LockRecord>> {
322        let conn = self.conn.lock().expect("SQLite connection lock poisoned");
323
324        let mut stmt = conn
325            .prepare(
326                "SELECT lock_id, lock_type, resource_path, agent_id, process_id,
327                        acquired_at, expires_at, hostname
328                 FROM locks",
329            )
330            .context("Failed to prepare list locks query")?;
331
332        let locks = stmt
333            .query_map([], |row| {
334                Ok(LockRecord {
335                    lock_id: row.get(0)?,
336                    lock_type: row.get(1)?,
337                    resource_path: row.get(2)?,
338                    agent_id: row.get(3)?,
339                    process_id: row.get(4)?,
340                    acquired_at: row.get(5)?,
341                    expires_at: row.get(6)?,
342                    hostname: row.get(7)?,
343                })
344            })
345            .context("Failed to query locks")?
346            .filter_map(|r| r.ok())
347            .collect();
348
349        Ok(locks)
350    }
351
352    /// Force release a lock by ID (admin operation)
353    pub async fn force_release(&self, lock_id: &str) -> Result<()> {
354        let conn = self.conn.lock().expect("SQLite connection lock poisoned");
355        conn.execute("DELETE FROM locks WHERE lock_id = ?", [lock_id])
356            .context("Failed to force release lock")?;
357        Ok(())
358    }
359
360    /// Check if a lock is stale (expired or from dead process)
361    fn is_lock_stale(&self, lock: &LockRecord) -> bool {
362        let now = Utc::now().timestamp_millis();
363
364        // Check if expired
365        if let Some(expires_at) = lock.expires_at
366            && now > expires_at
367        {
368            return true;
369        }
370
371        // Check if process is dead (only if same hostname)
372        if lock.hostname == self.current_hostname && !Self::is_process_alive(lock.process_id) {
373            return true;
374        }
375
376        false
377    }
378
379    /// Check if a process is still running
380    #[cfg(unix)]
381    fn is_process_alive(pid: i32) -> bool {
382        // On Unix, we can use kill with signal 0 to check if process exists
383        // This doesn't actually send a signal, just checks if process exists
384        unsafe { libc::kill(pid, 0) == 0 }
385    }
386
387    #[cfg(windows)]
388    fn is_process_alive(pid: i32) -> bool {
389        use windows_sys::Win32::Foundation::{CloseHandle, STILL_ACTIVE};
390        use windows_sys::Win32::System::Threading::{
391            GetExitCodeProcess, OpenProcess, PROCESS_QUERY_LIMITED_INFORMATION,
392        };
393
394        unsafe {
395            let handle = OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, 0, pid as u32);
396            if handle == 0 {
397                return false;
398            }
399
400            let mut exit_code: u32 = 0;
401            let result = GetExitCodeProcess(handle, &mut exit_code);
402            CloseHandle(handle);
403
404            result != 0 && exit_code == STILL_ACTIVE
405        }
406    }
407
408    #[cfg(not(any(unix, windows)))]
409    fn is_process_alive(_pid: i32) -> bool {
410        // On other platforms, assume process is alive to be safe
411        true
412    }
413
414    /// Get lock statistics
415    pub async fn stats(&self) -> Result<LockStats> {
416        let locks = self.list_locks().await?;
417
418        let mut file_read_locks = 0;
419        let mut file_write_locks = 0;
420        let mut build_locks = 0;
421        let mut test_locks = 0;
422        let mut stale_locks = 0;
423
424        for lock in &locks {
425            match lock.lock_type.as_str() {
426                "file_read" => file_read_locks += 1,
427                "file_write" => file_write_locks += 1,
428                "build" => build_locks += 1,
429                "test" => test_locks += 1,
430                "build_test" => {
431                    build_locks += 1;
432                    test_locks += 1;
433                }
434                _ => {}
435            }
436
437            if self.is_lock_stale(lock) {
438                stale_locks += 1;
439            }
440        }
441
442        Ok(LockStats {
443            total_locks: locks.len(),
444            file_read_locks,
445            file_write_locks,
446            build_locks,
447            test_locks,
448            stale_locks,
449        })
450    }
451}
452
453/// Statistics about current locks
454#[derive(Debug, Clone)]
455pub struct LockStats {
456    /// Total number of locks
457    pub total_locks: usize,
458    /// Number of file read locks
459    pub file_read_locks: usize,
460    /// Number of file write locks
461    pub file_write_locks: usize,
462    /// Number of build locks
463    pub build_locks: usize,
464    /// Number of test locks
465    pub test_locks: usize,
466    /// Number of stale locks (expired or from dead processes)
467    pub stale_locks: usize,
468}
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473    use tempfile::TempDir;
474
475    async fn create_test_store() -> (LockStore, TempDir) {
476        let temp = TempDir::new().unwrap();
477        let db_path = temp.path().join("test_locks.db");
478        let store = LockStore::new_with_path(&db_path).await.unwrap();
479        (store, temp)
480    }
481
482    #[tokio::test]
483    async fn test_acquire_and_release_lock() {
484        let (store, _temp) = create_test_store().await;
485
486        // Acquire lock
487        let acquired = store
488            .try_acquire("file_write", "/test/file.txt", "agent-1", None)
489            .await
490            .unwrap();
491        assert!(acquired);
492
493        // Verify lock exists
494        let lock = store
495            .is_locked("file_write", "/test/file.txt")
496            .await
497            .unwrap();
498        assert!(lock.is_some());
499        assert_eq!(lock.unwrap().agent_id, "agent-1");
500
501        // Release lock
502        let released = store
503            .release("file_write", "/test/file.txt", "agent-1")
504            .await
505            .unwrap();
506        assert!(released);
507
508        // Verify lock is gone
509        let lock = store
510            .is_locked("file_write", "/test/file.txt")
511            .await
512            .unwrap();
513        assert!(lock.is_none());
514    }
515
516    #[tokio::test]
517    async fn test_idempotent_acquire() {
518        let (store, _temp) = create_test_store().await;
519
520        // Acquire lock twice - should succeed both times
521        let acquired1 = store
522            .try_acquire("file_write", "/test/file.txt", "agent-1", None)
523            .await
524            .unwrap();
525        let acquired2 = store
526            .try_acquire("file_write", "/test/file.txt", "agent-1", None)
527            .await
528            .unwrap();
529
530        assert!(acquired1);
531        assert!(acquired2);
532    }
533
534    #[tokio::test]
535    async fn test_lock_conflict() {
536        let (store, _temp) = create_test_store().await;
537
538        // Acquire lock as agent-1
539        let acquired1 = store
540            .try_acquire("file_write", "/test/file.txt", "agent-1", None)
541            .await
542            .unwrap();
543        assert!(acquired1);
544
545        // Try to acquire as agent-2 - should fail (same process, so not stale)
546        // Note: In the same process, locks from different agents will conflict
547        // because they have the same PID
548        let acquired2 = store
549            .try_acquire("file_write", "/test/file.txt", "agent-2", None)
550            .await
551            .unwrap();
552        // In same process, different agent, same PID - this will fail
553        assert!(!acquired2);
554    }
555
556    #[tokio::test]
557    async fn test_release_all_for_agent() {
558        let (store, _temp) = create_test_store().await;
559
560        // Acquire multiple locks
561        store
562            .try_acquire("file_write", "/test/file1.txt", "agent-1", None)
563            .await
564            .unwrap();
565        store
566            .try_acquire("file_read", "/test/file2.txt", "agent-1", None)
567            .await
568            .unwrap();
569        store
570            .try_acquire("build", "/test/project", "agent-1", None)
571            .await
572            .unwrap();
573
574        // Release all for agent-1
575        let released = store.release_all_for_agent("agent-1").await.unwrap();
576        assert_eq!(released, 3);
577
578        // Verify all locks are gone
579        let locks = store.list_locks().await.unwrap();
580        assert!(locks.is_empty());
581    }
582
583    #[tokio::test]
584    async fn test_expired_lock_cleanup() {
585        let (store, _temp) = create_test_store().await;
586
587        // Acquire lock with very short timeout (already expired)
588        store
589            .try_acquire(
590                "file_write",
591                "/test/file.txt",
592                "agent-1",
593                Some(Duration::from_millis(1)),
594            )
595            .await
596            .unwrap();
597
598        // Wait for expiration
599        tokio::time::sleep(Duration::from_millis(10)).await;
600
601        // Cleanup should remove expired lock
602        let cleaned = store.cleanup_stale().await.unwrap();
603        assert_eq!(cleaned, 1);
604
605        // Lock should be gone
606        let lock = store
607            .is_locked("file_write", "/test/file.txt")
608            .await
609            .unwrap();
610        assert!(lock.is_none());
611    }
612
613    #[tokio::test]
614    async fn test_list_locks() {
615        let (store, _temp) = create_test_store().await;
616
617        store
618            .try_acquire("file_write", "/test/file1.txt", "agent-1", None)
619            .await
620            .unwrap();
621        store
622            .try_acquire("file_read", "/test/file2.txt", "agent-1", None)
623            .await
624            .unwrap();
625
626        let locks = store.list_locks().await.unwrap();
627        assert_eq!(locks.len(), 2);
628    }
629
630    #[tokio::test]
631    async fn test_stats() {
632        let (store, _temp) = create_test_store().await;
633
634        store
635            .try_acquire("file_write", "/test/file1.txt", "agent-1", None)
636            .await
637            .unwrap();
638        store
639            .try_acquire("file_read", "/test/file2.txt", "agent-1", None)
640            .await
641            .unwrap();
642        store
643            .try_acquire("build", "/test/project", "agent-1", None)
644            .await
645            .unwrap();
646
647        let stats = store.stats().await.unwrap();
648        assert_eq!(stats.total_locks, 3);
649        assert_eq!(stats.file_write_locks, 1);
650        assert_eq!(stats.file_read_locks, 1);
651        assert_eq!(stats.build_locks, 1);
652    }
653
654    #[test]
655    fn test_is_process_alive() {
656        // Current process should be alive
657        let current_pid = std::process::id() as i32;
658        assert!(LockStore::is_process_alive(current_pid));
659
660        // PID 0 (init/kernel) should exist on Unix
661        #[cfg(unix)]
662        {
663            // Note: PID 1 (init) should always exist, but we might not have permission
664            // to signal it. PID of current process is a safer test.
665        }
666    }
667}