Skip to main content

aft/db/
mod.rs

1use rusqlite::{Connection, TransactionBehavior};
2use std::fmt;
3use std::fs;
4use std::path::Path;
5
6pub mod backups;
7pub mod bash_tasks;
8pub mod compression_events;
9pub mod state;
10
11pub const CURRENT_SCHEMA_VERSION: u32 = 2;
12
13const MIGRATION_V1: &str = r#"
14CREATE TABLE IF NOT EXISTS schema_version (
15  version INTEGER NOT NULL PRIMARY KEY
16);
17
18CREATE TABLE IF NOT EXISTS bash_tasks (
19  harness      TEXT NOT NULL,
20  session_id   TEXT NOT NULL,
21  task_id      TEXT NOT NULL,
22  project_key  TEXT NOT NULL,
23  command      TEXT NOT NULL,
24  cwd          TEXT NOT NULL,
25  status       TEXT NOT NULL,
26  exit_code    INTEGER,
27  pid          INTEGER,
28  pgid         INTEGER,
29  started_at   INTEGER NOT NULL,
30  completed_at INTEGER,
31  stdout_path  TEXT,
32  stderr_path  TEXT,
33  compressed   INTEGER NOT NULL DEFAULT 1,
34  timeout_ms   INTEGER,
35  completion_delivered INTEGER NOT NULL DEFAULT 0,
36  output_bytes INTEGER,
37  metadata     TEXT,
38  PRIMARY KEY (harness, session_id, task_id)
39);
40CREATE INDEX IF NOT EXISTS idx_bash_tasks_project_key ON bash_tasks(project_key);
41CREATE INDEX IF NOT EXISTS idx_bash_tasks_status      ON bash_tasks(status);
42CREATE INDEX IF NOT EXISTS idx_bash_tasks_session_status ON bash_tasks(harness, session_id, status);
43
44CREATE TABLE IF NOT EXISTS compression_events (
45  id                INTEGER PRIMARY KEY AUTOINCREMENT,
46  harness           TEXT NOT NULL,
47  session_id        TEXT,
48  project_key       TEXT NOT NULL,
49  tool              TEXT NOT NULL,
50  task_id           TEXT,
51  command           TEXT,
52  compressor        TEXT NOT NULL,
53  original_bytes    INTEGER NOT NULL,
54  compressed_bytes  INTEGER NOT NULL,
55  original_tokens   INTEGER NOT NULL,
56  compressed_tokens INTEGER NOT NULL,
57  created_at        INTEGER NOT NULL
58);
59CREATE INDEX IF NOT EXISTS idx_compression_session         ON compression_events(harness, session_id);
60CREATE INDEX IF NOT EXISTS idx_compression_session_created ON compression_events(harness, session_id, created_at);
61CREATE INDEX IF NOT EXISTS idx_compression_project_key     ON compression_events(project_key);
62
63CREATE TABLE IF NOT EXISTS backups (
64  id            INTEGER PRIMARY KEY AUTOINCREMENT,
65  backup_id     TEXT,
66  harness       TEXT NOT NULL,
67  session_id    TEXT NOT NULL,
68  project_key   TEXT NOT NULL,
69  op_id         TEXT,
70  order_blob    BLOB NOT NULL,
71  file_path     TEXT NOT NULL,
72  path_hash     TEXT NOT NULL,
73  backup_path   TEXT,
74  kind          TEXT NOT NULL,
75  description   TEXT,
76  created_at    INTEGER NOT NULL,
77  is_tombstone  INTEGER NOT NULL DEFAULT 0
78);
79CREATE INDEX IF NOT EXISTS idx_backups_session_path  ON backups(harness, session_id, path_hash);
80CREATE INDEX IF NOT EXISTS idx_backups_session_op    ON backups(harness, session_id, op_id) WHERE op_id IS NOT NULL;
81CREATE INDEX IF NOT EXISTS idx_backups_session_order ON backups(harness, session_id, order_blob DESC);
82CREATE INDEX IF NOT EXISTS idx_backups_session_path_order ON backups(harness, session_id, path_hash, order_blob DESC);
83
84CREATE TABLE IF NOT EXISTS harness_state (
85  harness    TEXT NOT NULL,
86  key        TEXT NOT NULL,
87  value      TEXT NOT NULL,
88  updated_at INTEGER NOT NULL,
89  PRIMARY KEY (harness, key)
90);
91
92CREATE TABLE IF NOT EXISTS host_state (
93  key        TEXT NOT NULL PRIMARY KEY,
94  value      TEXT NOT NULL,
95  updated_at INTEGER NOT NULL
96);
97"#;
98
99const MIGRATION_V2: &str = r#"
100DELETE FROM compression_events
101WHERE id NOT IN (
102  SELECT MIN(id)
103  FROM compression_events
104  GROUP BY
105    harness,
106    COALESCE(session_id, char(0)),
107    project_key,
108    tool,
109    COALESCE(task_id, char(0))
110);
111
112CREATE UNIQUE INDEX IF NOT EXISTS idx_compression_event_identity
113ON compression_events (
114  harness,
115  COALESCE(session_id, char(0)),
116  project_key,
117  tool,
118  COALESCE(task_id, char(0))
119);
120"#;
121
122#[derive(Debug)]
123pub enum OpenError {
124    Io(std::io::Error),
125    Sqlite(rusqlite::Error),
126    DowngradeRefused {
127        db_version: u32,
128        supported: u32,
129    },
130    MigrationFailed {
131        from: u32,
132        to: u32,
133        error: rusqlite::Error,
134    },
135}
136
137impl fmt::Display for OpenError {
138    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139        match self {
140            OpenError::Io(error) => write!(f, "database I/O error: {error}"),
141            OpenError::Sqlite(error) => write!(f, "sqlite error: {error}"),
142            OpenError::DowngradeRefused {
143                db_version,
144                supported,
145            } => write!(
146                f,
147                "database schema version {db_version} is newer than supported version {supported}"
148            ),
149            OpenError::MigrationFailed { from, to, error } => {
150                write!(f, "database migration {from}->{to} failed: {error}")
151            }
152        }
153    }
154}
155
156impl std::error::Error for OpenError {}
157
158impl From<std::io::Error> for OpenError {
159    fn from(error: std::io::Error) -> Self {
160        OpenError::Io(error)
161    }
162}
163
164impl From<rusqlite::Error> for OpenError {
165    fn from(error: rusqlite::Error) -> Self {
166        OpenError::Sqlite(error)
167    }
168}
169
170/// Open or create the AFT SQLite database at the given path.
171///
172/// Applies per-connection PRAGMAs, runs schema migrations from the DB's
173/// current schema version up to [`CURRENT_SCHEMA_VERSION`], and returns the
174/// configured connection.
175pub fn open(path: &Path) -> Result<Connection, OpenError> {
176    if let Some(parent) = path.parent() {
177        if !parent.as_os_str().is_empty() {
178            fs::create_dir_all(parent)?;
179        }
180    }
181
182    let mut conn = Connection::open(path)?;
183    apply_pragmas(&conn)?;
184    run_migrations(&mut conn)?;
185    Ok(conn)
186}
187
188/// Apply the per-connection PRAGMAs required for every AFT SQLite connection.
189pub fn apply_pragmas(conn: &Connection) -> Result<(), rusqlite::Error> {
190    conn.pragma_update(None, "foreign_keys", "ON")?;
191    conn.pragma_update(None, "journal_mode", "WAL")?;
192    conn.pragma_update(None, "busy_timeout", 5000)?;
193    conn.pragma_update(None, "synchronous", "NORMAL")?;
194    Ok(())
195}
196
197/// Run forward-only migrations up to [`CURRENT_SCHEMA_VERSION`].
198///
199/// Returns the post-migration schema version. Refuses to open databases created
200/// by newer AFT versions.
201pub fn run_migrations(conn: &mut Connection) -> Result<u32, OpenError> {
202    conn.execute_batch(
203        "CREATE TABLE IF NOT EXISTS schema_version (version INTEGER NOT NULL PRIMARY KEY);",
204    )?;
205
206    let db_version = current_schema_version(conn)?;
207    if db_version > CURRENT_SCHEMA_VERSION {
208        return Err(OpenError::DowngradeRefused {
209            db_version,
210            supported: CURRENT_SCHEMA_VERSION,
211        });
212    }
213
214    for version in (db_version + 1)..=CURRENT_SCHEMA_VERSION {
215        apply_migration(conn, version)?;
216    }
217
218    Ok(current_schema_version(conn)?)
219}
220
221fn current_schema_version(conn: &Connection) -> Result<u32, rusqlite::Error> {
222    conn.query_row(
223        "SELECT COALESCE(MAX(version), 0) FROM schema_version",
224        [],
225        |row| row.get::<_, u32>(0),
226    )
227}
228
229fn apply_migration(conn: &mut Connection, version: u32) -> Result<(), OpenError> {
230    let from = version - 1;
231    let tx = conn
232        .transaction_with_behavior(TransactionBehavior::Immediate)
233        .map_err(|error| OpenError::MigrationFailed {
234            from,
235            to: version,
236            error,
237        })?;
238
239    let result = match version {
240        1 => tx.execute_batch(MIGRATION_V1),
241        2 => tx.execute_batch(MIGRATION_V2),
242        _ => Ok(()),
243    }
244    .and_then(|()| {
245        tx.execute("DELETE FROM schema_version", [])?;
246        tx.execute(
247            "INSERT OR REPLACE INTO schema_version (version) VALUES (?1)",
248            [version],
249        )?;
250        tx.commit()
251    });
252
253    result.map_err(|error| OpenError::MigrationFailed {
254        from,
255        to: version,
256        error,
257    })
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use rusqlite::params;
264    use tempfile::tempdir;
265
266    const EXPECTED_TABLES: &[&str] = &[
267        "schema_version",
268        "bash_tasks",
269        "compression_events",
270        "backups",
271        "harness_state",
272        "host_state",
273    ];
274
275    const EXPECTED_INDEXES: &[&str] = &[
276        "idx_bash_tasks_project_key",
277        "idx_bash_tasks_status",
278        "idx_bash_tasks_session_status",
279        "idx_compression_session",
280        "idx_compression_session_created",
281        "idx_compression_project_key",
282        "idx_compression_event_identity",
283        "idx_backups_session_path",
284        "idx_backups_session_op",
285        "idx_backups_session_order",
286        "idx_backups_session_path_order",
287    ];
288
289    #[test]
290    fn open_fresh_db_creates_all_tables() {
291        let dir = tempdir().unwrap();
292        let conn = open(&dir.path().join("aft.db")).unwrap();
293
294        let tables = sqlite_names(&conn, "table");
295        for table in EXPECTED_TABLES {
296            assert!(tables.contains(&table.to_string()), "missing table {table}");
297        }
298    }
299
300    #[test]
301    fn open_fresh_db_creates_all_indexes() {
302        let dir = tempdir().unwrap();
303        let conn = open(&dir.path().join("aft.db")).unwrap();
304
305        let indexes = sqlite_names(&conn, "index");
306        for index in EXPECTED_INDEXES {
307            assert!(
308                indexes.contains(&index.to_string()),
309                "missing index {index}"
310            );
311        }
312    }
313
314    #[test]
315    fn open_existing_db_is_idempotent() {
316        let dir = tempdir().unwrap();
317        let path = dir.path().join("aft.db");
318
319        let conn = open(&path).unwrap();
320        let first_version = schema_version(&conn);
321        drop(conn);
322
323        let conn = open(&path).unwrap();
324        assert_eq!(schema_version(&conn), first_version);
325    }
326
327    #[test]
328    fn pragmas_applied_correctly() {
329        let dir = tempdir().unwrap();
330        let conn = open(&dir.path().join("aft.db")).unwrap();
331
332        let foreign_keys: i64 = conn
333            .query_row("PRAGMA foreign_keys", [], |row| row.get(0))
334            .unwrap();
335        let journal_mode: String = conn
336            .query_row("PRAGMA journal_mode", [], |row| row.get(0))
337            .unwrap();
338        let busy_timeout: i64 = conn
339            .query_row("PRAGMA busy_timeout", [], |row| row.get(0))
340            .unwrap();
341        let synchronous: i64 = conn
342            .query_row("PRAGMA synchronous", [], |row| row.get(0))
343            .unwrap();
344
345        assert_eq!(foreign_keys, 1);
346        assert_eq!(journal_mode, "wal");
347        assert_eq!(busy_timeout, 5000);
348        assert_eq!(synchronous, 1);
349    }
350
351    #[test]
352    fn downgrade_refused() {
353        let dir = tempdir().unwrap();
354        let path = dir.path().join("aft.db");
355        let conn = open(&path).unwrap();
356        conn.execute("INSERT OR REPLACE INTO schema_version VALUES (999)", [])
357            .unwrap();
358        drop(conn);
359
360        match open(&path).unwrap_err() {
361            OpenError::DowngradeRefused {
362                db_version,
363                supported,
364            } => {
365                assert_eq!(db_version, 999);
366                assert_eq!(supported, CURRENT_SCHEMA_VERSION);
367            }
368            error => panic!("expected downgrade refusal, got {error:?}"),
369        }
370    }
371
372    #[test]
373    fn migration_runner_advances_version() {
374        let dir = tempdir().unwrap();
375        let conn = open(&dir.path().join("aft.db")).unwrap();
376
377        assert_eq!(schema_version(&conn), CURRENT_SCHEMA_VERSION);
378    }
379
380    #[test]
381    fn migration_runner_no_op_when_current() {
382        let dir = tempdir().unwrap();
383        let path = dir.path().join("aft.db");
384
385        let conn = open(&path).unwrap();
386        assert_eq!(schema_version_row_count(&conn), 1);
387        drop(conn);
388
389        let conn = open(&path).unwrap();
390        assert_eq!(schema_version(&conn), CURRENT_SCHEMA_VERSION);
391        assert_eq!(schema_version_row_count(&conn), 1);
392    }
393
394    #[test]
395    fn harness_state_compound_pk_works() {
396        let dir = tempdir().unwrap();
397        let conn = open(&dir.path().join("aft.db")).unwrap();
398
399        conn.execute(
400            "INSERT INTO harness_state (harness, key, value, updated_at) VALUES (?1, ?2, ?3, ?4)",
401            params!["opencode", "warned_tools", "{}", 1_i64],
402        )
403        .unwrap();
404        let duplicate = conn.execute(
405            "INSERT INTO harness_state (harness, key, value, updated_at) VALUES (?1, ?2, ?3, ?4)",
406            params!["opencode", "warned_tools", "{}", 2_i64],
407        );
408        assert_unique_constraint(duplicate);
409
410        conn.execute(
411            "INSERT INTO harness_state (harness, key, value, updated_at) VALUES (?1, ?2, ?3, ?4)",
412            params!["pi", "warned_tools", "{}", 3_i64],
413        )
414        .unwrap();
415    }
416
417    #[test]
418    fn host_state_simple_pk_works() {
419        let dir = tempdir().unwrap();
420        let conn = open(&dir.path().join("aft.db")).unwrap();
421
422        conn.execute(
423            "INSERT INTO host_state (key, value, updated_at) VALUES (?1, ?2, ?3)",
424            params!["trusted_filter_projects", "[]", 1_i64],
425        )
426        .unwrap();
427        let duplicate = conn.execute(
428            "INSERT INTO host_state (key, value, updated_at) VALUES (?1, ?2, ?3)",
429            params!["trusted_filter_projects", "[]", 2_i64],
430        );
431        assert_unique_constraint(duplicate);
432    }
433
434    #[test]
435    fn bash_tasks_compound_pk_works() {
436        let dir = tempdir().unwrap();
437        let conn = open(&dir.path().join("aft.db")).unwrap();
438
439        insert_bash_task(&conn, "opencode", "session-1", "bash-12345678").unwrap();
440        let duplicate = insert_bash_task(&conn, "opencode", "session-1", "bash-12345678");
441        assert_unique_constraint(duplicate);
442
443        insert_bash_task(&conn, "pi", "session-1", "bash-12345678").unwrap();
444    }
445
446    #[test]
447    fn backups_order_blob_sort() {
448        let dir = tempdir().unwrap();
449        let conn = open(&dir.path().join("aft.db")).unwrap();
450
451        let one = order_blob(1);
452        let two = order_blob(2);
453        let max = [0xFF; 16];
454
455        insert_backup(&conn, "one", &one).unwrap();
456        insert_backup(&conn, "two", &two).unwrap();
457        insert_backup(&conn, "max", &max).unwrap();
458
459        assert_eq!(backup_ids_ordered(&conn, "ASC"), vec!["one", "two", "max"]);
460        assert_eq!(backup_ids_ordered(&conn, "DESC"), vec!["max", "two", "one"]);
461    }
462
463    fn sqlite_names(conn: &Connection, kind: &str) -> Vec<String> {
464        let sql = match kind {
465            "table" => "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name",
466            "index" => "SELECT name FROM sqlite_master WHERE type='index' AND name NOT LIKE 'sqlite_%' ORDER BY name",
467            _ => panic!("unsupported sqlite_master kind: {kind}"),
468        };
469        let mut stmt = conn.prepare(sql).unwrap();
470        stmt.query_map([], |row| row.get::<_, String>(0))
471            .unwrap()
472            .collect::<Result<Vec<_>, _>>()
473            .unwrap()
474    }
475
476    fn schema_version(conn: &Connection) -> u32 {
477        conn.query_row("SELECT version FROM schema_version", [], |row| row.get(0))
478            .unwrap()
479    }
480
481    fn schema_version_row_count(conn: &Connection) -> i64 {
482        conn.query_row("SELECT COUNT(*) FROM schema_version", [], |row| row.get(0))
483            .unwrap()
484    }
485
486    fn assert_unique_constraint(result: rusqlite::Result<usize>) {
487        let error = result.expect_err("expected a unique constraint violation");
488        assert!(
489            error.to_string().contains("UNIQUE constraint failed"),
490            "expected UNIQUE constraint failure, got {error}"
491        );
492    }
493
494    fn insert_bash_task(
495        conn: &Connection,
496        harness: &str,
497        session_id: &str,
498        task_id: &str,
499    ) -> rusqlite::Result<usize> {
500        conn.execute(
501            "INSERT INTO bash_tasks (
502                harness, session_id, task_id, project_key, command, cwd, status, started_at
503             ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
504            params![
505                harness,
506                session_id,
507                task_id,
508                "project-key",
509                "echo ok",
510                "/tmp",
511                "running",
512                1_i64
513            ],
514        )
515    }
516
517    fn insert_backup(
518        conn: &Connection,
519        backup_id: &str,
520        order_blob: &[u8],
521    ) -> rusqlite::Result<usize> {
522        conn.execute(
523            "INSERT INTO backups (
524                backup_id, harness, session_id, project_key, order_blob, file_path,
525                path_hash, kind, created_at
526             ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
527            params![
528                backup_id,
529                "opencode",
530                "session-1",
531                "project-key",
532                order_blob,
533                "/tmp/file.txt",
534                "path-hash",
535                "content",
536                1_i64
537            ],
538        )
539    }
540
541    fn order_blob(value: u128) -> [u8; 16] {
542        value.to_be_bytes()
543    }
544
545    fn backup_ids_ordered(conn: &Connection, direction: &str) -> Vec<String> {
546        let sql = match direction {
547            "ASC" => "SELECT backup_id FROM backups ORDER BY order_blob ASC",
548            "DESC" => "SELECT backup_id FROM backups ORDER BY order_blob DESC",
549            _ => panic!("unsupported order direction: {direction}"),
550        };
551        let mut stmt = conn.prepare(sql).unwrap();
552        stmt.query_map([], |row| row.get::<_, String>(0))
553            .unwrap()
554            .collect::<Result<Vec<_>, _>>()
555            .unwrap()
556    }
557}