Skip to main content

agent_orchestrator/
async_database.rs

1use crate::db::configure_conn;
2use anyhow::Result;
3use rusqlite::OpenFlags;
4use std::path::{Path, PathBuf};
5
6/// Async wrapper around SQLite using `tokio_rusqlite`.
7///
8/// Uses two named connections (not a pool):
9/// - **writer**: all write operations, serialized to match SQLite WAL single-writer model
10/// - **reader**: read-only queries, avoids contention with writer lock
11#[derive(Clone)]
12pub struct AsyncDatabase {
13    db_path: PathBuf,
14    writer: tokio_rusqlite::Connection,
15    reader: tokio_rusqlite::Connection,
16}
17
18impl AsyncDatabase {
19    /// Opens the database and configures paired writer and reader connections.
20    ///
21    /// The writer connection uses default read-write flags, while the reader
22    /// connection is opened read-only to reduce contention.
23    pub async fn open(db_path: impl AsRef<Path>) -> Result<Self> {
24        let db_path = db_path.as_ref().to_path_buf();
25
26        // Writer: read-write (default flags)
27        let writer = tokio_rusqlite::Connection::open(&db_path)
28            .await
29            .map_err(flatten_err)?;
30        writer
31            .call(|conn| configure_conn(conn).map_err(|e| tokio_rusqlite::Error::Other(e.into())))
32            .await
33            .map_err(flatten_err)?;
34
35        // Reader: read-only
36        let reader = tokio_rusqlite::Connection::open_with_flags(
37            &db_path,
38            OpenFlags::SQLITE_OPEN_READ_ONLY | OpenFlags::SQLITE_OPEN_NO_MUTEX,
39        )
40        .await
41        .map_err(flatten_err)?;
42        reader
43            .call(|conn| configure_conn(conn).map_err(|e| tokio_rusqlite::Error::Other(e.into())))
44            .await
45            .map_err(flatten_err)?;
46
47        Ok(Self {
48            db_path,
49            writer,
50            reader,
51        })
52    }
53
54    /// Returns the filesystem path for the database file.
55    pub fn path(&self) -> &Path {
56        &self.db_path
57    }
58
59    /// Returns the write-capable SQLite connection.
60    pub fn writer(&self) -> &tokio_rusqlite::Connection {
61        &self.writer
62    }
63
64    /// Returns the read-only SQLite connection.
65    pub fn reader(&self) -> &tokio_rusqlite::Connection {
66        &self.reader
67    }
68}
69
70/// Flatten `tokio_rusqlite::Error` into `anyhow::Error`.
71pub fn flatten_err(err: tokio_rusqlite::Error) -> anyhow::Error {
72    match err {
73        tokio_rusqlite::Error::ConnectionClosed => anyhow::anyhow!("db connection closed"),
74        tokio_rusqlite::Error::Close((_, e)) => e.into(),
75        tokio_rusqlite::Error::Rusqlite(e) => e.into(),
76        tokio_rusqlite::Error::Other(e) => anyhow::anyhow!(e),
77        _ => anyhow::anyhow!("unknown tokio-rusqlite error"),
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84    use crate::db::init_schema;
85    use tempfile::tempdir;
86
87    #[tokio::test]
88    async fn async_database_open_and_configure() {
89        let temp = tempdir().expect("temp dir");
90        let db_path = temp.path().join("async_test.db");
91        init_schema(&db_path).expect("init schema");
92
93        let db = AsyncDatabase::open(&db_path).await.expect("open async db");
94        assert_eq!(db.path(), db_path);
95
96        // Verify writer pragmas
97        let busy_timeout: i64 = db
98            .writer()
99            .call(|conn| {
100                conn.query_row("PRAGMA busy_timeout;", [], |row| row.get(0))
101                    .map_err(|e| e.into())
102            })
103            .await
104            .expect("query busy_timeout");
105        assert_eq!(busy_timeout, 5000);
106
107        let foreign_keys: i64 = db
108            .writer()
109            .call(|conn| {
110                conn.query_row("PRAGMA foreign_keys;", [], |row| row.get(0))
111                    .map_err(|e| e.into())
112            })
113            .await
114            .expect("query foreign_keys");
115        assert_eq!(foreign_keys, 1);
116    }
117
118    #[tokio::test]
119    async fn async_database_read_write_roundtrip() {
120        let temp = tempdir().expect("temp dir");
121        let db_path = temp.path().join("rw_test.db");
122        init_schema(&db_path).expect("init schema");
123
124        let db = AsyncDatabase::open(&db_path).await.expect("open async db");
125
126        // Write via writer
127        db.writer()
128            .call(|conn| {
129                conn.execute(
130                    "INSERT INTO events (task_id, event_type, payload_json, created_at) VALUES ('t1', 'test', '{}', '2026-01-01')",
131                    [],
132                )?;
133                Ok(())
134            })
135            .await
136            .expect("write event");
137
138        // Read via reader
139        let count: i64 = db
140            .reader()
141            .call(|conn| Ok(conn.query_row("SELECT COUNT(*) FROM events", [], |row| row.get(0))?))
142            .await
143            .expect("read count");
144        assert_eq!(count, 1);
145    }
146
147    #[tokio::test]
148    async fn async_database_clone_shares_connections() {
149        let temp = tempdir().expect("temp dir");
150        let db_path = temp.path().join("clone_test.db");
151        init_schema(&db_path).expect("init schema");
152
153        let db = AsyncDatabase::open(&db_path).await.expect("open async db");
154        let db2 = db.clone();
155
156        // Write through clone
157        db2.writer()
158            .call(|conn| {
159                conn.execute(
160                    "INSERT INTO events (task_id, event_type, payload_json, created_at) VALUES ('t1', 'test', '{}', '2026-01-01')",
161                    [],
162                )?;
163                Ok(())
164            })
165            .await
166            .expect("write via clone");
167
168        // Read through original
169        let count: i64 = db
170            .reader()
171            .call(|conn| Ok(conn.query_row("SELECT COUNT(*) FROM events", [], |row| row.get(0))?))
172            .await
173            .expect("read via original");
174        assert_eq!(count, 1);
175    }
176}