1pub mod error;
11pub mod events;
12pub mod hooks;
13pub mod pages;
14pub mod robots;
15pub mod servers;
16pub mod summaries;
17pub mod system;
18pub mod tasks;
19
20pub use error::StorageError;
21pub use hooks::{UpdateHookGuard, register_tasks_update_hook};
22
23use std::path::{Path, PathBuf};
24use std::sync::{Arc, Mutex};
25use std::time::Duration;
26
27use tokio_rusqlite::Connection;
28
29pub type NewTaskNotify = tokio::sync::mpsc::UnboundedSender<String>;
34
35#[derive(Debug, Clone)]
43pub struct Db {
44 pub(crate) conn: Connection,
45 pub(crate) new_task_tx: Arc<Mutex<Option<NewTaskNotify>>>,
46 path: PathBuf,
47}
48
49impl Db {
50 pub fn set_new_task_sender(&self, tx: NewTaskNotify) {
53 *self.new_task_tx.lock().expect("new_task_tx mutex poisoned") = Some(tx);
54 }
55
56 pub fn path(&self) -> &Path {
61 &self.path
62 }
63}
64
65const MIGRATIONS: &[(&str, &str)] = &[
72 (
73 "001_initial.sql",
74 include_str!("migrations/001_initial.sql"),
75 ),
76 (
77 "002_servers.sql",
78 include_str!("migrations/002_servers.sql"),
79 ),
80 (
81 "003_robots_state.sql",
82 include_str!("migrations/003_robots_state.sql"),
83 ),
84 ("004_tasks.sql", include_str!("migrations/004_tasks.sql")),
85 (
86 "005_summary_cache.sql",
87 include_str!("migrations/005_summary_cache.sql"),
88 ),
89 (
90 "006_render_reason.sql",
91 include_str!("migrations/006_render_reason.sql"),
92 ),
93];
94
95enum MigrationOutcome {
98 Ok,
99 FailedAt { name: String, err: rusqlite::Error },
100}
101
102impl Db {
103 pub async fn open(path: impl AsRef<Path>) -> Result<Self, StorageError> {
105 Self::open_with_migrations(path, MIGRATIONS).await
106 }
107
108 pub(crate) async fn open_with_migrations(
112 path: impl AsRef<Path>,
113 migrations: &'static [(&'static str, &'static str)],
114 ) -> Result<Self, StorageError> {
115 let path_owned = path.as_ref().to_path_buf();
116 let path_str = path_owned.display().to_string();
117 let conn = Connection::open(&path_owned)
118 .await
119 .map_err(|source| StorageError::Open {
120 path: path_str.clone(),
121 source: tokio_rusqlite::Error::Error(source),
122 })?;
123
124 conn.call(|c| {
125 c.pragma_update(None, "journal_mode", "WAL")?;
126 c.busy_timeout(Duration::from_secs(5))?;
127 Ok::<_, rusqlite::Error>(())
128 })
129 .await?;
130
131 let db = Self {
132 conn,
133 new_task_tx: Arc::new(Mutex::new(None)),
134 path: path_owned,
135 };
136 db.run_migrations(migrations).await?;
137 Ok(db)
138 }
139
140 async fn run_migrations(
141 &self,
142 migrations: &'static [(&'static str, &'static str)],
143 ) -> Result<(), StorageError> {
144 let outcome = self
145 .conn
146 .call(move |c| {
147 let current = system::read_schema_version(c).map_err(unwrap_storage_err)?;
148 for (idx, (name, sql)) in migrations.iter().enumerate() {
149 let target = (idx + 1) as u32;
150 if current >= target {
151 continue;
152 }
153 let tx = c.unchecked_transaction()?;
156 if let Err(err) = tx
157 .execute_batch(sql)
158 .and_then(|()| {
159 system::write_schema_version(&tx, target).map_err(unwrap_storage_err)
160 })
161 .and_then(|()| tx.commit())
162 {
163 return Ok(MigrationOutcome::FailedAt {
164 name: (*name).to_string(),
165 err,
166 });
167 }
168 tracing::info!(target: "rover::storage", migration = name, "applied migration");
169 }
170 Ok::<_, rusqlite::Error>(MigrationOutcome::Ok)
171 })
172 .await?;
173
174 match outcome {
175 MigrationOutcome::Ok => Ok(()),
176 MigrationOutcome::FailedAt { name, err } => Err(StorageError::Migration {
177 name,
178 source: tokio_rusqlite::Error::Error(err),
179 }),
180 }
181 }
182
183 pub async fn schema_version(&self) -> Result<u32, StorageError> {
185 self.conn
186 .call(|c| Ok::<_, rusqlite::Error>(system::read_schema_version(c)))
187 .await?
188 }
189}
190
191fn unwrap_storage_err(e: StorageError) -> rusqlite::Error {
197 match e {
198 StorageError::Backend(tokio_rusqlite::Error::Error(inner)) => inner,
199 other => rusqlite::Error::ToSqlConversionFailure(Box::new(StringErr(other.to_string()))),
200 }
201}
202
203#[derive(Debug)]
204pub(crate) struct StringErr(pub(crate) String);
205impl std::fmt::Display for StringErr {
206 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207 f.write_str(&self.0)
208 }
209}
210impl std::error::Error for StringErr {}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215
216 #[tokio::test]
217 async fn open_creates_db_and_applies_migrations() {
218 let tmp = tempfile::tempdir().unwrap();
219 let path = tmp.path().join("rover.db");
220 let db = Db::open(&path).await.unwrap();
221 assert_eq!(db.schema_version().await.unwrap(), MIGRATIONS.len() as u32);
222 }
223
224 #[tokio::test]
225 async fn open_is_idempotent() {
226 let tmp = tempfile::tempdir().unwrap();
227 let path = tmp.path().join("rover.db");
228 let _db1 = Db::open(&path).await.unwrap();
229 let db2 = Db::open(&path).await.unwrap();
230 assert_eq!(db2.schema_version().await.unwrap(), MIGRATIONS.len() as u32);
231 }
232
233 #[tokio::test]
234 async fn open_creates_pages_table() {
235 let tmp = tempfile::tempdir().unwrap();
236 let path = tmp.path().join("rover.db");
237 let db = Db::open(&path).await.unwrap();
238 let count: i64 = db
239 .conn
240 .call(|c| {
241 let n: i64 =
242 c.query_row("SELECT COUNT(*) FROM pages", [], |r| r.get::<_, i64>(0))?;
243 Ok::<_, rusqlite::Error>(n)
244 })
245 .await
246 .unwrap();
247 assert_eq!(count, 0);
248 }
249
250 const BROKEN_MIGRATIONS: &[(&str, &str)] =
251 &[("001_broken.sql", "CREATE TABLE oops(SYNTAX ERROR);")];
252
253 #[tokio::test]
254 async fn broken_migration_surfaces_named_migration_error() {
255 let tmp = tempfile::tempdir().unwrap();
256 let path = tmp.path().join("rover.db");
257 let err = Db::open_with_migrations(&path, BROKEN_MIGRATIONS)
258 .await
259 .expect_err("broken migration must fail");
260 match err {
261 StorageError::Migration { name, .. } => {
262 assert_eq!(name, "001_broken.sql");
263 }
264 other => panic!("expected StorageError::Migration, got {other:?}"),
265 }
266 }
267
268 #[tokio::test]
269 async fn migration_003_adds_state_column_to_robots_cache() {
270 let tmp = tempfile::tempdir().unwrap();
271 let path = tmp.path().join("rover.db");
272 let db = Db::open(&path).await.unwrap();
273
274 let cols: Vec<String> = db
275 .conn
276 .call(|c| {
277 let mut stmt = c.prepare("PRAGMA table_info(robots_cache)")?;
278 let mut rows = stmt.query([])?;
279 let mut out = Vec::new();
280 while let Some(r) = rows.next()? {
281 out.push(r.get::<_, String>(1)?);
282 }
283 Ok::<_, rusqlite::Error>(out)
284 })
285 .await
286 .unwrap();
287 assert!(cols.contains(&"state".to_string()), "cols = {cols:?}");
288 assert_eq!(db.schema_version().await.unwrap(), MIGRATIONS.len() as u32);
289 }
290
291 #[tokio::test]
292 async fn migration_005_adds_summary_cache_table() {
293 let tmp = tempfile::tempdir().unwrap();
294 let path = tmp.path().join("rover.db");
295 let db = Db::open(&path).await.unwrap();
296
297 let count: i64 = db
298 .conn
299 .call(|c| {
300 let n: i64 = c.query_row("SELECT COUNT(*) FROM summary_cache", [], |r| {
301 r.get::<_, i64>(0)
302 })?;
303 Ok::<_, rusqlite::Error>(n)
304 })
305 .await
306 .unwrap();
307 assert_eq!(count, 0);
308 assert_eq!(db.schema_version().await.unwrap(), MIGRATIONS.len() as u32);
309 }
310}