aimdb_persistence_sqlite/
lib.rs1use std::path::Path;
24
25use aimdb_persistence::backend::{BoxFuture, PersistenceBackend, QueryParams, StoredValue};
26use aimdb_persistence::error::PersistenceError;
27use rusqlite::{params, Connection};
28use serde_json::Value;
29
30enum DbCommand {
35 Store {
36 record_name: String,
37 json: String,
38 timestamp: u64,
39 reply: tokio::sync::oneshot::Sender<Result<(), PersistenceError>>,
40 },
41 Query {
42 pattern: String,
43 params: QueryParams,
44 reply: tokio::sync::oneshot::Sender<Result<Vec<StoredValue>, PersistenceError>>,
45 },
46 Cleanup {
47 older_than: u64,
48 reply: tokio::sync::oneshot::Sender<Result<u64, PersistenceError>>,
49 },
50}
51
52#[derive(Clone)]
63pub struct SqliteBackend {
64 tx: std::sync::mpsc::SyncSender<DbCommand>,
65}
66
67impl SqliteBackend {
68 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, PersistenceError> {
73 let (tx, rx) = std::sync::mpsc::sync_channel::<DbCommand>(64);
75
76 let conn = Connection::open(path).map_err(|e| PersistenceError::Backend(e.to_string()))?;
77
78 conn.pragma_update(None, "journal_mode", "WAL")
80 .map_err(|e| PersistenceError::Backend(e.to_string()))?;
81
82 conn.execute_batch(
84 "CREATE TABLE IF NOT EXISTS record_history (
85 id INTEGER PRIMARY KEY AUTOINCREMENT,
86 record_name TEXT NOT NULL,
87 value_json TEXT NOT NULL,
88 stored_at INTEGER NOT NULL
89 );
90 CREATE INDEX IF NOT EXISTS idx_record_time
91 ON record_history(record_name, stored_at DESC);",
92 )
93 .map_err(|e| PersistenceError::Backend(e.to_string()))?;
94
95 std::thread::Builder::new()
96 .name("aimdb-sqlite".to_string())
97 .spawn(move || run_db_thread(conn, rx))
98 .map_err(|e| PersistenceError::Backend(e.to_string()))?;
99
100 Ok(Self { tx })
101 }
102}
103
104fn run_db_thread(conn: Connection, rx: std::sync::mpsc::Receiver<DbCommand>) {
109 while let Ok(cmd) = rx.recv() {
110 match cmd {
111 DbCommand::Store {
112 record_name,
113 json,
114 timestamp,
115 reply,
116 } => {
117 let ts = match i64::try_from(timestamp) {
118 Ok(v) => v,
119 Err(_) => {
120 let _ = reply.send(Err(PersistenceError::Backend(format!(
121 "timestamp {timestamp} overflows i64"
122 ))));
123 continue;
124 }
125 };
126 let result = conn
127 .prepare_cached(
128 "INSERT INTO record_history (record_name, value_json, stored_at)
129 VALUES (?1, ?2, ?3)",
130 )
131 .and_then(|mut stmt| stmt.execute(params![record_name, json, ts]))
132 .map(|_| ())
133 .map_err(|e| PersistenceError::Backend(e.to_string()));
134 let _ = reply.send(result);
135 }
136
137 DbCommand::Query {
138 pattern,
139 params,
140 reply,
141 } => {
142 let result = query_sync(&conn, &pattern, params);
143 let _ = reply.send(result);
144 }
145
146 DbCommand::Cleanup { older_than, reply } => {
147 let cutoff = match i64::try_from(older_than) {
148 Ok(v) => v,
149 Err(_) => {
150 let _ = reply.send(Err(PersistenceError::Backend(format!(
151 "cleanup cutoff {older_than} overflows i64"
152 ))));
153 continue;
154 }
155 };
156 let result = conn
157 .prepare_cached("DELETE FROM record_history WHERE stored_at < ?1")
158 .and_then(|mut stmt| stmt.execute(params![cutoff]))
159 .map(|n| n as u64)
160 .map_err(|e| PersistenceError::Backend(e.to_string()));
161 let _ = reply.send(result);
162 }
163 }
164 }
165 }
167
168fn sanitize_pattern(pattern: &str) -> String {
174 pattern
175 .replace('\\', "\\\\")
176 .replace('%', "\\%")
177 .replace('_', "\\_")
178 .replace('*', "%")
179}
180
181fn query_sync(
182 conn: &Connection,
183 pattern: &str,
184 params: QueryParams,
185) -> Result<Vec<StoredValue>, PersistenceError> {
186 let limit: Option<i64> = params
188 .limit_per_record
189 .map(|l| {
190 i64::try_from(l).map_err(|_| {
191 PersistenceError::Backend("limit_per_record overflows i64".to_string())
192 })
193 })
194 .transpose()?;
195 let sql_pattern = sanitize_pattern(pattern);
196
197 let start_time: Option<i64> = params
199 .start_time
200 .map(i64::try_from)
201 .transpose()
202 .map_err(|_| PersistenceError::Backend("start_time overflows i64".to_string()))?;
203 let end_time: Option<i64> = params
204 .end_time
205 .map(i64::try_from)
206 .transpose()
207 .map_err(|_| PersistenceError::Backend("end_time overflows i64".to_string()))?;
208
209 let mut stmt = conn
210 .prepare_cached(
211 "WITH ranked AS (
212 SELECT record_name, value_json, stored_at,
213 ROW_NUMBER() OVER (
214 PARTITION BY record_name
215 ORDER BY stored_at DESC, id DESC
216 ) AS rn
217 FROM record_history
218 WHERE record_name LIKE ?1 ESCAPE '\\'
219 AND (?2 IS NULL OR stored_at >= ?2)
220 AND (?3 IS NULL OR stored_at <= ?3)
221 )
222 SELECT record_name, value_json, stored_at
223 FROM ranked WHERE (?4 IS NULL OR rn <= ?4)
224 ORDER BY record_name, stored_at DESC",
225 )
226 .map_err(|e| PersistenceError::Backend(e.to_string()))?;
227
228 let rows = stmt
229 .query_map(
230 rusqlite::params![sql_pattern, start_time, end_time, limit],
231 |row| {
232 let value_str: String = row.get(1)?;
233 Ok(StoredValue {
234 record_name: row.get(0)?,
235 value: serde_json::from_str(&value_str).unwrap_or_else(|e| {
236 #[cfg(feature = "tracing")]
237 tracing::warn!(
238 "SQLite: corrupted JSON in record_history row, \
239 substituting null: {e}"
240 );
241 #[cfg(not(feature = "tracing"))]
242 let _ = e;
243 Value::Null
244 }),
245 stored_at: row.get::<_, i64>(2).map(|v| v.max(0) as u64)?,
246 })
247 },
248 )
249 .map_err(|e| PersistenceError::Backend(e.to_string()))?;
250
251 rows.collect::<Result<Vec<_>, _>>()
252 .map_err(|e| PersistenceError::Backend(e.to_string()))
253}
254
255macro_rules! send_cmd {
260 ($tx:expr, $cmd:expr) => {{
261 let (reply_tx, reply_rx) = tokio::sync::oneshot::channel();
262 $tx.send($cmd(reply_tx))
263 .map_err(|_| PersistenceError::BackendShutdown)?;
264 reply_rx
265 .await
266 .map_err(|_| PersistenceError::BackendShutdown)?
267 }};
268}
269
270impl PersistenceBackend for SqliteBackend {
275 fn store<'a>(
278 &'a self,
279 record_name: &'a str,
280 value: &'a Value,
281 timestamp: u64,
282 ) -> BoxFuture<'a, Result<(), PersistenceError>> {
283 let json = match serde_json::to_string(value) {
284 Ok(j) => j,
285 Err(e) => return Box::pin(async move { Err(PersistenceError::from(e)) }),
286 };
287 let record_name = record_name.to_string();
288 let tx = self.tx.clone();
289 Box::pin(async move {
290 send_cmd!(tx, |reply| DbCommand::Store {
291 record_name,
292 json,
293 timestamp,
294 reply,
295 })
296 })
297 }
298
299 fn query<'a>(
300 &'a self,
301 record_pattern: &'a str,
302 params: QueryParams,
303 ) -> BoxFuture<'a, Result<Vec<StoredValue>, PersistenceError>> {
304 let pattern = record_pattern.to_string();
305 let tx = self.tx.clone();
306 Box::pin(async move {
307 send_cmd!(tx, |reply| DbCommand::Query {
308 pattern,
309 params,
310 reply,
311 })
312 })
313 }
314
315 fn cleanup(&self, older_than: u64) -> BoxFuture<'_, Result<u64, PersistenceError>> {
316 let tx = self.tx.clone();
317 Box::pin(async move { send_cmd!(tx, |reply| DbCommand::Cleanup { older_than, reply }) })
318 }
319}
320
321#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[tokio::test]
330 async fn test_store_and_query() {
331 let dir = tempfile::tempdir().unwrap();
332 let db_path = dir.path().join("test.db");
333 let backend = SqliteBackend::new(&db_path).unwrap();
334
335 let value = serde_json::json!({"celsius": 21.5, "city": "vienna"});
337 backend.store("temp::vienna", &value, 1000).await.unwrap();
338 backend.store("temp::vienna", &value, 2000).await.unwrap();
339 backend.store("temp::berlin", &value, 1500).await.unwrap();
340
341 let results = backend
343 .query(
344 "temp::*",
345 QueryParams {
346 limit_per_record: Some(1),
347 ..Default::default()
348 },
349 )
350 .await
351 .unwrap();
352
353 assert_eq!(results.len(), 2); assert!(results.iter().any(|r| r.record_name == "temp::vienna"));
355 assert!(results.iter().any(|r| r.record_name == "temp::berlin"));
356
357 let vienna = results
359 .iter()
360 .find(|r| r.record_name == "temp::vienna")
361 .unwrap();
362 assert_eq!(vienna.stored_at, 2000);
363 }
364
365 #[tokio::test]
366 async fn test_time_range_query() {
367 let dir = tempfile::tempdir().unwrap();
368 let db_path = dir.path().join("test_range.db");
369 let backend = SqliteBackend::new(&db_path).unwrap();
370
371 let value = serde_json::json!({"celsius": 20.0});
372 for ts in [1000u64, 2000, 3000, 4000, 5000] {
373 backend.store("temp::vienna", &value, ts).await.unwrap();
374 }
375
376 let results = backend
377 .query(
378 "temp::vienna",
379 QueryParams {
380 start_time: Some(2000),
381 end_time: Some(4000),
382 ..Default::default()
383 },
384 )
385 .await
386 .unwrap();
387
388 assert_eq!(results.len(), 3); }
390
391 #[tokio::test]
392 async fn test_cleanup() {
393 let dir = tempfile::tempdir().unwrap();
394 let db_path = dir.path().join("test_cleanup.db");
395 let backend = SqliteBackend::new(&db_path).unwrap();
396
397 let value = serde_json::json!({"celsius": 20.0});
398 backend.store("temp::a", &value, 1000).await.unwrap();
399 backend.store("temp::b", &value, 2000).await.unwrap();
400 backend.store("temp::c", &value, 3000).await.unwrap();
401
402 let deleted = backend.cleanup(2500).await.unwrap();
404 assert_eq!(deleted, 2); let results = backend
408 .query(
409 "temp::*",
410 QueryParams {
411 ..Default::default()
412 },
413 )
414 .await
415 .unwrap();
416 assert_eq!(results.len(), 1);
417 assert_eq!(results[0].stored_at, 3000);
418 }
419
420 #[tokio::test]
421 async fn test_pattern_escaping() {
422 let dir = tempfile::tempdir().unwrap();
423 let db_path = dir.path().join("test_escape.db");
424 let backend = SqliteBackend::new(&db_path).unwrap();
425
426 let value = serde_json::json!({"ok": true});
427 backend.store("test_record", &value, 1000).await.unwrap();
428 backend.store("testXrecord", &value, 1000).await.unwrap();
429
430 let results = backend
432 .query(
433 "test_record",
434 QueryParams {
435 ..Default::default()
436 },
437 )
438 .await
439 .unwrap();
440 assert_eq!(results.len(), 1);
441 assert_eq!(results[0].record_name, "test_record");
442 }
443}