use rusqlite::{params, Connection};
use std::collections::HashMap;
use std::path::Path;
use std::sync::Mutex;
use sysinfo::{Pid, ProcessRefreshKind, System};
#[derive(Debug, Clone)]
pub struct TrackedEntry {
pub pid: u32,
pub created_at_ms: u64,
pub kind: String,
pub command: String,
pub cwd: String,
pub originator: String,
pub containment: String,
pub registered_at: f64,
}
pub struct Registry {
processes: Mutex<HashMap<(u32, u64), TrackedEntry>>,
by_pid: Mutex<HashMap<u32, u64>>,
db: Mutex<Connection>,
}
pub fn created_at_to_ms(created_at: f64) -> u64 {
(created_at * 1000.0) as u64
}
type RegistryMaps = (HashMap<(u32, u64), TrackedEntry>, HashMap<u32, u64>);
fn load_from_db(conn: &Connection) -> Result<RegistryMaps, rusqlite::Error> {
let mut stmt = conn.prepare(
"SELECT pid, created_at_ms, kind, command, cwd, originator, containment, registered_at \
FROM tracked_processes",
)?;
let mut processes: HashMap<(u32, u64), TrackedEntry> = HashMap::new();
let mut by_pid: HashMap<u32, u64> = HashMap::new();
let rows = stmt.query_map([], |row| {
Ok(TrackedEntry {
pid: row.get(0)?,
created_at_ms: row.get(1)?,
kind: row.get(2)?,
command: row.get(3)?,
cwd: row.get(4)?,
originator: row.get(5)?,
containment: row.get(6)?,
registered_at: row.get(7)?,
})
})?;
for row in rows {
let entry = row?;
let key = (entry.pid, entry.created_at_ms);
by_pid.insert(entry.pid, entry.created_at_ms);
processes.insert(key, entry);
}
Ok((processes, by_pid))
}
impl Registry {
pub fn open(db_path: &Path) -> Result<Self, rusqlite::Error> {
if let Some(parent) = db_path.parent() {
let _ = std::fs::create_dir_all(parent);
}
let conn = Connection::open(db_path)?;
conn.pragma_update(None, "journal_mode", "WAL")?;
conn.execute_batch(
"CREATE TABLE IF NOT EXISTS tracked_processes (
pid INTEGER NOT NULL,
created_at_ms INTEGER NOT NULL,
kind TEXT NOT NULL,
command TEXT NOT NULL,
cwd TEXT NOT NULL DEFAULT '',
originator TEXT NOT NULL DEFAULT '',
containment TEXT NOT NULL DEFAULT 'contained',
registered_at REAL NOT NULL,
PRIMARY KEY (pid, created_at_ms)
);",
)?;
let (mut processes, mut by_pid) = load_from_db(&conn)?;
let mut stale_keys = Vec::new();
let mut system = System::new();
for (&(pid, created_at_ms), _entry) in processes.iter() {
let sysinfo_pid = Pid::from_u32(pid);
system.refresh_process_specifics(sysinfo_pid, ProcessRefreshKind::new());
match system.process(sysinfo_pid) {
Some(proc) => {
let proc_start_ms = proc.start_time() * 1000;
if proc_start_ms.abs_diff(created_at_ms) > 2000 {
stale_keys.push((pid, created_at_ms)); }
}
None => {
stale_keys.push((pid, created_at_ms)); }
}
}
for key in &stale_keys {
processes.remove(key);
by_pid.remove(&key.0);
conn.execute(
"DELETE FROM tracked_processes WHERE pid = ?1 AND created_at_ms = ?2",
params![key.0, key.1],
)
.ok();
}
tracing::info!(
"registry recovery: loaded {} processes, purged {} stale",
processes.len(),
stale_keys.len()
);
Ok(Self {
processes: Mutex::new(processes),
by_pid: Mutex::new(by_pid),
db: Mutex::new(conn),
})
}
pub fn register(&self, entry: TrackedEntry) -> Result<(), rusqlite::Error> {
let mut by_pid = self.by_pid.lock().unwrap();
let mut processes = self.processes.lock().unwrap();
let db = self.db.lock().unwrap();
if let Some(&old_created) = by_pid.get(&entry.pid) {
if old_created != entry.created_at_ms {
tracing::warn!(
pid = entry.pid,
old_created_at_ms = old_created,
new_created_at_ms = entry.created_at_ms,
"PID reuse detected, replacing old entry"
);
processes.remove(&(entry.pid, old_created));
db.execute(
"DELETE FROM tracked_processes WHERE pid = ?1 AND created_at_ms = ?2",
params![entry.pid, old_created],
)?;
}
}
db.execute(
"INSERT OR REPLACE INTO tracked_processes \
(pid, created_at_ms, kind, command, cwd, originator, containment, registered_at) \
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
params![
entry.pid,
entry.created_at_ms,
entry.kind,
entry.command,
entry.cwd,
entry.originator,
entry.containment,
entry.registered_at,
],
)?;
tracing::debug!(
pid = entry.pid,
created_at_ms = entry.created_at_ms,
command = %entry.command,
"registered process"
);
by_pid.insert(entry.pid, entry.created_at_ms);
processes.insert((entry.pid, entry.created_at_ms), entry);
Ok(())
}
pub fn unregister(&self, pid: u32) -> bool {
let mut by_pid = self.by_pid.lock().unwrap();
let mut processes = self.processes.lock().unwrap();
let db = self.db.lock().unwrap();
let Some(created_at_ms) = by_pid.remove(&pid) else {
return false;
};
processes.remove(&(pid, created_at_ms));
db.execute(
"DELETE FROM tracked_processes WHERE pid = ?1 AND created_at_ms = ?2",
params![pid, created_at_ms],
)
.ok();
true
}
pub fn unregister_exact(&self, pid: u32, created_at_ms: u64) -> bool {
let mut by_pid = self.by_pid.lock().unwrap();
let mut processes = self.processes.lock().unwrap();
let db = self.db.lock().unwrap();
let removed = processes.remove(&(pid, created_at_ms)).is_some();
if !removed {
return false;
}
if by_pid.get(&pid) == Some(&created_at_ms) {
by_pid.remove(&pid);
}
db.execute(
"DELETE FROM tracked_processes WHERE pid = ?1 AND created_at_ms = ?2",
params![pid, created_at_ms],
)
.ok();
true
}
pub fn list_all(&self) -> Vec<TrackedEntry> {
let processes = self.processes.lock().unwrap();
let mut entries: Vec<TrackedEntry> = processes.values().cloned().collect();
entries.sort_by(|a, b| {
a.registered_at
.partial_cmp(&b.registered_at)
.unwrap_or(std::cmp::Ordering::Equal)
});
entries
}
pub fn list_by_originator(&self, tool: &str) -> Vec<TrackedEntry> {
let prefix = format!("{tool}:");
self.list_all()
.into_iter()
.filter(|e| e.originator.starts_with(&prefix))
.collect()
}
pub fn count(&self) -> usize {
self.processes.lock().unwrap().len()
}
pub fn validate_against_os(&self) {
let mut system = System::new();
let mut stale_keys = Vec::new();
{
let processes = self.processes.lock().unwrap();
for (&(pid, created_at_ms), _entry) in processes.iter() {
let sysinfo_pid = Pid::from_u32(pid);
system.refresh_process_specifics(sysinfo_pid, ProcessRefreshKind::new());
match system.process(sysinfo_pid) {
Some(proc) => {
let proc_start_ms = proc.start_time() * 1000;
if proc_start_ms.abs_diff(created_at_ms) > 2000 {
stale_keys.push((pid, created_at_ms));
}
}
None => {
stale_keys.push((pid, created_at_ms));
}
}
}
}
if stale_keys.is_empty() {
return;
}
let mut by_pid = self.by_pid.lock().unwrap();
let mut processes = self.processes.lock().unwrap();
let db = self.db.lock().unwrap();
for &(pid, created_at_ms) in &stale_keys {
processes.remove(&(pid, created_at_ms));
if by_pid.get(&pid) == Some(&created_at_ms) {
by_pid.remove(&pid);
}
db.execute(
"DELETE FROM tracked_processes WHERE pid = ?1 AND created_at_ms = ?2",
params![pid, created_at_ms],
)
.ok();
tracing::info!(pid, created_at_ms, "purged stale process from registry");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn make_entry(pid: u32, created_at_ms: u64, command: &str) -> TrackedEntry {
TrackedEntry {
pid,
created_at_ms,
kind: "subprocess".to_string(),
command: command.to_string(),
cwd: "/tmp".to_string(),
originator: "test:unit".to_string(),
containment: "contained".to_string(),
registered_at: created_at_ms as f64 / 1000.0,
}
}
#[test]
fn test_register_and_list() {
let tmp = TempDir::new().unwrap();
let db = tmp.path().join("test.db");
let reg = Registry::open(&db).unwrap();
reg.register(make_entry(1, 1000, "cmd1")).unwrap();
reg.register(make_entry(2, 2000, "cmd2")).unwrap();
reg.register(make_entry(3, 3000, "cmd3")).unwrap();
let all = reg.list_all();
assert_eq!(all.len(), 3);
assert_eq!(reg.count(), 3);
}
#[test]
fn test_unregister_removes() {
let tmp = TempDir::new().unwrap();
let db = tmp.path().join("test.db");
let reg = Registry::open(&db).unwrap();
reg.register(make_entry(42, 5000, "ls -la")).unwrap();
assert_eq!(reg.count(), 1);
let removed = reg.unregister(42);
assert!(removed);
assert_eq!(reg.count(), 0);
assert_eq!(reg.list_all().len(), 0);
let removed_again = reg.unregister(42);
assert!(!removed_again);
}
#[test]
fn test_unregister_exact_preserves_newer_pid_reuse() {
let tmp = TempDir::new().unwrap();
let db = tmp.path().join("test.db");
let reg = Registry::open(&db).unwrap();
reg.register(make_entry(77, 1000, "old")).unwrap();
reg.register(make_entry(77, 2000, "new")).unwrap();
assert!(!reg.unregister_exact(77, 1000));
let all = reg.list_all();
assert_eq!(all.len(), 1);
assert_eq!(all[0].created_at_ms, 2000);
assert!(reg.unregister_exact(77, 2000));
assert_eq!(reg.count(), 0);
}
#[test]
fn test_pid_reuse_replaces_old() {
let tmp = TempDir::new().unwrap();
let db = tmp.path().join("test.db");
let reg = Registry::open(&db).unwrap();
reg.register(make_entry(100, 1000, "old-cmd")).unwrap();
assert_eq!(reg.count(), 1);
reg.register(make_entry(100, 2000, "new-cmd")).unwrap();
assert_eq!(reg.count(), 1);
let all = reg.list_all();
assert_eq!(all.len(), 1);
assert_eq!(all[0].created_at_ms, 2000);
assert_eq!(all[0].command, "new-cmd");
}
#[test]
fn test_list_by_originator_filters() {
let tmp = TempDir::new().unwrap();
let db = tmp.path().join("test.db");
let reg = Registry::open(&db).unwrap();
let mut e1 = make_entry(1, 1000, "cmd1");
e1.originator = "codeup:abc".to_string();
let mut e2 = make_entry(2, 2000, "cmd2");
e2.originator = "codeup:def".to_string();
let mut e3 = make_entry(3, 3000, "cmd3");
e3.originator = "other:xyz".to_string();
reg.register(e1).unwrap();
reg.register(e2).unwrap();
reg.register(e3).unwrap();
let codeup = reg.list_by_originator("codeup");
assert_eq!(codeup.len(), 2);
let other = reg.list_by_originator("other");
assert_eq!(other.len(), 1);
let none = reg.list_by_originator("nonexistent");
assert_eq!(none.len(), 0);
}
#[test]
fn test_persistence_across_reopen() {
let tmp = TempDir::new().unwrap();
let db = tmp.path().join("test.db");
{
let reg = Registry::open(&db).unwrap();
let pid = std::process::id();
let mut sys = System::new();
let sysinfo_pid = Pid::from_u32(pid);
sys.refresh_process_specifics(sysinfo_pid, ProcessRefreshKind::new());
let created_at_ms = sys
.process(sysinfo_pid)
.map(|p| p.start_time() * 1000)
.unwrap_or(0);
reg.register(TrackedEntry {
pid,
created_at_ms,
kind: "subprocess".to_string(),
command: "persist-test".to_string(),
cwd: "/tmp".to_string(),
originator: "test:persist".to_string(),
containment: "contained".to_string(),
registered_at: 1000.0,
})
.unwrap();
assert_eq!(reg.count(), 1);
}
{
let reg = Registry::open(&db).unwrap();
let all = reg.list_all();
assert_eq!(all.len(), 1);
assert_eq!(all[0].command, "persist-test");
}
}
#[test]
fn test_created_at_to_ms() {
assert_eq!(created_at_to_ms(1234.567), 1234567);
assert_eq!(created_at_to_ms(0.0), 0);
assert_eq!(created_at_to_ms(1.0), 1000);
}
}