use std::{
collections::HashMap,
path::{Path, PathBuf},
sync::Mutex,
thread::{self, JoinHandle},
time::{Duration, SystemTime, UNIX_EPOCH},
};
use attune_core::{BackendError, StorageBackend, StoredValue};
use crossbeam_channel::{Receiver, RecvTimeoutError, Sender, unbounded};
use rusqlite::{Connection, params};
const POLL_INTERVAL: Duration = Duration::from_millis(1000);
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct SqliteOptions {
pub cross_process: bool,
pub poll_interval: Duration,
pub journal_mode: SqliteJournalMode,
pub busy_timeout: Duration,
pub synchronous: SqliteSynchronous,
}
impl Default for SqliteOptions {
fn default() -> Self {
Self {
cross_process: true,
poll_interval: POLL_INTERVAL,
journal_mode: SqliteJournalMode::Wal,
busy_timeout: Duration::from_millis(5000),
synchronous: SqliteSynchronous::Normal,
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum SqliteJournalMode {
Wal,
Delete,
Truncate,
Persist,
Memory,
Off,
}
impl SqliteJournalMode {
fn as_pragma(self) -> &'static str {
match self {
Self::Wal => "WAL",
Self::Delete => "DELETE",
Self::Truncate => "TRUNCATE",
Self::Persist => "PERSIST",
Self::Memory => "MEMORY",
Self::Off => "OFF",
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum SqliteSynchronous {
Off,
Normal,
Full,
Extra,
}
impl SqliteSynchronous {
fn as_pragma(self) -> &'static str {
match self {
Self::Off => "OFF",
Self::Normal => "NORMAL",
Self::Full => "FULL",
Self::Extra => "EXTRA",
}
}
}
pub struct SqliteBackend {
conn: Mutex<Connection>,
commits_rx: Option<Receiver<()>>,
shutdown_tx: Option<Sender<()>>,
poll_thread: Option<JoinHandle<()>>,
}
impl SqliteBackend {
pub fn open(path: impl AsRef<Path>) -> Result<Self, BackendError> {
Self::open_with_options(path, SqliteOptions::default())
}
pub fn open_with_options(
path: impl AsRef<Path>,
options: SqliteOptions,
) -> Result<Self, BackendError> {
let path = path.as_ref().to_path_buf();
let conn =
rusqlite::Connection::open(&path).map_err(|e| BackendError::Open(e.to_string()))?;
let pragmas = format!(
"PRAGMA journal_mode = {};\
PRAGMA busy_timeout = {};\
PRAGMA synchronous = {};\
PRAGMA foreign_keys = ON;",
options.journal_mode.as_pragma(),
options.busy_timeout.as_millis(),
options.synchronous.as_pragma(),
);
conn.execute_batch(&pragmas)
.map_err(|e| BackendError::Open(e.to_string()))?;
let settings_table_sql = "CREATE TABLE IF NOT EXISTS settings (
key TEXT PRIMARY KEY NOT NULL,
value TEXT NOT NULL,
updated_at INTEGER NOT NULL
)";
conn.execute(settings_table_sql, [])
.map_err(|e| BackendError::Open(e.to_string()))?;
let (commits_rx, shutdown_tx, poll_thread) = if options.cross_process {
let (commits_tx, commits_rx) = unbounded::<()>();
let (shutdown_tx, shutdown_rx) = unbounded::<()>();
let sidecar_path = path.clone();
let poll_interval = options.poll_interval;
let poll_thread = thread::spawn(move || {
polling_loop(sidecar_path, commits_tx, shutdown_rx, poll_interval);
});
(Some(commits_rx), Some(shutdown_tx), Some(poll_thread))
} else {
(None, None, None)
};
Ok(SqliteBackend {
conn: Mutex::new(conn),
commits_rx,
shutdown_tx,
poll_thread,
})
}
}
fn polling_loop(
path: PathBuf,
commits_tx: Sender<()>,
shutdown_rx: Receiver<()>,
poll_interval: Duration,
) {
let sidecar_conn = match Connection::open(&path) {
Ok(c) => c,
Err(_) => return,
};
let mut last_version: i64 =
match sidecar_conn.query_row("PRAGMA data_version", [], |row| row.get(0)) {
Ok(v) => v,
Err(_) => return,
};
loop {
match shutdown_rx.recv_timeout(poll_interval) {
Ok(()) => return, Err(RecvTimeoutError::Disconnected) => return, Err(RecvTimeoutError::Timeout) => {} }
let version: i64 = match sidecar_conn.query_row("PRAGMA data_version", [], |row| row.get(0))
{
Ok(v) => v,
Err(_) => return,
};
if version != last_version {
last_version = version;
if commits_tx.send(()).is_err() {
return;
}
}
}
}
impl StorageBackend for SqliteBackend {
fn load_all(&self) -> Result<HashMap<String, StoredValue>, BackendError> {
let conn = self.conn.lock().unwrap();
let sql = "SELECT key, value FROM settings";
let mut stmt = conn
.prepare(sql)
.map_err(|e| BackendError::Read(e.to_string()))?;
let rows = stmt
.query_map([], |row| {
let key = row.get(0)?;
let raw = row.get(1)?;
Ok((key, StoredValue::from_raw(raw)))
})
.map_err(|e| BackendError::Read(e.to_string()))?;
let mut result = HashMap::new();
for row in rows {
let (k, v) = row.map_err(|e| BackendError::Read(e.to_string()))?;
result.insert(k, v);
}
Ok(result)
}
fn set(&self, key: &str, value: &StoredValue) -> Result<(), BackendError> {
let conn = self.conn.lock().unwrap();
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64;
let sql = "INSERT OR REPLACE INTO settings (key, value, updated_at) VALUES (?, ?, ?)";
conn.execute(sql, params![key, value.as_str(), now])
.map_err(|e| BackendError::Write(e.to_string()))?;
Ok(())
}
fn delete(&self, key: &str) -> Result<(), BackendError> {
let conn = self.conn.lock().unwrap();
let sql = "DELETE FROM settings WHERE key = ?";
conn.execute(sql, params![key])
.map_err(|e| BackendError::Write(e.to_string()))?;
Ok(())
}
fn watch_changes(&self) -> Option<Receiver<()>> {
self.commits_rx.clone()
}
}
impl Drop for SqliteBackend {
fn drop(&mut self) {
if let Some(shutdown_tx) = &self.shutdown_tx {
let _ = shutdown_tx.send(());
}
if let Some(handle) = self.poll_thread.take() {
let _ = handle.join();
}
}
}
#[cfg(test)]
mod test {
use super::*;
use rusqlite::OptionalExtension;
#[test]
fn test_open_correctly_inits_sqlite_db() {
let sqlite_be = SqliteBackend::open(":memory:").unwrap();
let conn = sqlite_be.conn.lock().unwrap();
let mut stmt = conn
.prepare("SELECT name FROM sqlite_master WHERE type='table' AND name='settings'")
.unwrap();
let result: Option<String> = stmt.query_row([], |row| row.get(0)).optional().unwrap();
assert!(result.is_some());
let result = result.unwrap();
assert_eq!(result, "settings");
}
#[test]
fn test_load_all_returns_a_hashmap_with_no_values() {
let sqlite_be = SqliteBackend::open(":memory:").unwrap();
let stored_values = sqlite_be.load_all().unwrap();
assert_eq!(stored_values.len(), 0)
}
#[test]
fn test_set_successfully_writes_a_setting_to_the_db() {
let key = "theme";
let sqlite_be = SqliteBackend::open(":memory:").unwrap();
let sv = StoredValue::encode(&"dark").unwrap();
sqlite_be.set(&key, &sv).unwrap();
let stored_values = sqlite_be.load_all().unwrap();
let loaded = stored_values.get::<str>(&key).unwrap();
assert_eq!(stored_values.len(), 1);
assert_eq!(sv.as_str(), loaded.as_str())
}
#[test]
fn test_delete_successfully_removes_a_setting_from_the_db() {
let key = "theme";
let sqlite_be = SqliteBackend::open(":memory:").unwrap();
let sv = StoredValue::encode(&"dark").unwrap();
sqlite_be.set(&key, &sv).unwrap();
let stored_values = sqlite_be.load_all().unwrap();
assert_eq!(stored_values.len(), 1);
sqlite_be.delete(&key).unwrap();
let stored_values = sqlite_be.load_all().unwrap();
assert_eq!(stored_values.len(), 0);
}
#[test]
fn test_watch_changes_signals_on_external_commit() {
let tmp = tempfile::NamedTempFile::new().unwrap();
let path = tmp.path().to_path_buf();
let backend = SqliteBackend::open(&path).unwrap();
let rx = backend
.watch_changes()
.expect("polling thread should be running");
let other = rusqlite::Connection::open(&path).unwrap();
other
.execute(
"INSERT INTO settings (key, value, updated_at) VALUES ('theme', '\"dark\"', 0)",
[],
)
.unwrap();
match rx.recv_timeout(Duration::from_millis(3000)) {
Ok(()) => {}
Err(e) => panic!("expected a change signal within 3s, got {:?}", e),
}
}
#[test]
fn test_watch_changes_times_out() {
let tmp = tempfile::NamedTempFile::new().unwrap();
let path = tmp.path().to_path_buf();
let backend = SqliteBackend::open(&path).unwrap();
let rx = backend
.watch_changes()
.expect("polling thread should be running");
match rx.recv_timeout(Duration::from_millis(1)) {
Ok(()) => panic!("did not expect a signal"),
Err(_e) => {}
}
}
#[test]
fn test_open_with_options_can_disable_watch_changes() {
let tmp = tempfile::NamedTempFile::new().unwrap();
let path = tmp.path().to_path_buf();
let options = SqliteOptions {
cross_process: false,
..SqliteOptions::default()
};
let backend = SqliteBackend::open_with_options(&path, options).unwrap();
assert!(backend.watch_changes().is_none());
}
#[test]
fn test_open_with_options_applies_pragmas() {
let tmp = tempfile::NamedTempFile::new().unwrap();
let path = tmp.path().to_path_buf();
let options = SqliteOptions {
cross_process: false,
journal_mode: SqliteJournalMode::Delete,
busy_timeout: Duration::from_millis(1234),
synchronous: SqliteSynchronous::Full,
..SqliteOptions::default()
};
let backend = SqliteBackend::open_with_options(&path, options).unwrap();
let conn = backend.conn.lock().unwrap();
let journal_mode: String = conn
.query_row("PRAGMA journal_mode", [], |row| row.get(0))
.unwrap();
let busy_timeout: i64 = conn
.query_row("PRAGMA busy_timeout", [], |row| row.get(0))
.unwrap();
let synchronous: i64 = conn
.query_row("PRAGMA synchronous", [], |row| row.get(0))
.unwrap();
assert_eq!(journal_mode, "delete");
assert_eq!(busy_timeout, 1234);
assert_eq!(synchronous, 2);
}
}