1use std::{
2 collections::HashMap,
3 path::{Path, PathBuf},
4 sync::Mutex,
5 thread::{self, JoinHandle},
6 time::{Duration, SystemTime, UNIX_EPOCH},
7};
8
9use attune_core::{BackendError, StorageBackend, StoredValue};
10use crossbeam_channel::{Receiver, RecvTimeoutError, Sender, unbounded};
11use rusqlite::{Connection, params};
12
13const POLL_INTERVAL: Duration = Duration::from_millis(1000);
14
15#[derive(Clone, Debug, Eq, PartialEq)]
16pub struct SqliteOptions {
17 pub cross_process: bool,
18 pub poll_interval: Duration,
19 pub journal_mode: SqliteJournalMode,
20 pub busy_timeout: Duration,
21 pub synchronous: SqliteSynchronous,
22}
23
24impl Default for SqliteOptions {
25 fn default() -> Self {
26 Self {
27 cross_process: true,
28 poll_interval: POLL_INTERVAL,
29 journal_mode: SqliteJournalMode::Wal,
30 busy_timeout: Duration::from_millis(5000),
31 synchronous: SqliteSynchronous::Normal,
32 }
33 }
34}
35
36#[derive(Clone, Copy, Debug, Eq, PartialEq)]
37pub enum SqliteJournalMode {
38 Wal,
39 Delete,
40 Truncate,
41 Persist,
42 Memory,
43 Off,
44}
45
46impl SqliteJournalMode {
47 fn as_pragma(self) -> &'static str {
48 match self {
49 Self::Wal => "WAL",
50 Self::Delete => "DELETE",
51 Self::Truncate => "TRUNCATE",
52 Self::Persist => "PERSIST",
53 Self::Memory => "MEMORY",
54 Self::Off => "OFF",
55 }
56 }
57}
58
59#[derive(Clone, Copy, Debug, Eq, PartialEq)]
60pub enum SqliteSynchronous {
61 Off,
62 Normal,
63 Full,
64 Extra,
65}
66
67impl SqliteSynchronous {
68 fn as_pragma(self) -> &'static str {
69 match self {
70 Self::Off => "OFF",
71 Self::Normal => "NORMAL",
72 Self::Full => "FULL",
73 Self::Extra => "EXTRA",
74 }
75 }
76}
77
78pub struct SqliteBackend {
79 conn: Mutex<Connection>,
80 commits_rx: Option<Receiver<()>>,
81 shutdown_tx: Option<Sender<()>>,
82 poll_thread: Option<JoinHandle<()>>,
83}
84
85impl SqliteBackend {
86 pub fn open(path: impl AsRef<Path>) -> Result<Self, BackendError> {
97 Self::open_with_options(path, SqliteOptions::default())
98 }
99
100 pub fn open_with_options(
111 path: impl AsRef<Path>,
112 options: SqliteOptions,
113 ) -> Result<Self, BackendError> {
114 let path = path.as_ref().to_path_buf();
115
116 let conn =
118 rusqlite::Connection::open(&path).map_err(|e| BackendError::Open(e.to_string()))?;
119
120 let pragmas = format!(
122 "PRAGMA journal_mode = {};\
123 PRAGMA busy_timeout = {};\
124 PRAGMA synchronous = {};\
125 PRAGMA foreign_keys = ON;",
126 options.journal_mode.as_pragma(),
127 options.busy_timeout.as_millis(),
128 options.synchronous.as_pragma(),
129 );
130 conn.execute_batch(&pragmas)
131 .map_err(|e| BackendError::Open(e.to_string()))?;
132
133 let settings_table_sql = "CREATE TABLE IF NOT EXISTS settings (
135 key TEXT PRIMARY KEY NOT NULL,
136 value TEXT NOT NULL,
137 updated_at INTEGER NOT NULL
138 )";
139 conn.execute(settings_table_sql, [])
140 .map_err(|e| BackendError::Open(e.to_string()))?;
141
142 let (commits_rx, shutdown_tx, poll_thread) = if options.cross_process {
144 let (commits_tx, commits_rx) = unbounded::<()>();
145 let (shutdown_tx, shutdown_rx) = unbounded::<()>();
146 let sidecar_path = path.clone();
147 let poll_interval = options.poll_interval;
148 let poll_thread = thread::spawn(move || {
149 polling_loop(sidecar_path, commits_tx, shutdown_rx, poll_interval);
150 });
151
152 (Some(commits_rx), Some(shutdown_tx), Some(poll_thread))
153 } else {
154 (None, None, None)
155 };
156
157 Ok(SqliteBackend {
158 conn: Mutex::new(conn),
159 commits_rx,
160 shutdown_tx,
161 poll_thread,
162 })
163 }
164}
165
166fn polling_loop(
187 path: PathBuf,
188 commits_tx: Sender<()>,
189 shutdown_rx: Receiver<()>,
190 poll_interval: Duration,
191) {
192 let sidecar_conn = match Connection::open(&path) {
194 Ok(c) => c,
195 Err(_) => return,
196 };
197
198 let mut last_version: i64 =
200 match sidecar_conn.query_row("PRAGMA data_version", [], |row| row.get(0)) {
201 Ok(v) => v,
202 Err(_) => return,
203 };
204
205 loop {
206 match shutdown_rx.recv_timeout(poll_interval) {
208 Ok(()) => return, Err(RecvTimeoutError::Disconnected) => return, Err(RecvTimeoutError::Timeout) => {} }
212
213 let version: i64 = match sidecar_conn.query_row("PRAGMA data_version", [], |row| row.get(0))
215 {
216 Ok(v) => v,
217 Err(_) => return,
218 };
219
220 if version != last_version {
221 last_version = version;
222
223 if commits_tx.send(()).is_err() {
225 return;
226 }
227 }
228 }
229}
230
231impl StorageBackend for SqliteBackend {
232 fn load_all(&self) -> Result<HashMap<String, StoredValue>, BackendError> {
233 let conn = self.conn.lock().unwrap();
235
236 let sql = "SELECT key, value FROM settings";
238 let mut stmt = conn
239 .prepare(sql)
240 .map_err(|e| BackendError::Read(e.to_string()))?;
241 let rows = stmt
242 .query_map([], |row| {
243 let key = row.get(0)?;
244 let raw = row.get(1)?;
245 Ok((key, StoredValue::from_raw(raw)))
246 })
247 .map_err(|e| BackendError::Read(e.to_string()))?;
248
249 let mut result = HashMap::new();
251 for row in rows {
252 let (k, v) = row.map_err(|e| BackendError::Read(e.to_string()))?;
253 result.insert(k, v);
254 }
255
256 Ok(result)
257 }
258
259 fn set(&self, key: &str, value: &StoredValue) -> Result<(), BackendError> {
260 let conn = self.conn.lock().unwrap();
262
263 let now = SystemTime::now()
265 .duration_since(UNIX_EPOCH)
266 .unwrap_or_default()
267 .as_secs() as i64;
268 let sql = "INSERT OR REPLACE INTO settings (key, value, updated_at) VALUES (?, ?, ?)";
269 conn.execute(sql, params![key, value.as_str(), now])
270 .map_err(|e| BackendError::Write(e.to_string()))?;
271
272 Ok(())
273 }
274
275 fn delete(&self, key: &str) -> Result<(), BackendError> {
276 let conn = self.conn.lock().unwrap();
278
279 let sql = "DELETE FROM settings WHERE key = ?";
281 conn.execute(sql, params![key])
282 .map_err(|e| BackendError::Write(e.to_string()))?;
283
284 Ok(())
285 }
286
287 fn watch_changes(&self) -> Option<Receiver<()>> {
288 self.commits_rx.clone()
289 }
290}
291
292impl Drop for SqliteBackend {
293 fn drop(&mut self) {
294 if let Some(shutdown_tx) = &self.shutdown_tx {
296 let _ = shutdown_tx.send(());
297 }
298
299 if let Some(handle) = self.poll_thread.take() {
301 let _ = handle.join();
302 }
303 }
304}
305
306#[cfg(test)]
307mod test {
308 use super::*;
309 use rusqlite::OptionalExtension;
310
311 #[test]
312 fn test_open_correctly_inits_sqlite_db() {
313 let sqlite_be = SqliteBackend::open(":memory:").unwrap();
314 let conn = sqlite_be.conn.lock().unwrap();
315
316 let mut stmt = conn
318 .prepare("SELECT name FROM sqlite_master WHERE type='table' AND name='settings'")
319 .unwrap();
320 let result: Option<String> = stmt.query_row([], |row| row.get(0)).optional().unwrap();
321 assert!(result.is_some());
322 let result = result.unwrap();
323 assert_eq!(result, "settings");
324 }
325
326 #[test]
327 fn test_load_all_returns_a_hashmap_with_no_values() {
328 let sqlite_be = SqliteBackend::open(":memory:").unwrap();
329 let stored_values = sqlite_be.load_all().unwrap();
330 assert_eq!(stored_values.len(), 0)
331 }
332
333 #[test]
334 fn test_set_successfully_writes_a_setting_to_the_db() {
335 let key = "theme";
336 let sqlite_be = SqliteBackend::open(":memory:").unwrap();
337 let sv = StoredValue::encode(&"dark").unwrap();
338 sqlite_be.set(&key, &sv).unwrap();
339
340 let stored_values = sqlite_be.load_all().unwrap();
341 let loaded = stored_values.get::<str>(&key).unwrap();
342 assert_eq!(stored_values.len(), 1);
343 assert_eq!(sv.as_str(), loaded.as_str())
344 }
345
346 #[test]
347 fn test_delete_successfully_removes_a_setting_from_the_db() {
348 let key = "theme";
350 let sqlite_be = SqliteBackend::open(":memory:").unwrap();
351 let sv = StoredValue::encode(&"dark").unwrap();
352 sqlite_be.set(&key, &sv).unwrap();
353
354 let stored_values = sqlite_be.load_all().unwrap();
356 assert_eq!(stored_values.len(), 1);
357
358 sqlite_be.delete(&key).unwrap();
360 let stored_values = sqlite_be.load_all().unwrap();
361 assert_eq!(stored_values.len(), 0);
362 }
363
364 #[test]
365 fn test_watch_changes_signals_on_external_commit() {
366 let tmp = tempfile::NamedTempFile::new().unwrap();
367 let path = tmp.path().to_path_buf();
368
369 let backend = SqliteBackend::open(&path).unwrap();
371 let rx = backend
372 .watch_changes()
373 .expect("polling thread should be running");
374
375 let other = rusqlite::Connection::open(&path).unwrap();
377 other
378 .execute(
379 "INSERT INTO settings (key, value, updated_at) VALUES ('theme', '\"dark\"', 0)",
380 [],
381 )
382 .unwrap();
383
384 match rx.recv_timeout(Duration::from_millis(3000)) {
386 Ok(()) => {}
387 Err(e) => panic!("expected a change signal within 3s, got {:?}", e),
388 }
389 }
390
391 #[test]
392 fn test_watch_changes_times_out() {
393 let tmp = tempfile::NamedTempFile::new().unwrap();
394 let path = tmp.path().to_path_buf();
395
396 let backend = SqliteBackend::open(&path).unwrap();
398 let rx = backend
399 .watch_changes()
400 .expect("polling thread should be running");
401
402 match rx.recv_timeout(Duration::from_millis(1)) {
404 Ok(()) => panic!("did not expect a signal"),
405 Err(_e) => {}
406 }
407 }
408
409 #[test]
410 fn test_open_with_options_can_disable_watch_changes() {
411 let tmp = tempfile::NamedTempFile::new().unwrap();
412 let path = tmp.path().to_path_buf();
413 let options = SqliteOptions {
414 cross_process: false,
415 ..SqliteOptions::default()
416 };
417
418 let backend = SqliteBackend::open_with_options(&path, options).unwrap();
419
420 assert!(backend.watch_changes().is_none());
421 }
422
423 #[test]
424 fn test_open_with_options_applies_pragmas() {
425 let tmp = tempfile::NamedTempFile::new().unwrap();
426 let path = tmp.path().to_path_buf();
427 let options = SqliteOptions {
428 cross_process: false,
429 journal_mode: SqliteJournalMode::Delete,
430 busy_timeout: Duration::from_millis(1234),
431 synchronous: SqliteSynchronous::Full,
432 ..SqliteOptions::default()
433 };
434
435 let backend = SqliteBackend::open_with_options(&path, options).unwrap();
436 let conn = backend.conn.lock().unwrap();
437
438 let journal_mode: String = conn
439 .query_row("PRAGMA journal_mode", [], |row| row.get(0))
440 .unwrap();
441 let busy_timeout: i64 = conn
442 .query_row("PRAGMA busy_timeout", [], |row| row.get(0))
443 .unwrap();
444 let synchronous: i64 = conn
445 .query_row("PRAGMA synchronous", [], |row| row.get(0))
446 .unwrap();
447
448 assert_eq!(journal_mode, "delete");
449 assert_eq!(busy_timeout, 1234);
450 assert_eq!(synchronous, 2);
451 }
452}