1use super::{FlushResult, SnapshotQuery, StorageBackend, StorageError};
2use crate::models::{DecisionSnapshot, Snapshot, SnapshotType};
3use rusqlite::{params, Connection, OptionalExtension};
4use serde_json;
5use std::path::Path;
6use std::sync::{Arc, Mutex};
7#[cfg(feature = "async")]
8use tokio::task;
9
10#[derive(Debug, Clone, PartialEq)]
11pub enum CompressionType {
12 None,
13 Gzip,
14}
15
16pub struct SqliteBackend {
17 pub conn: Arc<Mutex<Connection>>,
18}
19
20impl SqliteBackend {
21 pub fn new(path: impl AsRef<Path>) -> Result<Self, StorageError> {
23 let conn = Connection::open(path).map_err(|e| {
24 StorageError::ConnectionError(format!("Failed to open database: {}", e))
25 })?;
26
27 let backend = Self {
28 conn: Arc::new(Mutex::new(conn)),
29 };
30
31 {
33 let conn_guard = backend.conn.lock().unwrap();
34 Self::run_migrations(&conn_guard)?;
35 }
36
37 Ok(backend)
38 }
39
40 pub fn in_memory() -> Result<Self, StorageError> {
42 let conn = Connection::open(":memory:").map_err(|e| {
43 StorageError::ConnectionError(format!("Failed to create in-memory database: {}", e))
44 })?;
45
46 let backend = Self {
47 conn: Arc::new(Mutex::new(conn)),
48 };
49
50 {
52 let conn_guard = backend.conn.lock().unwrap();
53 Self::run_migrations(&conn_guard)?;
54 }
55
56 Ok(backend)
57 }
58
59 fn run_migrations(conn: &Connection) -> Result<(), StorageError> {
61 conn.pragma_update(None, "journal_mode", "WAL")
63 .map_err(|e| StorageError::ConnectionError(format!("Failed to set WAL mode: {}", e)))?;
64
65 conn.pragma_update(None, "foreign_keys", "ON")
67 .map_err(|e| {
68 StorageError::ConnectionError(format!("Failed to enable foreign keys: {}", e))
69 })?;
70
71 conn.execute(
73 r#"
74 CREATE TABLE IF NOT EXISTS snapshots (
75 id TEXT PRIMARY KEY,
76 snapshot_type TEXT NOT NULL,
77 data_json TEXT NOT NULL,
78 created_at DATETIME NOT NULL,
79 created_by TEXT,
80 checksum TEXT
81 )
82 "#,
83 [],
84 )
85 .map_err(|e| {
86 StorageError::ConnectionError(format!("Failed to create snapshots table: {}", e))
87 })?;
88
89 conn.execute(
91 "CREATE INDEX IF NOT EXISTS idx_snapshots_created_at ON snapshots(created_at)",
92 [],
93 )
94 .map_err(|e| StorageError::ConnectionError(format!("Failed to create index: {}", e)))?;
95
96 Ok(())
97 }
98
99 pub fn save_internal(&self, snapshot: &Snapshot) -> Result<String, StorageError> {
100 let conn_guard = self.conn.lock().unwrap();
101 let snapshot_id = snapshot.metadata.snapshot_id.to_string();
102
103 let data_json = serde_json::to_string(snapshot)
104 .map_err(|e| StorageError::SerializationError(e.to_string()))?;
105
106 conn_guard
107 .execute(
108 r#"
109 INSERT OR REPLACE INTO snapshots (
110 id, snapshot_type, data_json, created_at, created_by, checksum
111 ) VALUES (?, ?, ?, ?, ?, ?)
112 "#,
113 params![
114 snapshot_id,
115 format!("{:?}", snapshot.snapshot_type),
116 data_json,
117 snapshot
118 .metadata
119 .timestamp
120 .format("%Y-%m-%d %H:%M:%S%.3f")
121 .to_string(),
122 snapshot.metadata.created_by,
123 snapshot.metadata.checksum,
124 ],
125 )
126 .map_err(|e| {
127 StorageError::ConnectionError(format!("Failed to insert snapshot: {}", e))
128 })?;
129
130 Ok(snapshot_id)
131 }
132
133 pub fn save_decision_internal(
134 &self,
135 decision: &DecisionSnapshot,
136 ) -> Result<String, StorageError> {
137 let snapshot = Snapshot {
139 metadata: decision.metadata.clone(),
140 decisions: vec![decision.clone()],
141 snapshot_type: SnapshotType::Decision,
142 };
143
144 self.save_internal(&snapshot)
145 }
146
147 pub fn load_internal(&self, snapshot_id: &str) -> Result<Snapshot, StorageError> {
148 let conn_guard = self.conn.lock().unwrap();
149
150 let row: Option<(String,)> = conn_guard
151 .query_row(
152 "SELECT data_json FROM snapshots WHERE id = ?",
153 params![snapshot_id],
154 |row| Ok((row.get(0)?,)),
155 )
156 .optional()
157 .map_err(|e| {
158 StorageError::ConnectionError(format!("Failed to query snapshot: {}", e))
159 })?;
160
161 match row {
162 Some((data_json,)) => {
163 let snapshot: Snapshot = serde_json::from_str(&data_json)
164 .map_err(|e| StorageError::SerializationError(e.to_string()))?;
165 Ok(snapshot)
166 }
167 None => Err(StorageError::NotFound(format!(
168 "Snapshot {} not found",
169 snapshot_id
170 ))),
171 }
172 }
173
174 pub fn query_internal(&self, query: SnapshotQuery) -> Result<Vec<Snapshot>, StorageError> {
175 let conn_guard = self.conn.lock().unwrap();
176
177 let mut sql = "SELECT data_json FROM snapshots WHERE 1=1".to_string();
178 let mut params_vec: Vec<String> = Vec::new();
179
180 if let Some(start_time) = query.start_time {
182 sql.push_str(" AND created_at >= ?");
183 params_vec.push(start_time.format("%Y-%m-%d %H:%M:%S%.3f").to_string());
184 }
185
186 if let Some(end_time) = query.end_time {
187 sql.push_str(" AND created_at <= ?");
188 params_vec.push(end_time.format("%Y-%m-%d %H:%M:%S%.3f").to_string());
189 }
190
191 sql.push_str(" ORDER BY created_at DESC");
193
194 if let Some(limit) = query.limit {
195 sql.push_str(" LIMIT ?");
196 params_vec.push(limit.to_string());
197 }
198
199 if let Some(offset) = query.offset {
200 sql.push_str(" OFFSET ?");
201 params_vec.push(offset.to_string());
202 }
203
204 let mut stmt = conn_guard
206 .prepare(&sql)
207 .map_err(|e| StorageError::InvalidQuery(format!("Invalid query: {}", e)))?;
208
209 let param_refs: Vec<&dyn rusqlite::ToSql> = params_vec
210 .iter()
211 .map(|p| p as &dyn rusqlite::ToSql)
212 .collect();
213
214 let rows = stmt
215 .query_map(param_refs.as_slice(), |row| row.get::<_, String>(0))
216 .map_err(|e| StorageError::ConnectionError(format!("Query failed: {}", e)))?;
217
218 let mut snapshots = Vec::new();
219 for row in rows {
220 let data_json =
221 row.map_err(|e| StorageError::ConnectionError(format!("Row error: {}", e)))?;
222 let snapshot: Snapshot = serde_json::from_str(&data_json)
223 .map_err(|e| StorageError::SerializationError(e.to_string()))?;
224
225 if self.matches_query_filters(&snapshot, &query) {
227 snapshots.push(snapshot);
228 }
229 }
230
231 Ok(snapshots)
232 }
233
234 fn matches_query_filters(&self, snapshot: &Snapshot, query: &SnapshotQuery) -> bool {
235 if query.function_name.is_some()
237 || query.module_name.is_some()
238 || query.model_name.is_some()
239 || query.tags.is_some()
240 {
241 for decision in &snapshot.decisions {
242 if let Some(function_name) = &query.function_name {
243 if decision.function_name != *function_name {
244 continue;
245 }
246 }
247
248 if let Some(module_name) = &query.module_name {
249 if decision.module_name.as_ref() != Some(module_name) {
250 continue;
251 }
252 }
253
254 if let Some(model_name) = &query.model_name {
255 if let Some(model_params) = &decision.model_parameters {
256 if model_params.model_name != *model_name {
257 continue;
258 }
259 } else {
260 continue;
261 }
262 }
263
264 if let Some(query_tags) = &query.tags {
265 let mut all_tags_match = true;
266 for (key, value) in query_tags {
267 if decision.tags.get(key) != Some(value) {
268 all_tags_match = false;
269 break;
270 }
271 }
272 if !all_tags_match {
273 continue;
274 }
275 }
276
277 return true;
279 }
280
281 return false;
283 }
284
285 true
287 }
288}
289
290#[cfg(feature = "async")]
291#[async_trait::async_trait]
292impl StorageBackend for SqliteBackend {
293 async fn save(&self, snapshot: &Snapshot) -> Result<String, StorageError> {
294 let snapshot_clone = snapshot.clone();
295 let self_clone = self.clone();
296
297 task::spawn_blocking(move || self_clone.save_internal(&snapshot_clone))
298 .await
299 .map_err(|e| StorageError::ConnectionError(format!("Task join error: {}", e)))?
300 }
301
302 async fn save_decision(&self, decision: &DecisionSnapshot) -> Result<String, StorageError> {
303 let decision_clone = decision.clone();
304 let self_clone = self.clone();
305
306 task::spawn_blocking(move || self_clone.save_decision_internal(&decision_clone))
307 .await
308 .map_err(|e| StorageError::ConnectionError(format!("Task join error: {}", e)))?
309 }
310
311 async fn load(&self, snapshot_id: &str) -> Result<Snapshot, StorageError> {
312 let id = snapshot_id.to_string();
313 let self_clone = self.clone();
314
315 task::spawn_blocking(move || self_clone.load_internal(&id))
316 .await
317 .map_err(|e| StorageError::ConnectionError(format!("Task join error: {}", e)))?
318 }
319
320 async fn load_decision(&self, decision_id: &str) -> Result<DecisionSnapshot, StorageError> {
321 let snapshot = self.load(decision_id).await?;
322 if let Some(decision) = snapshot.decisions.first() {
323 Ok(decision.clone())
324 } else {
325 Err(StorageError::NotFound(format!(
326 "Decision {} not found",
327 decision_id
328 )))
329 }
330 }
331
332 async fn query(&self, query: SnapshotQuery) -> Result<Vec<Snapshot>, StorageError> {
333 let self_clone = self.clone();
334
335 task::spawn_blocking(move || self_clone.query_internal(query))
336 .await
337 .map_err(|e| StorageError::ConnectionError(format!("Task join error: {}", e)))?
338 }
339
340 async fn delete(&self, snapshot_id: &str) -> Result<bool, StorageError> {
341 let id = snapshot_id.to_string();
342 let self_clone = self.clone();
343
344 task::spawn_blocking(move || {
345 let conn_guard = self_clone.conn.lock().unwrap();
346
347 let rows_affected = conn_guard
348 .execute("DELETE FROM snapshots WHERE id = ?", params![id])
349 .map_err(|e| {
350 StorageError::ConnectionError(format!("Failed to delete snapshot: {}", e))
351 })?;
352
353 Ok(rows_affected > 0)
354 })
355 .await
356 .map_err(|e| StorageError::ConnectionError(format!("Task join error: {}", e)))?
357 }
358
359 async fn flush(&self) -> Result<FlushResult, StorageError> {
360 let self_clone = self.clone();
361
362 task::spawn_blocking(move || {
363 let conn_guard = self_clone.conn.lock().unwrap();
364
365 conn_guard
367 .execute("PRAGMA wal_checkpoint(TRUNCATE)", [])
368 .map_err(|e| {
369 StorageError::ConnectionError(format!("Failed to checkpoint WAL: {}", e))
370 })?;
371
372 let snapshot_count: i64 = conn_guard
374 .query_row("SELECT COUNT(*) FROM snapshots", [], |row| row.get(0))
375 .unwrap_or(0);
376
377 Ok(FlushResult {
378 snapshots_written: snapshot_count as usize,
379 bytes_written: 0, checkpoint_id: None,
381 })
382 })
383 .await
384 .map_err(|e| StorageError::ConnectionError(format!("Task join error: {}", e)))?
385 }
386
387 async fn health_check(&self) -> Result<bool, StorageError> {
388 let self_clone = self.clone();
389
390 task::spawn_blocking(move || {
391 let conn_guard = self_clone.conn.lock().unwrap();
392
393 let _: i64 = conn_guard
395 .query_row("SELECT 1", [], |row| row.get(0))
396 .map_err(|e| {
397 StorageError::ConnectionError(format!("Health check failed: {}", e))
398 })?;
399
400 Ok(true)
401 })
402 .await
403 .map_err(|e| StorageError::ConnectionError(format!("Task join error: {}", e)))?
404 }
405}
406
407impl Clone for SqliteBackend {
408 fn clone(&self) -> Self {
409 Self {
410 conn: Arc::clone(&self.conn),
411 }
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418 use crate::models::*;
419 use serde_json::json;
420
421 async fn create_test_snapshot() -> Snapshot {
422 let input = Input::new("test_input", json!("value"), "string");
423 let output = Output::new("test_output", json!("result"), "string");
424 let model_params = ModelParameters::new("gpt-4");
425
426 let decision = DecisionSnapshot::new("test_function")
427 .with_module("test_module")
428 .add_input(input)
429 .add_output(output)
430 .with_model_parameters(model_params)
431 .add_tag("env", "test");
432
433 let mut snapshot = Snapshot::new(SnapshotType::Session);
434 snapshot.add_decision(decision);
435 snapshot
436 }
437
438 #[tokio::test]
439 async fn test_sqlite_in_memory() {
440 let backend = SqliteBackend::in_memory().unwrap();
441 assert!(backend.health_check().await.unwrap());
442 }
443
444 #[tokio::test]
445 async fn test_save_and_load_snapshot() {
446 let backend = SqliteBackend::in_memory().unwrap();
447 let snapshot = create_test_snapshot().await;
448
449 let snapshot_id = backend.save(&snapshot).await.unwrap();
450 let loaded_snapshot = backend.load(&snapshot_id).await.unwrap();
451
452 assert_eq!(snapshot.decisions.len(), loaded_snapshot.decisions.len());
453 assert_eq!(snapshot.snapshot_type, loaded_snapshot.snapshot_type);
454 }
455
456 #[tokio::test]
457 async fn test_query_by_function_name() {
458 let backend = SqliteBackend::in_memory().unwrap();
459 let snapshot = create_test_snapshot().await;
460 backend.save(&snapshot).await.unwrap();
461
462 let query = SnapshotQuery::new().with_function_name("test_function");
463 let results = backend.query(query).await.unwrap();
464
465 assert_eq!(results.len(), 1);
466 assert_eq!(results[0].decisions[0].function_name, "test_function");
467 }
468
469 #[tokio::test]
470 async fn test_delete_snapshot() {
471 let backend = SqliteBackend::in_memory().unwrap();
472 let snapshot = create_test_snapshot().await;
473
474 let snapshot_id = backend.save(&snapshot).await.unwrap();
475 assert!(backend.delete(&snapshot_id).await.unwrap());
476
477 let result = backend.load(&snapshot_id).await;
478 assert!(matches!(result, Err(StorageError::NotFound(_))));
479 }
480}