Skip to main content

attune_sqlite/
backend.rs

1use std::{
2    collections::HashMap,
3    path::{Path, PathBuf},
4    sync::Mutex,
5    thread::{self, JoinHandle},
6    time::{Duration, SystemTime, UNIX_EPOCH},
7};
8
9use attune_core::{BackendError, StorageBackend, StoredValue};
10use crossbeam_channel::{Receiver, RecvTimeoutError, Sender, unbounded};
11use rusqlite::{Connection, params};
12
13const POLL_INTERVAL: Duration = Duration::from_millis(1000);
14
15#[derive(Clone, Debug, Eq, PartialEq)]
16pub struct SqliteOptions {
17    pub cross_process: bool,
18    pub poll_interval: Duration,
19    pub journal_mode: SqliteJournalMode,
20    pub busy_timeout: Duration,
21    pub synchronous: SqliteSynchronous,
22}
23
24impl Default for SqliteOptions {
25    fn default() -> Self {
26        Self {
27            cross_process: true,
28            poll_interval: POLL_INTERVAL,
29            journal_mode: SqliteJournalMode::Wal,
30            busy_timeout: Duration::from_millis(5000),
31            synchronous: SqliteSynchronous::Normal,
32        }
33    }
34}
35
36#[derive(Clone, Copy, Debug, Eq, PartialEq)]
37pub enum SqliteJournalMode {
38    Wal,
39    Delete,
40    Truncate,
41    Persist,
42    Memory,
43    Off,
44}
45
46impl SqliteJournalMode {
47    fn as_pragma(self) -> &'static str {
48        match self {
49            Self::Wal => "WAL",
50            Self::Delete => "DELETE",
51            Self::Truncate => "TRUNCATE",
52            Self::Persist => "PERSIST",
53            Self::Memory => "MEMORY",
54            Self::Off => "OFF",
55        }
56    }
57}
58
59#[derive(Clone, Copy, Debug, Eq, PartialEq)]
60pub enum SqliteSynchronous {
61    Off,
62    Normal,
63    Full,
64    Extra,
65}
66
67impl SqliteSynchronous {
68    fn as_pragma(self) -> &'static str {
69        match self {
70            Self::Off => "OFF",
71            Self::Normal => "NORMAL",
72            Self::Full => "FULL",
73            Self::Extra => "EXTRA",
74        }
75    }
76}
77
78pub struct SqliteBackend {
79    conn: Mutex<Connection>,
80    commits_rx: Option<Receiver<()>>,
81    shutdown_tx: Option<Sender<()>>,
82    poll_thread: Option<JoinHandle<()>>,
83}
84
85impl SqliteBackend {
86    /// Opens a connection to the SQLite database.
87    ///
88    /// Uses [`SqliteOptions::default`] for SQLite pragmas and cross-process
89    /// change detection.
90    ///
91    /// ## Errors
92    ///
93    /// Returns [`BackendError::Open`] when the database cannot be opened,
94    /// SQLite pragmas cannot be applied, or the settings table cannot be
95    /// created.
96    pub fn open(path: impl AsRef<Path>) -> Result<Self, BackendError> {
97        Self::open_with_options(path, SqliteOptions::default())
98    }
99
100    /// Opens a connection to the SQLite database with explicit backend options.
101    ///
102    /// Applies the configured SQLite pragmas, creates the settings table, and
103    /// starts the cross-process watcher when `options.cross_process` is `true`.
104    ///
105    /// ## Errors
106    ///
107    /// Returns [`BackendError::Open`] when the database cannot be opened,
108    /// SQLite pragmas cannot be applied, or the settings table cannot be
109    /// created.
110    pub fn open_with_options(
111        path: impl AsRef<Path>,
112        options: SqliteOptions,
113    ) -> Result<Self, BackendError> {
114        let path = path.as_ref().to_path_buf();
115
116        // 1. Create the SQLite connection.
117        let conn =
118            rusqlite::Connection::open(&path).map_err(|e| BackendError::Open(e.to_string()))?;
119
120        // 2. Set PRAGMAs.
121        let pragmas = format!(
122            "PRAGMA journal_mode = {};\
123            PRAGMA busy_timeout = {};\
124            PRAGMA synchronous = {};\
125            PRAGMA foreign_keys = ON;",
126            options.journal_mode.as_pragma(),
127            options.busy_timeout.as_millis(),
128            options.synchronous.as_pragma(),
129        );
130        conn.execute_batch(&pragmas)
131            .map_err(|e| BackendError::Open(e.to_string()))?;
132
133        // 3. Create settings table.
134        let settings_table_sql = "CREATE TABLE IF NOT EXISTS settings (
135            key TEXT PRIMARY KEY NOT NULL,
136            value TEXT NOT NULL,
137            updated_at INTEGER NOT NULL
138        )";
139        conn.execute(settings_table_sql, [])
140            .map_err(|e| BackendError::Open(e.to_string()))?;
141
142        // 4. Setup optional sidecar thread for cross-process change detection.
143        let (commits_rx, shutdown_tx, poll_thread) = if options.cross_process {
144            let (commits_tx, commits_rx) = unbounded::<()>();
145            let (shutdown_tx, shutdown_rx) = unbounded::<()>();
146            let sidecar_path = path.clone();
147            let poll_interval = options.poll_interval;
148            let poll_thread = thread::spawn(move || {
149                polling_loop(sidecar_path, commits_tx, shutdown_rx, poll_interval);
150            });
151
152            (Some(commits_rx), Some(shutdown_tx), Some(poll_thread))
153        } else {
154            (None, None, None)
155        };
156
157        Ok(SqliteBackend {
158            conn: Mutex::new(conn),
159            commits_rx,
160            shutdown_tx,
161            poll_thread,
162        })
163    }
164}
165
166/// Run the cross-process change-detection loop on a sidecar SQLite connection.
167///
168/// Opens its own [`Connection`] to `path`, separate from the writer connection,
169/// then polls `PRAGMA data_version` every `poll_interval`. Whenever the
170/// version counter ticks (signalling that some connection committed to the
171/// database), pushes `()` onto `commits_tx`. Own-process commits also tick the
172/// counter; the diff loop in `attune-core` is responsible for dedup.
173///
174/// Returns silently, never panics, and never returns a `Result`. The loop exits
175/// when any of these conditions occur:
176///
177/// - A shutdown value is received on `shutdown_rx` (clean shutdown requested
178///   by the owning [`SqliteBackend`] being dropped).
179/// - `shutdown_rx` becomes disconnected because its sender was dropped (also
180///   indicates the backend has been dropped).
181/// - `commits_tx` becomes disconnected because no receiver is listening for
182///   change signals.
183/// - Opening the sidecar connection fails, or any `PRAGMA data_version` query
184///   fails. Cross-process detection degrades gracefully on storage errors
185///   rather than propagating.
186fn polling_loop(
187    path: PathBuf,
188    commits_tx: Sender<()>,
189    shutdown_rx: Receiver<()>,
190    poll_interval: Duration,
191) {
192    // 1. Open the sidecar connection.
193    let sidecar_conn = match Connection::open(&path) {
194        Ok(c) => c,
195        Err(_) => return,
196    };
197
198    // 2. Get the initial data version.
199    let mut last_version: i64 =
200        match sidecar_conn.query_row("PRAGMA data_version", [], |row| row.get(0)) {
201            Ok(v) => v,
202            Err(_) => return,
203        };
204
205    loop {
206        // Wait up to the poll interval for a shutdown signal.
207        match shutdown_rx.recv_timeout(poll_interval) {
208            Ok(()) => return,                              // Shutdown requested.
209            Err(RecvTimeoutError::Disconnected) => return, // Sender dropped. (backend dropped)
210            Err(RecvTimeoutError::Timeout) => {}           // Time to poll.
211        }
212
213        // Check the current data version.
214        let version: i64 = match sidecar_conn.query_row("PRAGMA data_version", [], |row| row.get(0))
215        {
216            Ok(v) => v,
217            Err(_) => return,
218        };
219
220        if version != last_version {
221            last_version = version;
222
223            // If the send fails, no one is listening. Silently exit.
224            if commits_tx.send(()).is_err() {
225                return;
226            }
227        }
228    }
229}
230
231impl StorageBackend for SqliteBackend {
232    fn load_all(&self) -> Result<HashMap<String, StoredValue>, BackendError> {
233        // 1. Obtain lock to the DB connection.
234        let conn = self.conn.lock().unwrap();
235
236        // 2. Query DB for settings and deserialize values into StoredValue.
237        let sql = "SELECT key, value FROM settings";
238        let mut stmt = conn
239            .prepare(sql)
240            .map_err(|e| BackendError::Read(e.to_string()))?;
241        let rows = stmt
242            .query_map([], |row| {
243                let key = row.get(0)?;
244                let raw = row.get(1)?;
245                Ok((key, StoredValue::from_raw(raw)))
246            })
247            .map_err(|e| BackendError::Read(e.to_string()))?;
248
249        // 3. Add deserialized values to HashMap.
250        let mut result = HashMap::new();
251        for row in rows {
252            let (k, v) = row.map_err(|e| BackendError::Read(e.to_string()))?;
253            result.insert(k, v);
254        }
255
256        Ok(result)
257    }
258
259    fn set(&self, key: &str, value: &StoredValue) -> Result<(), BackendError> {
260        // 1. Obtain lock to the DB connection.
261        let conn = self.conn.lock().unwrap();
262
263        // 2. Write setting to DB.
264        let now = SystemTime::now()
265            .duration_since(UNIX_EPOCH)
266            .unwrap_or_default()
267            .as_secs() as i64;
268        let sql = "INSERT OR REPLACE INTO settings (key, value, updated_at) VALUES (?, ?, ?)";
269        conn.execute(sql, params![key, value.as_str(), now])
270            .map_err(|e| BackendError::Write(e.to_string()))?;
271
272        Ok(())
273    }
274
275    fn delete(&self, key: &str) -> Result<(), BackendError> {
276        // 1. Obtain lock to the DB connection.
277        let conn = self.conn.lock().unwrap();
278
279        // 2. Delete setting from DB.
280        let sql = "DELETE FROM settings WHERE key = ?";
281        conn.execute(sql, params![key])
282            .map_err(|e| BackendError::Write(e.to_string()))?;
283
284        Ok(())
285    }
286
287    fn watch_changes(&self) -> Option<Receiver<()>> {
288        self.commits_rx.clone()
289    }
290}
291
292impl Drop for SqliteBackend {
293    fn drop(&mut self) {
294        // Signal the polling thread to exit.
295        if let Some(shutdown_tx) = &self.shutdown_tx {
296            let _ = shutdown_tx.send(());
297        }
298
299        // Take the join handle out of Option and move it into `.join()`.
300        if let Some(handle) = self.poll_thread.take() {
301            let _ = handle.join();
302        }
303    }
304}
305
306#[cfg(test)]
307mod test {
308    use super::*;
309    use rusqlite::OptionalExtension;
310
311    #[test]
312    fn test_open_correctly_inits_sqlite_db() {
313        let sqlite_be = SqliteBackend::open(":memory:").unwrap();
314        let conn = sqlite_be.conn.lock().unwrap();
315
316        // Assert that the settings table is created.
317        let mut stmt = conn
318            .prepare("SELECT name FROM sqlite_master WHERE type='table' AND name='settings'")
319            .unwrap();
320        let result: Option<String> = stmt.query_row([], |row| row.get(0)).optional().unwrap();
321        assert!(result.is_some());
322        let result = result.unwrap();
323        assert_eq!(result, "settings");
324    }
325
326    #[test]
327    fn test_load_all_returns_a_hashmap_with_no_values() {
328        let sqlite_be = SqliteBackend::open(":memory:").unwrap();
329        let stored_values = sqlite_be.load_all().unwrap();
330        assert_eq!(stored_values.len(), 0)
331    }
332
333    #[test]
334    fn test_set_successfully_writes_a_setting_to_the_db() {
335        let key = "theme";
336        let sqlite_be = SqliteBackend::open(":memory:").unwrap();
337        let sv = StoredValue::encode(&"dark").unwrap();
338        sqlite_be.set(&key, &sv).unwrap();
339
340        let stored_values = sqlite_be.load_all().unwrap();
341        let loaded = stored_values.get::<str>(&key).unwrap();
342        assert_eq!(stored_values.len(), 1);
343        assert_eq!(sv.as_str(), loaded.as_str())
344    }
345
346    #[test]
347    fn test_delete_successfully_removes_a_setting_from_the_db() {
348        // 1. Write setting.
349        let key = "theme";
350        let sqlite_be = SqliteBackend::open(":memory:").unwrap();
351        let sv = StoredValue::encode(&"dark").unwrap();
352        sqlite_be.set(&key, &sv).unwrap();
353
354        // 2. Ensure setting persists.
355        let stored_values = sqlite_be.load_all().unwrap();
356        assert_eq!(stored_values.len(), 1);
357
358        // 3. Remove setting and ensure it's no longer in the DB.
359        sqlite_be.delete(&key).unwrap();
360        let stored_values = sqlite_be.load_all().unwrap();
361        assert_eq!(stored_values.len(), 0);
362    }
363
364    #[test]
365    fn test_watch_changes_signals_on_external_commit() {
366        let tmp = tempfile::NamedTempFile::new().unwrap();
367        let path = tmp.path().to_path_buf();
368
369        // Open the backend
370        let backend = SqliteBackend::open(&path).unwrap();
371        let rx = backend
372            .watch_changes()
373            .expect("polling thread should be running");
374
375        // Write through a second backend connection. This simulates an external process.
376        let other = rusqlite::Connection::open(&path).unwrap();
377        other
378            .execute(
379                "INSERT INTO settings (key, value, updated_at) VALUES ('theme', '\"dark\"', 0)",
380                [],
381            )
382            .unwrap();
383
384        // Detect the commit through the polling thread.
385        match rx.recv_timeout(Duration::from_millis(3000)) {
386            Ok(()) => {}
387            Err(e) => panic!("expected a change signal within 3s, got {:?}", e),
388        }
389    }
390
391    #[test]
392    fn test_watch_changes_times_out() {
393        let tmp = tempfile::NamedTempFile::new().unwrap();
394        let path = tmp.path().to_path_buf();
395
396        // Open the backend
397        let backend = SqliteBackend::open(&path).unwrap();
398        let rx = backend
399            .watch_changes()
400            .expect("polling thread should be running");
401
402        // Detect the commit through the polling thread.
403        match rx.recv_timeout(Duration::from_millis(1)) {
404            Ok(()) => panic!("did not expect a signal"),
405            Err(_e) => {}
406        }
407    }
408
409    #[test]
410    fn test_open_with_options_can_disable_watch_changes() {
411        let tmp = tempfile::NamedTempFile::new().unwrap();
412        let path = tmp.path().to_path_buf();
413        let options = SqliteOptions {
414            cross_process: false,
415            ..SqliteOptions::default()
416        };
417
418        let backend = SqliteBackend::open_with_options(&path, options).unwrap();
419
420        assert!(backend.watch_changes().is_none());
421    }
422
423    #[test]
424    fn test_open_with_options_applies_pragmas() {
425        let tmp = tempfile::NamedTempFile::new().unwrap();
426        let path = tmp.path().to_path_buf();
427        let options = SqliteOptions {
428            cross_process: false,
429            journal_mode: SqliteJournalMode::Delete,
430            busy_timeout: Duration::from_millis(1234),
431            synchronous: SqliteSynchronous::Full,
432            ..SqliteOptions::default()
433        };
434
435        let backend = SqliteBackend::open_with_options(&path, options).unwrap();
436        let conn = backend.conn.lock().unwrap();
437
438        let journal_mode: String = conn
439            .query_row("PRAGMA journal_mode", [], |row| row.get(0))
440            .unwrap();
441        let busy_timeout: i64 = conn
442            .query_row("PRAGMA busy_timeout", [], |row| row.get(0))
443            .unwrap();
444        let synchronous: i64 = conn
445            .query_row("PRAGMA synchronous", [], |row| row.get(0))
446            .unwrap();
447
448        assert_eq!(journal_mode, "delete");
449        assert_eq!(busy_timeout, 1234);
450        assert_eq!(synchronous, 2);
451    }
452}