agent_orchestrator/
async_database.rs1use crate::db::configure_conn;
2use anyhow::Result;
3use rusqlite::OpenFlags;
4use std::path::{Path, PathBuf};
5
6#[derive(Clone)]
12pub struct AsyncDatabase {
13 db_path: PathBuf,
14 writer: tokio_rusqlite::Connection,
15 reader: tokio_rusqlite::Connection,
16}
17
18impl AsyncDatabase {
19 pub async fn open(db_path: impl AsRef<Path>) -> Result<Self> {
24 let db_path = db_path.as_ref().to_path_buf();
25
26 let writer = tokio_rusqlite::Connection::open(&db_path)
28 .await
29 .map_err(flatten_err)?;
30 writer
31 .call(|conn| configure_conn(conn).map_err(|e| tokio_rusqlite::Error::Other(e.into())))
32 .await
33 .map_err(flatten_err)?;
34
35 let reader = tokio_rusqlite::Connection::open_with_flags(
37 &db_path,
38 OpenFlags::SQLITE_OPEN_READ_ONLY | OpenFlags::SQLITE_OPEN_NO_MUTEX,
39 )
40 .await
41 .map_err(flatten_err)?;
42 reader
43 .call(|conn| configure_conn(conn).map_err(|e| tokio_rusqlite::Error::Other(e.into())))
44 .await
45 .map_err(flatten_err)?;
46
47 Ok(Self {
48 db_path,
49 writer,
50 reader,
51 })
52 }
53
54 pub fn path(&self) -> &Path {
56 &self.db_path
57 }
58
59 pub fn writer(&self) -> &tokio_rusqlite::Connection {
61 &self.writer
62 }
63
64 pub fn reader(&self) -> &tokio_rusqlite::Connection {
66 &self.reader
67 }
68}
69
70pub fn flatten_err(err: tokio_rusqlite::Error) -> anyhow::Error {
72 match err {
73 tokio_rusqlite::Error::ConnectionClosed => anyhow::anyhow!("db connection closed"),
74 tokio_rusqlite::Error::Close((_, e)) => e.into(),
75 tokio_rusqlite::Error::Rusqlite(e) => e.into(),
76 tokio_rusqlite::Error::Other(e) => anyhow::anyhow!(e),
77 _ => anyhow::anyhow!("unknown tokio-rusqlite error"),
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84 use crate::db::init_schema;
85 use tempfile::tempdir;
86
87 #[tokio::test]
88 async fn async_database_open_and_configure() {
89 let temp = tempdir().expect("temp dir");
90 let db_path = temp.path().join("async_test.db");
91 init_schema(&db_path).expect("init schema");
92
93 let db = AsyncDatabase::open(&db_path).await.expect("open async db");
94 assert_eq!(db.path(), db_path);
95
96 let busy_timeout: i64 = db
98 .writer()
99 .call(|conn| {
100 conn.query_row("PRAGMA busy_timeout;", [], |row| row.get(0))
101 .map_err(|e| e.into())
102 })
103 .await
104 .expect("query busy_timeout");
105 assert_eq!(busy_timeout, 5000);
106
107 let foreign_keys: i64 = db
108 .writer()
109 .call(|conn| {
110 conn.query_row("PRAGMA foreign_keys;", [], |row| row.get(0))
111 .map_err(|e| e.into())
112 })
113 .await
114 .expect("query foreign_keys");
115 assert_eq!(foreign_keys, 1);
116 }
117
118 #[tokio::test]
119 async fn async_database_read_write_roundtrip() {
120 let temp = tempdir().expect("temp dir");
121 let db_path = temp.path().join("rw_test.db");
122 init_schema(&db_path).expect("init schema");
123
124 let db = AsyncDatabase::open(&db_path).await.expect("open async db");
125
126 db.writer()
128 .call(|conn| {
129 conn.execute(
130 "INSERT INTO events (task_id, event_type, payload_json, created_at) VALUES ('t1', 'test', '{}', '2026-01-01')",
131 [],
132 )?;
133 Ok(())
134 })
135 .await
136 .expect("write event");
137
138 let count: i64 = db
140 .reader()
141 .call(|conn| Ok(conn.query_row("SELECT COUNT(*) FROM events", [], |row| row.get(0))?))
142 .await
143 .expect("read count");
144 assert_eq!(count, 1);
145 }
146
147 #[tokio::test]
148 async fn async_database_clone_shares_connections() {
149 let temp = tempdir().expect("temp dir");
150 let db_path = temp.path().join("clone_test.db");
151 init_schema(&db_path).expect("init schema");
152
153 let db = AsyncDatabase::open(&db_path).await.expect("open async db");
154 let db2 = db.clone();
155
156 db2.writer()
158 .call(|conn| {
159 conn.execute(
160 "INSERT INTO events (task_id, event_type, payload_json, created_at) VALUES ('t1', 'test', '{}', '2026-01-01')",
161 [],
162 )?;
163 Ok(())
164 })
165 .await
166 .expect("write via clone");
167
168 let count: i64 = db
170 .reader()
171 .call(|conn| Ok(conn.query_row("SELECT COUNT(*) FROM events", [], |row| row.get(0))?))
172 .await
173 .expect("read via original");
174 assert_eq!(count, 1);
175 }
176}