1use crate::error::StorageError;
4use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
5use std::path::Path;
6use std::str::FromStr;
7
8static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("./migrations");
9
10#[derive(Debug, Clone)]
12pub struct Storage {
13 pool: sqlx::SqlitePool,
14}
15
16impl Storage {
17 pub async fn open(path: impl AsRef<Path>) -> Result<Self, StorageError> {
20 let options = SqliteConnectOptions::new()
21 .filename(path)
22 .create_if_missing(true)
23 .foreign_keys(true);
24 let pool = SqlitePoolOptions::new()
25 .max_connections(5)
26 .connect_with(options)
27 .await?;
28 MIGRATOR.run(&pool).await?;
29 Ok(Self { pool })
30 }
31
32 pub async fn in_memory_for_tests() -> Result<Self, StorageError> {
35 let options = SqliteConnectOptions::from_str("sqlite::memory:")?.foreign_keys(true);
36 let pool = SqlitePoolOptions::new()
37 .max_connections(1) .connect_with(options)
39 .await?;
40 MIGRATOR.run(&pool).await?;
41 Ok(Self { pool })
42 }
43
44 pub fn pool(&self) -> &sqlx::SqlitePool {
46 &self.pool
47 }
48}
49
50#[cfg(test)]
51mod tests {
52 use super::*;
53
54 #[tokio::test]
55 async fn in_memory_storage_applies_migrations() {
56 let storage = Storage::in_memory_for_tests().await.unwrap();
57 let exists: (i64,) = sqlx::query_as(
58 "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='projects'",
59 )
60 .fetch_one(storage.pool())
61 .await
62 .unwrap();
63 assert_eq!(exists.0, 1);
64 }
65
66 #[tokio::test]
67 async fn foreign_keys_pragma_is_enabled() {
68 let storage = Storage::in_memory_for_tests().await.unwrap();
69 let fk: (i64,) = sqlx::query_as("PRAGMA foreign_keys")
70 .fetch_one(storage.pool())
71 .await
72 .unwrap();
73 assert_eq!(fk.0, 1, "foreign_keys pragma must be ON");
74 }
75
76 #[tokio::test]
77 async fn all_five_tables_exist_after_migration() {
78 let storage = Storage::in_memory_for_tests().await.unwrap();
79 let names: Vec<(String,)> =
80 sqlx::query_as("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
81 .fetch_all(storage.pool())
82 .await
83 .unwrap();
84 let got: Vec<&str> = names.iter().map(|(n,)| n.as_str()).collect();
85 for expected in [
86 "agent_configs",
87 "experiments",
88 "projects",
89 "sessions",
90 "signals",
91 ] {
92 assert!(
93 got.contains(&expected),
94 "missing table {expected}; got {got:?}",
95 );
96 }
97 }
98}