Skip to main content

openhawk_core/
db.rs

1use std::path::Path;
2
3use rusqlite::{Connection, params};
4
5use crate::error::HawkError;
6
7pub fn init_database(path: &Path) -> Result<Connection, HawkError> {
8    let conn = Connection::open(path)
9        .map_err(|e| HawkError::Database(e.to_string()))?;
10
11    conn.execute_batch("PRAGMA journal_mode=WAL;")
12        .map_err(|e| HawkError::Database(e.to_string()))?;
13
14    conn.execute_batch(SCHEMA)
15        .map_err(|e| HawkError::Database(e.to_string()))?;
16
17    migrate(&conn)?;
18
19    Ok(conn)
20}
21
22fn migrate(conn: &Connection) -> Result<(), HawkError> {
23    let version: i64 = conn
24        .query_row(
25            "SELECT version FROM schema_version ORDER BY version DESC LIMIT 1",
26            [],
27            |row| row.get(0),
28        )
29        .unwrap_or(0);
30
31    if version < 1 {
32        conn.execute(
33            "INSERT INTO schema_version (version, applied_at) VALUES (?1, datetime('now'))",
34            params![1],
35        )
36        .map_err(|e| HawkError::Database(e.to_string()))?;
37    }
38
39    Ok(())
40}
41
42pub const SCHEMA: &str = "
43CREATE TABLE IF NOT EXISTS schema_version (
44    version     INTEGER PRIMARY KEY,
45    applied_at  TEXT NOT NULL
46);
47
48CREATE TABLE IF NOT EXISTS agents (
49    pid             INTEGER PRIMARY KEY,
50    name            TEXT NOT NULL,
51    state           TEXT NOT NULL CHECK(state IN ('Starting','Running','Paused','Stopping','Stopped','Failed')),
52    manifest_path   TEXT NOT NULL,
53    started_at      TEXT NOT NULL,
54    stopped_at      TEXT,
55    session_id      TEXT NOT NULL
56);
57
58CREATE TABLE IF NOT EXISTS sessions (
59    id          TEXT PRIMARY KEY,
60    started_at  TEXT NOT NULL,
61    ended_at    TEXT,
62    status      TEXT NOT NULL DEFAULT 'Active'
63);
64
65CREATE TABLE IF NOT EXISTS session_actions (
66    id          INTEGER PRIMARY KEY AUTOINCREMENT,
67    session_id  TEXT NOT NULL REFERENCES sessions(id),
68    step_number INTEGER NOT NULL,
69    timestamp   TEXT NOT NULL,
70    action_type TEXT NOT NULL,
71    agent_pid   INTEGER NOT NULL,
72    payload     TEXT NOT NULL,
73    UNIQUE(session_id, step_number)
74);
75
76CREATE TABLE IF NOT EXISTS snapshots (
77    id               TEXT PRIMARY KEY,
78    timestamp        TEXT NOT NULL,
79    agent_pid        INTEGER NOT NULL,
80    task_description TEXT,
81    file_count       INTEGER NOT NULL,
82    strategy         TEXT NOT NULL,
83    working_dir      TEXT NOT NULL,
84    session_id       TEXT NOT NULL REFERENCES sessions(id)
85);
86
87CREATE TABLE IF NOT EXISTS snapshot_files (
88    snapshot_id TEXT NOT NULL REFERENCES snapshots(id),
89    file_path   TEXT NOT NULL,
90    hash        TEXT NOT NULL,
91    size_bytes  INTEGER NOT NULL,
92    PRIMARY KEY (snapshot_id, file_path)
93);
94
95CREATE TABLE IF NOT EXISTS token_usage (
96    id                INTEGER PRIMARY KEY AUTOINCREMENT,
97    agent_pid         INTEGER NOT NULL,
98    timestamp         TEXT NOT NULL,
99    provider          TEXT NOT NULL,
100    prompt_tokens     INTEGER NOT NULL,
101    completion_tokens INTEGER NOT NULL,
102    estimated_cost    REAL
103);
104
105CREATE TABLE IF NOT EXISTS verification_reports (
106    id             INTEGER PRIMARY KEY AUTOINCREMENT,
107    session_id     TEXT NOT NULL REFERENCES sessions(id),
108    timestamp      TEXT NOT NULL,
109    overall_status TEXT NOT NULL,
110    report_json    TEXT NOT NULL
111);
112
113CREATE TABLE IF NOT EXISTS healing_events (
114    id             INTEGER PRIMARY KEY AUTOINCREMENT,
115    agent_pid      INTEGER NOT NULL,
116    timestamp      TEXT NOT NULL,
117    original_error TEXT NOT NULL,
118    adjustment     TEXT NOT NULL,
119    outcome        TEXT NOT NULL,
120    attempt_number INTEGER NOT NULL
121);
122
123CREATE TABLE IF NOT EXISTS patterns (
124    id               TEXT PRIMARY KEY,
125    action_sequence  TEXT NOT NULL,
126    occurrence_count INTEGER NOT NULL DEFAULT 0,
127    last_occurrence  TEXT NOT NULL,
128    status           TEXT NOT NULL DEFAULT 'Detected',
129    created_at       TEXT NOT NULL,
130    expires_at       TEXT NOT NULL
131);
132
133CREATE TABLE IF NOT EXISTS installed_packages (
134    name         TEXT PRIMARY KEY,
135    version      TEXT NOT NULL,
136    package_type TEXT NOT NULL,
137    signature    TEXT NOT NULL,
138    installed_at TEXT NOT NULL,
139    capabilities TEXT
140);
141
142CREATE TABLE IF NOT EXISTS sync_peers (
143    device_id       TEXT PRIMARY KEY,
144    last_sync       TEXT,
145    status          TEXT NOT NULL DEFAULT 'Disconnected',
146    pending_changes INTEGER NOT NULL DEFAULT 0
147);
148
149CREATE TABLE IF NOT EXISTS sync_queue (
150    id            INTEGER PRIMARY KEY AUTOINCREMENT,
151    target_device TEXT NOT NULL,
152    data_type     TEXT NOT NULL,
153    data_key      TEXT NOT NULL,
154    payload       BLOB NOT NULL,
155    created_at    TEXT NOT NULL,
156    synced_at     TEXT
157);
158
159CREATE TABLE IF NOT EXISTS bus_queue (
160    id           INTEGER PRIMARY KEY AUTOINCREMENT,
161    target_pid   INTEGER NOT NULL,
162    topic        TEXT,
163    message_json TEXT NOT NULL,
164    created_at   TEXT NOT NULL,
165    expires_at   TEXT NOT NULL,
166    delivered    INTEGER NOT NULL DEFAULT 0
167);
168
169CREATE TABLE IF NOT EXISTS watch_alerts (
170    id           INTEGER PRIMARY KEY AUTOINCREMENT,
171    alert_type   TEXT NOT NULL,
172    timestamp    TEXT NOT NULL,
173    agent_name   TEXT,
174    details_json TEXT NOT NULL,
175    acknowledged INTEGER NOT NULL DEFAULT 0
176);
177";
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use tempfile::NamedTempFile;
183
184    fn temp_db() -> (NamedTempFile, Connection) {
185        let f = NamedTempFile::new().unwrap();
186        let conn = init_database(f.path()).unwrap();
187        (f, conn)
188    }
189
190    #[test]
191    fn test_all_tables_created() {
192        let (_f, conn) = temp_db();
193
194        let expected = [
195            "schema_version",
196            "agents",
197            "sessions",
198            "session_actions",
199            "snapshots",
200            "snapshot_files",
201            "token_usage",
202            "verification_reports",
203            "healing_events",
204            "patterns",
205            "installed_packages",
206            "sync_peers",
207            "sync_queue",
208            "bus_queue",
209            "watch_alerts",
210        ];
211
212        for table in &expected {
213            let count: i64 = conn
214                .query_row(
215                    "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?1",
216                    rusqlite::params![table],
217                    |row| row.get(0),
218                )
219                .unwrap();
220            assert_eq!(count, 1, "missing table: {table}");
221        }
222    }
223
224    #[test]
225    fn test_idempotent_reinit() {
226        let f = NamedTempFile::new().unwrap();
227        init_database(f.path()).unwrap();
228        let conn = init_database(f.path()).unwrap();
229
230        let count: i64 = conn
231            .query_row(
232                "SELECT COUNT(*) FROM sqlite_master WHERE type='table'",
233                [],
234                |row| row.get(0),
235            )
236            .unwrap();
237        assert!(count > 0);
238    }
239
240    #[test]
241    fn test_wal_mode_enabled() {
242        let (_f, conn) = temp_db();
243
244        let mode: String = conn
245            .query_row("PRAGMA journal_mode", [], |row| row.get(0))
246            .unwrap();
247        assert_eq!(mode, "wal");
248    }
249}