1use std::path::Path;
2
3use rusqlite::{Connection, params};
4
5use crate::error::HawkError;
6
7pub fn init_database(path: &Path) -> Result<Connection, HawkError> {
8 let conn = Connection::open(path)
9 .map_err(|e| HawkError::Database(e.to_string()))?;
10
11 conn.execute_batch("PRAGMA journal_mode=WAL;")
12 .map_err(|e| HawkError::Database(e.to_string()))?;
13
14 conn.execute_batch(SCHEMA)
15 .map_err(|e| HawkError::Database(e.to_string()))?;
16
17 migrate(&conn)?;
18
19 Ok(conn)
20}
21
22fn migrate(conn: &Connection) -> Result<(), HawkError> {
23 let version: i64 = conn
24 .query_row(
25 "SELECT version FROM schema_version ORDER BY version DESC LIMIT 1",
26 [],
27 |row| row.get(0),
28 )
29 .unwrap_or(0);
30
31 if version < 1 {
32 conn.execute(
33 "INSERT INTO schema_version (version, applied_at) VALUES (?1, datetime('now'))",
34 params![1],
35 )
36 .map_err(|e| HawkError::Database(e.to_string()))?;
37 }
38
39 Ok(())
40}
41
42pub const SCHEMA: &str = "
43CREATE TABLE IF NOT EXISTS schema_version (
44 version INTEGER PRIMARY KEY,
45 applied_at TEXT NOT NULL
46);
47
48CREATE TABLE IF NOT EXISTS agents (
49 pid INTEGER PRIMARY KEY,
50 name TEXT NOT NULL,
51 state TEXT NOT NULL CHECK(state IN ('Starting','Running','Paused','Stopping','Stopped','Failed')),
52 manifest_path TEXT NOT NULL,
53 started_at TEXT NOT NULL,
54 stopped_at TEXT,
55 session_id TEXT NOT NULL
56);
57
58CREATE TABLE IF NOT EXISTS sessions (
59 id TEXT PRIMARY KEY,
60 started_at TEXT NOT NULL,
61 ended_at TEXT,
62 status TEXT NOT NULL DEFAULT 'Active'
63);
64
65CREATE TABLE IF NOT EXISTS session_actions (
66 id INTEGER PRIMARY KEY AUTOINCREMENT,
67 session_id TEXT NOT NULL REFERENCES sessions(id),
68 step_number INTEGER NOT NULL,
69 timestamp TEXT NOT NULL,
70 action_type TEXT NOT NULL,
71 agent_pid INTEGER NOT NULL,
72 payload TEXT NOT NULL,
73 UNIQUE(session_id, step_number)
74);
75
76CREATE TABLE IF NOT EXISTS snapshots (
77 id TEXT PRIMARY KEY,
78 timestamp TEXT NOT NULL,
79 agent_pid INTEGER NOT NULL,
80 task_description TEXT,
81 file_count INTEGER NOT NULL,
82 strategy TEXT NOT NULL,
83 working_dir TEXT NOT NULL,
84 session_id TEXT NOT NULL REFERENCES sessions(id)
85);
86
87CREATE TABLE IF NOT EXISTS snapshot_files (
88 snapshot_id TEXT NOT NULL REFERENCES snapshots(id),
89 file_path TEXT NOT NULL,
90 hash TEXT NOT NULL,
91 size_bytes INTEGER NOT NULL,
92 PRIMARY KEY (snapshot_id, file_path)
93);
94
95CREATE TABLE IF NOT EXISTS token_usage (
96 id INTEGER PRIMARY KEY AUTOINCREMENT,
97 agent_pid INTEGER NOT NULL,
98 timestamp TEXT NOT NULL,
99 provider TEXT NOT NULL,
100 prompt_tokens INTEGER NOT NULL,
101 completion_tokens INTEGER NOT NULL,
102 estimated_cost REAL
103);
104
105CREATE TABLE IF NOT EXISTS verification_reports (
106 id INTEGER PRIMARY KEY AUTOINCREMENT,
107 session_id TEXT NOT NULL REFERENCES sessions(id),
108 timestamp TEXT NOT NULL,
109 overall_status TEXT NOT NULL,
110 report_json TEXT NOT NULL
111);
112
113CREATE TABLE IF NOT EXISTS healing_events (
114 id INTEGER PRIMARY KEY AUTOINCREMENT,
115 agent_pid INTEGER NOT NULL,
116 timestamp TEXT NOT NULL,
117 original_error TEXT NOT NULL,
118 adjustment TEXT NOT NULL,
119 outcome TEXT NOT NULL,
120 attempt_number INTEGER NOT NULL
121);
122
123CREATE TABLE IF NOT EXISTS patterns (
124 id TEXT PRIMARY KEY,
125 action_sequence TEXT NOT NULL,
126 occurrence_count INTEGER NOT NULL DEFAULT 0,
127 last_occurrence TEXT NOT NULL,
128 status TEXT NOT NULL DEFAULT 'Detected',
129 created_at TEXT NOT NULL,
130 expires_at TEXT NOT NULL
131);
132
133CREATE TABLE IF NOT EXISTS installed_packages (
134 name TEXT PRIMARY KEY,
135 version TEXT NOT NULL,
136 package_type TEXT NOT NULL,
137 signature TEXT NOT NULL,
138 installed_at TEXT NOT NULL,
139 capabilities TEXT
140);
141
142CREATE TABLE IF NOT EXISTS sync_peers (
143 device_id TEXT PRIMARY KEY,
144 last_sync TEXT,
145 status TEXT NOT NULL DEFAULT 'Disconnected',
146 pending_changes INTEGER NOT NULL DEFAULT 0
147);
148
149CREATE TABLE IF NOT EXISTS sync_queue (
150 id INTEGER PRIMARY KEY AUTOINCREMENT,
151 target_device TEXT NOT NULL,
152 data_type TEXT NOT NULL,
153 data_key TEXT NOT NULL,
154 payload BLOB NOT NULL,
155 created_at TEXT NOT NULL,
156 synced_at TEXT
157);
158
159CREATE TABLE IF NOT EXISTS bus_queue (
160 id INTEGER PRIMARY KEY AUTOINCREMENT,
161 target_pid INTEGER NOT NULL,
162 topic TEXT,
163 message_json TEXT NOT NULL,
164 created_at TEXT NOT NULL,
165 expires_at TEXT NOT NULL,
166 delivered INTEGER NOT NULL DEFAULT 0
167);
168
169CREATE TABLE IF NOT EXISTS watch_alerts (
170 id INTEGER PRIMARY KEY AUTOINCREMENT,
171 alert_type TEXT NOT NULL,
172 timestamp TEXT NOT NULL,
173 agent_name TEXT,
174 details_json TEXT NOT NULL,
175 acknowledged INTEGER NOT NULL DEFAULT 0
176);
177";
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use tempfile::NamedTempFile;
183
184 fn temp_db() -> (NamedTempFile, Connection) {
185 let f = NamedTempFile::new().unwrap();
186 let conn = init_database(f.path()).unwrap();
187 (f, conn)
188 }
189
190 #[test]
191 fn test_all_tables_created() {
192 let (_f, conn) = temp_db();
193
194 let expected = [
195 "schema_version",
196 "agents",
197 "sessions",
198 "session_actions",
199 "snapshots",
200 "snapshot_files",
201 "token_usage",
202 "verification_reports",
203 "healing_events",
204 "patterns",
205 "installed_packages",
206 "sync_peers",
207 "sync_queue",
208 "bus_queue",
209 "watch_alerts",
210 ];
211
212 for table in &expected {
213 let count: i64 = conn
214 .query_row(
215 "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?1",
216 rusqlite::params![table],
217 |row| row.get(0),
218 )
219 .unwrap();
220 assert_eq!(count, 1, "missing table: {table}");
221 }
222 }
223
224 #[test]
225 fn test_idempotent_reinit() {
226 let f = NamedTempFile::new().unwrap();
227 init_database(f.path()).unwrap();
228 let conn = init_database(f.path()).unwrap();
229
230 let count: i64 = conn
231 .query_row(
232 "SELECT COUNT(*) FROM sqlite_master WHERE type='table'",
233 [],
234 |row| row.get(0),
235 )
236 .unwrap();
237 assert!(count > 0);
238 }
239
240 #[test]
241 fn test_wal_mode_enabled() {
242 let (_f, conn) = temp_db();
243
244 let mode: String = conn
245 .query_row("PRAGMA journal_mode", [], |row| row.get(0))
246 .unwrap();
247 assert_eq!(mode, "wal");
248 }
249}