1use anyhow::Result;
2use rusqlite::Connection;
3
4const CURRENT_SCHEMA_VERSION: i32 = 1;
5
6pub fn initialize(conn: &Connection) -> Result<()> {
7 conn.execute_batch("PRAGMA journal_mode=WAL;")?;
8 conn.execute_batch("PRAGMA foreign_keys=ON;")?;
9
10 conn.execute_batch(
11 "CREATE TABLE IF NOT EXISTS schema_version (
12 version INTEGER NOT NULL
13 );
14
15 CREATE TABLE IF NOT EXISTS sessions (
16 id TEXT PRIMARY KEY,
17 started_at INTEGER NOT NULL,
18 ended_at INTEGER,
19 cwd TEXT NOT NULL,
20 model TEXT,
21 permission_mode TEXT
22 );
23
24 CREATE TABLE IF NOT EXISTS events (
25 id INTEGER PRIMARY KEY AUTOINCREMENT,
26 session_id TEXT NOT NULL REFERENCES sessions(id),
27 timestamp INTEGER NOT NULL,
28 event_type TEXT NOT NULL,
29 tool_name TEXT,
30 tool_use_id TEXT,
31 agent_id TEXT,
32 agent_type TEXT,
33 input_json BLOB,
34 output_json BLOB
35 );
36
37 CREATE TABLE IF NOT EXISTS snapshots (
38 id INTEGER PRIMARY KEY AUTOINCREMENT,
39 event_id INTEGER NOT NULL REFERENCES events(id),
40 file_path TEXT NOT NULL,
41 content_before BLOB,
42 content_after BLOB,
43 diff_unified TEXT NOT NULL
44 );
45
46 CREATE INDEX IF NOT EXISTS idx_events_session_ts ON events(session_id, timestamp);
47 CREATE INDEX IF NOT EXISTS idx_events_tool_use_id ON events(tool_use_id);
48 CREATE INDEX IF NOT EXISTS idx_snapshots_event ON snapshots(event_id);
49 CREATE INDEX IF NOT EXISTS idx_snapshots_file_event ON snapshots(file_path, event_id);",
50 )?;
51
52 let count: i32 = conn.query_row(
53 "SELECT COUNT(*) FROM schema_version",
54 [],
55 |row| row.get(0),
56 )?;
57 if count == 0 {
58 conn.execute(
59 "INSERT INTO schema_version (version) VALUES (?1)",
60 [CURRENT_SCHEMA_VERSION],
61 )?;
62 }
63
64 Ok(())
65}
66
67pub fn get_version(conn: &Connection) -> Result<i32> {
68 let version = conn.query_row(
69 "SELECT version FROM schema_version",
70 [],
71 |row| row.get(0),
72 )?;
73 Ok(version)
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79
80 #[test]
81 fn test_initialize_creates_tables() {
82 let conn = Connection::open_in_memory().unwrap();
83 initialize(&conn).unwrap();
84
85 let tables: Vec<String> = conn
86 .prepare("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
87 .unwrap()
88 .query_map([], |row| row.get(0))
89 .unwrap()
90 .filter_map(|r| r.ok())
91 .collect();
92
93 assert!(tables.contains(&"sessions".to_string()));
94 assert!(tables.contains(&"events".to_string()));
95 assert!(tables.contains(&"snapshots".to_string()));
96 assert!(tables.contains(&"schema_version".to_string()));
97 }
98
99 #[test]
100 fn test_initialize_is_idempotent() {
101 let conn = Connection::open_in_memory().unwrap();
102 initialize(&conn).unwrap();
103 initialize(&conn).unwrap();
104 }
105
106 #[test]
107 fn test_schema_version() {
108 let conn = Connection::open_in_memory().unwrap();
109 initialize(&conn).unwrap();
110 assert_eq!(get_version(&conn).unwrap(), CURRENT_SCHEMA_VERSION);
111 }
112}