1use crate::store::{PersistenceStore, RetentionPolicy, Sample};
9use anyhow::{Context, Result};
10use rusqlite::{params, Connection};
11use std::sync::Mutex;
12use std::time::{SystemTime, UNIX_EPOCH};
13
14pub struct SqliteStore {
36 conn: Mutex<Connection>,
37}
38
39impl SqliteStore {
40 pub fn new(path: &str) -> Result<Self> {
42 let conn = Connection::open(path)
43 .with_context(|| format!("Failed to open SQLite database at {}", path))?;
44
45 let store = Self {
46 conn: Mutex::new(conn),
47 };
48 store.init_schema()?;
49 Ok(store)
50 }
51
52 pub fn new_in_memory() -> Result<Self> {
54 let conn =
55 Connection::open_in_memory().context("Failed to create in-memory SQLite database")?;
56
57 let store = Self {
58 conn: Mutex::new(conn),
59 };
60 store.init_schema()?;
61 Ok(store)
62 }
63
64 fn init_schema(&self) -> Result<()> {
66 let conn = self.conn.lock().unwrap();
67
68 conn.execute(
69 "CREATE TABLE IF NOT EXISTS samples (
70 id INTEGER PRIMARY KEY AUTOINCREMENT,
71 topic TEXT NOT NULL,
72 type_name TEXT NOT NULL,
73 payload BLOB NOT NULL,
74 timestamp_ns INTEGER NOT NULL,
75 sequence INTEGER NOT NULL,
76 source_guid BLOB NOT NULL
77 )",
78 [],
79 )?;
80
81 conn.execute("CREATE INDEX IF NOT EXISTS idx_topic ON samples(topic)", [])?;
82
83 conn.execute(
84 "CREATE INDEX IF NOT EXISTS idx_timestamp ON samples(timestamp_ns)",
85 [],
86 )?;
87
88 Ok(())
89 }
90
91 fn row_to_sample(row: &rusqlite::Row) -> rusqlite::Result<Sample> {
93 let source_guid_blob: Vec<u8> = row.get(5)?;
94 let mut source_guid = [0u8; 16];
95 source_guid.copy_from_slice(&source_guid_blob);
96
97 Ok(Sample {
98 topic: row.get(0)?,
99 type_name: row.get(1)?,
100 payload: row.get(2)?,
101 timestamp_ns: row.get::<_, i64>(3)? as u64,
102 sequence: row.get::<_, i64>(4)? as u64,
103 source_guid,
104 })
105 }
106}
107
108impl PersistenceStore for SqliteStore {
109 fn save(&self, sample: &Sample) -> Result<()> {
110 let conn = self.conn.lock().unwrap();
111 conn.execute(
112 "INSERT INTO samples (topic, type_name, payload, timestamp_ns, sequence, source_guid)
113 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
114 params![
115 sample.topic,
116 sample.type_name,
117 sample.payload,
118 sample.timestamp_ns as i64,
119 sample.sequence as i64,
120 &sample.source_guid[..],
121 ],
122 )?;
123
124 Ok(())
125 }
126
127 fn load(&self, topic: &str) -> Result<Vec<Sample>> {
128 let conn = self.conn.lock().unwrap();
129 let mut stmt = conn.prepare(
130 "SELECT topic, type_name, payload, timestamp_ns, sequence, source_guid
131 FROM samples
132 WHERE topic = ?1
133 ORDER BY timestamp_ns ASC",
134 )?;
135
136 let samples = stmt
137 .query_map([topic], Self::row_to_sample)?
138 .collect::<Result<Vec<_>, _>>()?;
139
140 Ok(samples)
141 }
142
143 fn query_range(&self, topic: &str, start_ns: u64, end_ns: u64) -> Result<Vec<Sample>> {
144 let conn = self.conn.lock().unwrap();
145
146 let start_i64 = start_ns.min(i64::MAX as u64) as i64;
148 let end_i64 = end_ns.min(i64::MAX as u64) as i64;
149
150 let prefix = topic.strip_suffix("/*");
152 let query = if let Some(prefix) = prefix {
153 format!(
154 "SELECT topic, type_name, payload, timestamp_ns, sequence, source_guid
155 FROM samples
156 WHERE topic LIKE '{}/%' AND timestamp_ns BETWEEN ?1 AND ?2
157 ORDER BY timestamp_ns ASC",
158 prefix
159 )
160 } else if topic == "*" {
161 "SELECT topic, type_name, payload, timestamp_ns, sequence, source_guid
162 FROM samples
163 WHERE timestamp_ns BETWEEN ?1 AND ?2
164 ORDER BY timestamp_ns ASC"
165 .to_string()
166 } else {
167 "SELECT topic, type_name, payload, timestamp_ns, sequence, source_guid
168 FROM samples
169 WHERE topic = ?3 AND timestamp_ns BETWEEN ?1 AND ?2
170 ORDER BY timestamp_ns ASC"
171 .to_string()
172 };
173
174 let mut stmt = conn.prepare(&query)?;
175
176 let rows = if prefix.is_some() || topic == "*" {
177 stmt.query_map(params![start_i64, end_i64], Self::row_to_sample)?
178 } else {
179 stmt.query_map(params![start_i64, end_i64, topic], Self::row_to_sample)?
180 };
181
182 let samples = rows.collect::<Result<Vec<_>, _>>()?;
183 Ok(samples)
184 }
185
186 fn apply_retention(&self, topic: &str, keep_count: usize) -> Result<()> {
187 let conn = self.conn.lock().unwrap();
188
189 conn.execute(
191 "DELETE FROM samples
192 WHERE topic = ?1
193 AND id NOT IN (
194 SELECT id FROM samples
195 WHERE topic = ?1
196 ORDER BY timestamp_ns DESC
197 LIMIT ?2
198 )",
199 params![topic, keep_count],
200 )?;
201
202 Ok(())
203 }
204
205 fn apply_retention_policy(&self, topic: &str, policy: &RetentionPolicy) -> Result<()> {
206 if policy.is_noop() {
207 return Ok(());
208 }
209
210 let mut conn = self.conn.lock().unwrap();
211
212 if policy.keep_count > 0 {
213 conn.execute(
214 "DELETE FROM samples
215 WHERE topic = ?1
216 AND id NOT IN (
217 SELECT id FROM samples
218 WHERE topic = ?1
219 ORDER BY timestamp_ns DESC
220 LIMIT ?2
221 )",
222 params![topic, policy.keep_count as i64],
223 )?;
224 }
225
226 if let Some(max_age_ns) = policy.max_age_ns {
227 let now_ns = SystemTime::now()
228 .duration_since(UNIX_EPOCH)
229 .unwrap_or_default()
230 .as_nanos() as u64;
231 let cutoff = now_ns.saturating_sub(max_age_ns);
232 let cutoff_i64 = cutoff.min(i64::MAX as u64) as i64;
233 conn.execute(
234 "DELETE FROM samples
235 WHERE topic = ?1 AND timestamp_ns < ?2",
236 params![topic, cutoff_i64],
237 )?;
238 }
239
240 if let Some(max_bytes) = policy.max_bytes {
241 let ids_to_delete = {
242 let mut stmt = conn.prepare(
243 "SELECT id, length(payload) FROM samples
244 WHERE topic = ?1
245 ORDER BY timestamp_ns DESC",
246 )?;
247 let rows = stmt.query_map([topic], |row| {
248 let id: i64 = row.get(0)?;
249 let len: i64 = row.get(1)?;
250 Ok((id, len))
251 })?;
252
253 let mut total = 0u64;
254 let mut ids: Vec<i64> = Vec::new();
255 for row in rows {
256 let (id, len) = row?;
257 let len_u64 = if len < 0 { 0 } else { len as u64 };
258 if total.saturating_add(len_u64) <= max_bytes {
259 total = total.saturating_add(len_u64);
260 } else {
261 ids.push(id);
262 }
263 }
264 ids
265 };
266
267 if !ids_to_delete.is_empty() {
268 let tx = conn.transaction()?;
269 {
270 let mut del = tx.prepare("DELETE FROM samples WHERE id = ?1")?;
271 for id in ids_to_delete {
272 del.execute([id])?;
273 }
274 }
275 tx.commit()?;
276 }
277 }
278
279 Ok(())
280 }
281
282 fn count(&self) -> Result<usize> {
283 let conn = self.conn.lock().unwrap();
284 let count: i64 = conn.query_row("SELECT COUNT(*) FROM samples", [], |row| row.get(0))?;
285
286 Ok(count as usize)
287 }
288
289 fn clear(&self) -> Result<()> {
290 let conn = self.conn.lock().unwrap();
291 conn.execute("DELETE FROM samples", [])?;
292 Ok(())
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 #[test]
301 fn test_sqlite_store_save_and_load() {
302 let store = SqliteStore::new_in_memory().unwrap();
303
304 let sample = Sample {
305 topic: "test/topic".to_string(),
306 type_name: "TestType".to_string(),
307 payload: vec![0x01, 0x02, 0x03],
308 timestamp_ns: 1000,
309 sequence: 1,
310 source_guid: [0xAA; 16],
311 };
312
313 store.save(&sample).unwrap();
314
315 let loaded = store.load("test/topic").unwrap();
316 assert_eq!(loaded.len(), 1);
317 assert_eq!(loaded[0].topic, "test/topic");
318 assert_eq!(loaded[0].sequence, 1);
319 }
320
321 #[test]
322 fn test_sqlite_store_query_range() {
323 let store = SqliteStore::new_in_memory().unwrap();
324
325 for i in 0..10 {
326 let sample = Sample {
327 topic: "test/topic".to_string(),
328 type_name: "TestType".to_string(),
329 payload: vec![i as u8],
330 timestamp_ns: i * 1000,
331 sequence: i,
332 source_guid: [0xBB; 16],
333 };
334 store.save(&sample).unwrap();
335 }
336
337 let range = store.query_range("test/topic", 2000, 5000).unwrap();
338 assert_eq!(range.len(), 4); assert_eq!(range[0].sequence, 2);
340 assert_eq!(range[3].sequence, 5);
341 }
342
343 #[test]
344 fn test_sqlite_store_wildcard_query() {
345 let store = SqliteStore::new_in_memory().unwrap();
346
347 let topics = ["State/Temperature", "State/Pressure", "Command/Set"];
348
349 for (i, topic) in topics.iter().enumerate() {
350 let sample = Sample {
351 topic: topic.to_string(),
352 type_name: "TestType".to_string(),
353 payload: vec![i as u8],
354 timestamp_ns: 1000,
355 sequence: i as u64,
356 source_guid: [0xCC; 16],
357 };
358 store.save(&sample).unwrap();
359 }
360
361 let state_samples = store.query_range("State/*", 0, 10000).unwrap();
362 assert_eq!(state_samples.len(), 2);
363
364 let all_samples = store.query_range("*", 0, 10000).unwrap();
365 assert_eq!(all_samples.len(), 3);
366 }
367
368 #[test]
369 fn test_sqlite_store_retention() {
370 let store = SqliteStore::new_in_memory().unwrap();
371
372 for i in 0..10 {
373 let sample = Sample {
374 topic: "test/topic".to_string(),
375 type_name: "TestType".to_string(),
376 payload: vec![i as u8],
377 timestamp_ns: i * 1000,
378 sequence: i,
379 source_guid: [0xDD; 16],
380 };
381 store.save(&sample).unwrap();
382 }
383
384 assert_eq!(store.count().unwrap(), 10);
385
386 store.apply_retention("test/topic", 5).unwrap();
387 assert_eq!(store.count().unwrap(), 5);
388
389 let remaining = store.load("test/topic").unwrap();
390 assert_eq!(remaining[0].sequence, 5); }
392
393 #[test]
394 fn test_sqlite_store_clear() {
395 let store = SqliteStore::new_in_memory().unwrap();
396
397 let sample = Sample {
398 topic: "test/topic".to_string(),
399 type_name: "TestType".to_string(),
400 payload: vec![0x01],
401 timestamp_ns: 1000,
402 sequence: 1,
403 source_guid: [0xEE; 16],
404 };
405
406 store.save(&sample).unwrap();
407 assert_eq!(store.count().unwrap(), 1);
408
409 store.clear().unwrap();
410 assert_eq!(store.count().unwrap(), 0);
411 }
412}