Skip to main content

aimdb_persistence_sqlite/
lib.rs

1//! # aimdb-persistence-sqlite
2//!
3//! SQLite persistence backend for AimDB.
4//!
5//! Owns a dedicated OS thread that holds the `rusqlite::Connection`. All async
6//! callers send [`DbCommand`] messages via `std::sync::mpsc::sync_channel` and
7//! await a `tokio::sync::oneshot` reply. The async executor is never blocked;
8//! the writer thread is never awaited.
9//!
10//! **Runtime requirement:** This crate requires a Tokio runtime for the
11//! `oneshot` reply channel. Do **not** use `SqliteBackend` with the Embassy
12//! adapter — it will not compile without a Tokio executor.
13//!
14//! # Example
15//!
16//! ```rust,ignore
17//! use aimdb_persistence_sqlite::SqliteBackend;
18//! use std::sync::Arc;
19//!
20//! let backend = Arc::new(SqliteBackend::new("./data/history.db")?);
21//! ```
22
23use std::path::Path;
24
25use aimdb_persistence::backend::{BoxFuture, PersistenceBackend, QueryParams, StoredValue};
26use aimdb_persistence::error::PersistenceError;
27use rusqlite::{params, Connection};
28use serde_json::Value;
29
30// ---------------------------------------------------------------------------
31// Command enum — sent from async callers to the writer thread
32// ---------------------------------------------------------------------------
33
34enum DbCommand {
35    Store {
36        record_name: String,
37        json: String,
38        timestamp: u64,
39        reply: tokio::sync::oneshot::Sender<Result<(), PersistenceError>>,
40    },
41    Query {
42        pattern: String,
43        params: QueryParams,
44        reply: tokio::sync::oneshot::Sender<Result<Vec<StoredValue>, PersistenceError>>,
45    },
46    Cleanup {
47        older_than: u64,
48        reply: tokio::sync::oneshot::Sender<Result<u64, PersistenceError>>,
49    },
50}
51
52// ---------------------------------------------------------------------------
53// SqliteBackend — the public API
54// ---------------------------------------------------------------------------
55
56/// SQLite persistence backend.
57///
58/// `Clone` is cheap — it only clones the `mpsc::SyncSender` handle.
59///
60/// The writer thread shuts down automatically when all `SyncSender` handles
61/// (i.e. all `SqliteBackend` clones) are dropped.
62#[derive(Clone)]
63pub struct SqliteBackend {
64    tx: std::sync::mpsc::SyncSender<DbCommand>,
65}
66
67impl SqliteBackend {
68    /// Opens (or creates) a SQLite database at `path` and starts the writer thread.
69    ///
70    /// Schema and WAL mode are configured **synchronously** here — no `block_on`
71    /// needed, no async runtime required at construction time.
72    pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, PersistenceError> {
73        // Bound of 64 provides backpressure without being too aggressive.
74        let (tx, rx) = std::sync::mpsc::sync_channel::<DbCommand>(64);
75
76        let conn = Connection::open(path).map_err(|e| PersistenceError::Backend(e.to_string()))?;
77
78        // Enable WAL mode: readers and the single writer proceed concurrently.
79        conn.pragma_update(None, "journal_mode", "WAL")
80            .map_err(|e| PersistenceError::Backend(e.to_string()))?;
81
82        // Initialize schema before the writer thread is spawned.
83        conn.execute_batch(
84            "CREATE TABLE IF NOT EXISTS record_history (
85                id          INTEGER PRIMARY KEY AUTOINCREMENT,
86                record_name TEXT    NOT NULL,
87                value_json  TEXT    NOT NULL,
88                stored_at   INTEGER NOT NULL
89            );
90            CREATE INDEX IF NOT EXISTS idx_record_time
91                ON record_history(record_name, stored_at DESC);",
92        )
93        .map_err(|e| PersistenceError::Backend(e.to_string()))?;
94
95        std::thread::Builder::new()
96            .name("aimdb-sqlite".to_string())
97            .spawn(move || run_db_thread(conn, rx))
98            .map_err(|e| PersistenceError::Backend(e.to_string()))?;
99
100        Ok(Self { tx })
101    }
102}
103
104// ---------------------------------------------------------------------------
105// Writer thread — blocking event loop
106// ---------------------------------------------------------------------------
107
108fn run_db_thread(conn: Connection, rx: std::sync::mpsc::Receiver<DbCommand>) {
109    while let Ok(cmd) = rx.recv() {
110        match cmd {
111            DbCommand::Store {
112                record_name,
113                json,
114                timestamp,
115                reply,
116            } => {
117                let ts = match i64::try_from(timestamp) {
118                    Ok(v) => v,
119                    Err(_) => {
120                        let _ = reply.send(Err(PersistenceError::Backend(format!(
121                            "timestamp {timestamp} overflows i64"
122                        ))));
123                        continue;
124                    }
125                };
126                let result = conn
127                    .prepare_cached(
128                        "INSERT INTO record_history (record_name, value_json, stored_at)
129                         VALUES (?1, ?2, ?3)",
130                    )
131                    .and_then(|mut stmt| stmt.execute(params![record_name, json, ts]))
132                    .map(|_| ())
133                    .map_err(|e| PersistenceError::Backend(e.to_string()));
134                let _ = reply.send(result);
135            }
136
137            DbCommand::Query {
138                pattern,
139                params,
140                reply,
141            } => {
142                let result = query_sync(&conn, &pattern, params);
143                let _ = reply.send(result);
144            }
145
146            DbCommand::Cleanup { older_than, reply } => {
147                let cutoff = match i64::try_from(older_than) {
148                    Ok(v) => v,
149                    Err(_) => {
150                        let _ = reply.send(Err(PersistenceError::Backend(format!(
151                            "cleanup cutoff {older_than} overflows i64"
152                        ))));
153                        continue;
154                    }
155                };
156                let result = conn
157                    .prepare_cached("DELETE FROM record_history WHERE stored_at < ?1")
158                    .and_then(|mut stmt| stmt.execute(params![cutoff]))
159                    .map(|n| n as u64)
160                    .map_err(|e| PersistenceError::Backend(e.to_string()));
161                let _ = reply.send(result);
162            }
163        }
164    }
165    // All SyncSender handles dropped → exit cleanly.
166}
167
168// ---------------------------------------------------------------------------
169// SQL helpers
170// ---------------------------------------------------------------------------
171
172/// Escape SQL LIKE special characters, then replace `*` with `%`.
173fn sanitize_pattern(pattern: &str) -> String {
174    pattern
175        .replace('\\', "\\\\")
176        .replace('%', "\\%")
177        .replace('_', "\\_")
178        .replace('*', "%")
179}
180
181fn query_sync(
182    conn: &Connection,
183    pattern: &str,
184    params: QueryParams,
185) -> Result<Vec<StoredValue>, PersistenceError> {
186    // `None` means "no limit" — the SQL uses `(?4 IS NULL OR rn <= ?4)`.
187    let limit: Option<i64> = params
188        .limit_per_record
189        .map(|l| {
190            i64::try_from(l).map_err(|_| {
191                PersistenceError::Backend("limit_per_record overflows i64".to_string())
192            })
193        })
194        .transpose()?;
195    let sql_pattern = sanitize_pattern(pattern);
196
197    // Checked conversion: timestamps must fit in SQLite's signed i64.
198    let start_time: Option<i64> = params
199        .start_time
200        .map(i64::try_from)
201        .transpose()
202        .map_err(|_| PersistenceError::Backend("start_time overflows i64".to_string()))?;
203    let end_time: Option<i64> = params
204        .end_time
205        .map(i64::try_from)
206        .transpose()
207        .map_err(|_| PersistenceError::Backend("end_time overflows i64".to_string()))?;
208
209    let mut stmt = conn
210        .prepare_cached(
211            "WITH ranked AS (
212                SELECT record_name, value_json, stored_at,
213                       ROW_NUMBER() OVER (
214                           PARTITION BY record_name
215                           ORDER BY stored_at DESC, id DESC
216                       ) AS rn
217                FROM record_history
218                WHERE record_name LIKE ?1 ESCAPE '\\'
219                  AND (?2 IS NULL OR stored_at >= ?2)
220                  AND (?3 IS NULL OR stored_at <= ?3)
221            )
222            SELECT record_name, value_json, stored_at
223            FROM ranked WHERE (?4 IS NULL OR rn <= ?4)
224            ORDER BY record_name, stored_at DESC",
225        )
226        .map_err(|e| PersistenceError::Backend(e.to_string()))?;
227
228    let rows = stmt
229        .query_map(
230            rusqlite::params![sql_pattern, start_time, end_time, limit],
231            |row| {
232                let value_str: String = row.get(1)?;
233                Ok(StoredValue {
234                    record_name: row.get(0)?,
235                    value: serde_json::from_str(&value_str).unwrap_or_else(|e| {
236                        #[cfg(feature = "tracing")]
237                        tracing::warn!(
238                            "SQLite: corrupted JSON in record_history row, \
239                             substituting null: {e}"
240                        );
241                        #[cfg(not(feature = "tracing"))]
242                        let _ = e;
243                        Value::Null
244                    }),
245                    stored_at: row.get::<_, i64>(2).map(|v| v.max(0) as u64)?,
246                })
247            },
248        )
249        .map_err(|e| PersistenceError::Backend(e.to_string()))?;
250
251    rows.collect::<Result<Vec<_>, _>>()
252        .map_err(|e| PersistenceError::Backend(e.to_string()))
253}
254
255// ---------------------------------------------------------------------------
256// send_cmd! macro — enqueue + await oneshot
257// ---------------------------------------------------------------------------
258
259macro_rules! send_cmd {
260    ($tx:expr, $cmd:expr) => {{
261        let (reply_tx, reply_rx) = tokio::sync::oneshot::channel();
262        $tx.send($cmd(reply_tx))
263            .map_err(|_| PersistenceError::BackendShutdown)?;
264        reply_rx
265            .await
266            .map_err(|_| PersistenceError::BackendShutdown)?
267    }};
268}
269
270// ---------------------------------------------------------------------------
271// PersistenceBackend impl
272// ---------------------------------------------------------------------------
273
274impl PersistenceBackend for SqliteBackend {
275    // initialize() — uses trait default (no-op); schema was created in ::new().
276
277    fn store<'a>(
278        &'a self,
279        record_name: &'a str,
280        value: &'a Value,
281        timestamp: u64,
282    ) -> BoxFuture<'a, Result<(), PersistenceError>> {
283        let json = match serde_json::to_string(value) {
284            Ok(j) => j,
285            Err(e) => return Box::pin(async move { Err(PersistenceError::from(e)) }),
286        };
287        let record_name = record_name.to_string();
288        let tx = self.tx.clone();
289        Box::pin(async move {
290            send_cmd!(tx, |reply| DbCommand::Store {
291                record_name,
292                json,
293                timestamp,
294                reply,
295            })
296        })
297    }
298
299    fn query<'a>(
300        &'a self,
301        record_pattern: &'a str,
302        params: QueryParams,
303    ) -> BoxFuture<'a, Result<Vec<StoredValue>, PersistenceError>> {
304        let pattern = record_pattern.to_string();
305        let tx = self.tx.clone();
306        Box::pin(async move {
307            send_cmd!(tx, |reply| DbCommand::Query {
308                pattern,
309                params,
310                reply,
311            })
312        })
313    }
314
315    fn cleanup(&self, older_than: u64) -> BoxFuture<'_, Result<u64, PersistenceError>> {
316        let tx = self.tx.clone();
317        Box::pin(async move { send_cmd!(tx, |reply| DbCommand::Cleanup { older_than, reply }) })
318    }
319}
320
321// ---------------------------------------------------------------------------
322// Tests
323// ---------------------------------------------------------------------------
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    #[tokio::test]
330    async fn test_store_and_query() {
331        let dir = tempfile::tempdir().unwrap();
332        let db_path = dir.path().join("test.db");
333        let backend = SqliteBackend::new(&db_path).unwrap();
334
335        // Store a value
336        let value = serde_json::json!({"celsius": 21.5, "city": "vienna"});
337        backend.store("temp::vienna", &value, 1000).await.unwrap();
338        backend.store("temp::vienna", &value, 2000).await.unwrap();
339        backend.store("temp::berlin", &value, 1500).await.unwrap();
340
341        // Query latest 1 per record with wildcard
342        let results = backend
343            .query(
344                "temp::*",
345                QueryParams {
346                    limit_per_record: Some(1),
347                    ..Default::default()
348                },
349            )
350            .await
351            .unwrap();
352
353        assert_eq!(results.len(), 2); // 1 per city
354        assert!(results.iter().any(|r| r.record_name == "temp::vienna"));
355        assert!(results.iter().any(|r| r.record_name == "temp::berlin"));
356
357        // The vienna result should be the latest (timestamp 2000)
358        let vienna = results
359            .iter()
360            .find(|r| r.record_name == "temp::vienna")
361            .unwrap();
362        assert_eq!(vienna.stored_at, 2000);
363    }
364
365    #[tokio::test]
366    async fn test_time_range_query() {
367        let dir = tempfile::tempdir().unwrap();
368        let db_path = dir.path().join("test_range.db");
369        let backend = SqliteBackend::new(&db_path).unwrap();
370
371        let value = serde_json::json!({"celsius": 20.0});
372        for ts in [1000u64, 2000, 3000, 4000, 5000] {
373            backend.store("temp::vienna", &value, ts).await.unwrap();
374        }
375
376        let results = backend
377            .query(
378                "temp::vienna",
379                QueryParams {
380                    start_time: Some(2000),
381                    end_time: Some(4000),
382                    ..Default::default()
383                },
384            )
385            .await
386            .unwrap();
387
388        assert_eq!(results.len(), 3); // timestamps 2000, 3000, 4000
389    }
390
391    #[tokio::test]
392    async fn test_cleanup() {
393        let dir = tempfile::tempdir().unwrap();
394        let db_path = dir.path().join("test_cleanup.db");
395        let backend = SqliteBackend::new(&db_path).unwrap();
396
397        let value = serde_json::json!({"celsius": 20.0});
398        backend.store("temp::a", &value, 1000).await.unwrap();
399        backend.store("temp::b", &value, 2000).await.unwrap();
400        backend.store("temp::c", &value, 3000).await.unwrap();
401
402        // Delete rows older than 2500
403        let deleted = backend.cleanup(2500).await.unwrap();
404        assert_eq!(deleted, 2); // 1000 and 2000
405
406        // Only the 3000 row remains
407        let results = backend
408            .query(
409                "temp::*",
410                QueryParams {
411                    ..Default::default()
412                },
413            )
414            .await
415            .unwrap();
416        assert_eq!(results.len(), 1);
417        assert_eq!(results[0].stored_at, 3000);
418    }
419
420    #[tokio::test]
421    async fn test_pattern_escaping() {
422        let dir = tempfile::tempdir().unwrap();
423        let db_path = dir.path().join("test_escape.db");
424        let backend = SqliteBackend::new(&db_path).unwrap();
425
426        let value = serde_json::json!({"ok": true});
427        backend.store("test_record", &value, 1000).await.unwrap();
428        backend.store("testXrecord", &value, 1000).await.unwrap();
429
430        // Exact match — the `_` in the record name should NOT match `X`
431        let results = backend
432            .query(
433                "test_record",
434                QueryParams {
435                    ..Default::default()
436                },
437            )
438            .await
439            .unwrap();
440        assert_eq!(results.len(), 1);
441        assert_eq!(results[0].record_name, "test_record");
442    }
443}