Skip to main content

briefcase_core/storage/
sqlite.rs

1use super::{FlushResult, SnapshotQuery, StorageBackend, StorageError};
2use crate::models::{DecisionSnapshot, Snapshot, SnapshotType};
3use rusqlite::{params, Connection, OptionalExtension};
4use serde_json;
5use std::path::Path;
6use std::sync::{Arc, Mutex};
7#[cfg(feature = "async")]
8use tokio::task;
9
10#[derive(Debug, Clone, PartialEq)]
11pub enum CompressionType {
12    None,
13    Gzip,
14}
15
16pub struct SqliteBackend {
17    pub conn: Arc<Mutex<Connection>>,
18}
19
20impl SqliteBackend {
21    /// Create or open a SQLite database at the given path
22    pub fn new(path: impl AsRef<Path>) -> Result<Self, StorageError> {
23        let conn = Connection::open(path).map_err(|e| {
24            StorageError::ConnectionError(format!("Failed to open database: {}", e))
25        })?;
26
27        let backend = Self {
28            conn: Arc::new(Mutex::new(conn)),
29        };
30
31        // Run migrations
32        {
33            let conn_guard = backend.conn.lock().unwrap();
34            Self::run_migrations(&conn_guard)?;
35        }
36
37        Ok(backend)
38    }
39
40    /// Create an in-memory database (for testing)
41    pub fn in_memory() -> Result<Self, StorageError> {
42        let conn = Connection::open(":memory:").map_err(|e| {
43            StorageError::ConnectionError(format!("Failed to create in-memory database: {}", e))
44        })?;
45
46        let backend = Self {
47            conn: Arc::new(Mutex::new(conn)),
48        };
49
50        // Run migrations
51        {
52            let conn_guard = backend.conn.lock().unwrap();
53            Self::run_migrations(&conn_guard)?;
54        }
55
56        Ok(backend)
57    }
58
59    /// Run database migrations
60    fn run_migrations(conn: &Connection) -> Result<(), StorageError> {
61        // Enable WAL mode for better concurrent access
62        conn.pragma_update(None, "journal_mode", "WAL")
63            .map_err(|e| StorageError::ConnectionError(format!("Failed to set WAL mode: {}", e)))?;
64
65        // Enable foreign keys
66        conn.pragma_update(None, "foreign_keys", "ON")
67            .map_err(|e| {
68                StorageError::ConnectionError(format!("Failed to enable foreign keys: {}", e))
69            })?;
70
71        // Create snapshots table
72        conn.execute(
73            r#"
74            CREATE TABLE IF NOT EXISTS snapshots (
75                id TEXT PRIMARY KEY,
76                snapshot_type TEXT NOT NULL,
77                data_json TEXT NOT NULL,
78                created_at DATETIME NOT NULL,
79                created_by TEXT,
80                checksum TEXT
81            )
82            "#,
83            [],
84        )
85        .map_err(|e| {
86            StorageError::ConnectionError(format!("Failed to create snapshots table: {}", e))
87        })?;
88
89        // Create index
90        conn.execute(
91            "CREATE INDEX IF NOT EXISTS idx_snapshots_created_at ON snapshots(created_at)",
92            [],
93        )
94        .map_err(|e| StorageError::ConnectionError(format!("Failed to create index: {}", e)))?;
95
96        Ok(())
97    }
98
99    pub fn save_internal(&self, snapshot: &Snapshot) -> Result<String, StorageError> {
100        let conn_guard = self.conn.lock().unwrap();
101        let snapshot_id = snapshot.metadata.snapshot_id.to_string();
102
103        let data_json = serde_json::to_string(snapshot)
104            .map_err(|e| StorageError::SerializationError(e.to_string()))?;
105
106        conn_guard
107            .execute(
108                r#"
109            INSERT OR REPLACE INTO snapshots (
110                id, snapshot_type, data_json, created_at, created_by, checksum
111            ) VALUES (?, ?, ?, ?, ?, ?)
112            "#,
113                params![
114                    snapshot_id,
115                    format!("{:?}", snapshot.snapshot_type),
116                    data_json,
117                    snapshot
118                        .metadata
119                        .timestamp
120                        .format("%Y-%m-%d %H:%M:%S%.3f")
121                        .to_string(),
122                    snapshot.metadata.created_by,
123                    snapshot.metadata.checksum,
124                ],
125            )
126            .map_err(|e| {
127                StorageError::ConnectionError(format!("Failed to insert snapshot: {}", e))
128            })?;
129
130        Ok(snapshot_id)
131    }
132
133    pub fn save_decision_internal(
134        &self,
135        decision: &DecisionSnapshot,
136    ) -> Result<String, StorageError> {
137        // For simplicity, we'll save decisions as individual snapshots
138        let snapshot = Snapshot {
139            metadata: decision.metadata.clone(),
140            decisions: vec![decision.clone()],
141            snapshot_type: SnapshotType::Decision,
142        };
143
144        self.save_internal(&snapshot)
145    }
146
147    pub fn load_internal(&self, snapshot_id: &str) -> Result<Snapshot, StorageError> {
148        let conn_guard = self.conn.lock().unwrap();
149
150        let row: Option<(String,)> = conn_guard
151            .query_row(
152                "SELECT data_json FROM snapshots WHERE id = ?",
153                params![snapshot_id],
154                |row| Ok((row.get(0)?,)),
155            )
156            .optional()
157            .map_err(|e| {
158                StorageError::ConnectionError(format!("Failed to query snapshot: {}", e))
159            })?;
160
161        match row {
162            Some((data_json,)) => {
163                let snapshot: Snapshot = serde_json::from_str(&data_json)
164                    .map_err(|e| StorageError::SerializationError(e.to_string()))?;
165                Ok(snapshot)
166            }
167            None => Err(StorageError::NotFound(format!(
168                "Snapshot {} not found",
169                snapshot_id
170            ))),
171        }
172    }
173
174    pub fn query_internal(&self, query: SnapshotQuery) -> Result<Vec<Snapshot>, StorageError> {
175        let conn_guard = self.conn.lock().unwrap();
176
177        let mut sql = "SELECT data_json FROM snapshots WHERE 1=1".to_string();
178        let mut params_vec: Vec<String> = Vec::new();
179
180        // Build WHERE clause
181        if let Some(start_time) = query.start_time {
182            sql.push_str(" AND created_at >= ?");
183            params_vec.push(start_time.format("%Y-%m-%d %H:%M:%S%.3f").to_string());
184        }
185
186        if let Some(end_time) = query.end_time {
187            sql.push_str(" AND created_at <= ?");
188            params_vec.push(end_time.format("%Y-%m-%d %H:%M:%S%.3f").to_string());
189        }
190
191        // Add ordering and pagination
192        sql.push_str(" ORDER BY created_at DESC");
193
194        if let Some(limit) = query.limit {
195            sql.push_str(" LIMIT ?");
196            params_vec.push(limit.to_string());
197        }
198
199        if let Some(offset) = query.offset {
200            sql.push_str(" OFFSET ?");
201            params_vec.push(offset.to_string());
202        }
203
204        // Execute query
205        let mut stmt = conn_guard
206            .prepare(&sql)
207            .map_err(|e| StorageError::InvalidQuery(format!("Invalid query: {}", e)))?;
208
209        let param_refs: Vec<&dyn rusqlite::ToSql> = params_vec
210            .iter()
211            .map(|p| p as &dyn rusqlite::ToSql)
212            .collect();
213
214        let rows = stmt
215            .query_map(param_refs.as_slice(), |row| row.get::<_, String>(0))
216            .map_err(|e| StorageError::ConnectionError(format!("Query failed: {}", e)))?;
217
218        let mut snapshots = Vec::new();
219        for row in rows {
220            let data_json =
221                row.map_err(|e| StorageError::ConnectionError(format!("Row error: {}", e)))?;
222            let snapshot: Snapshot = serde_json::from_str(&data_json)
223                .map_err(|e| StorageError::SerializationError(e.to_string()))?;
224
225            // Apply additional filters that require checking the snapshot content
226            if self.matches_query_filters(&snapshot, &query) {
227                snapshots.push(snapshot);
228            }
229        }
230
231        Ok(snapshots)
232    }
233
234    fn matches_query_filters(&self, snapshot: &Snapshot, query: &SnapshotQuery) -> bool {
235        // Check function name, module name, model name, tags in decisions
236        if query.function_name.is_some()
237            || query.module_name.is_some()
238            || query.model_name.is_some()
239            || query.tags.is_some()
240        {
241            for decision in &snapshot.decisions {
242                if let Some(function_name) = &query.function_name {
243                    if decision.function_name != *function_name {
244                        continue;
245                    }
246                }
247
248                if let Some(module_name) = &query.module_name {
249                    if decision.module_name.as_ref() != Some(module_name) {
250                        continue;
251                    }
252                }
253
254                if let Some(model_name) = &query.model_name {
255                    if let Some(model_params) = &decision.model_parameters {
256                        if model_params.model_name != *model_name {
257                            continue;
258                        }
259                    } else {
260                        continue;
261                    }
262                }
263
264                if let Some(query_tags) = &query.tags {
265                    let mut all_tags_match = true;
266                    for (key, value) in query_tags {
267                        if decision.tags.get(key) != Some(value) {
268                            all_tags_match = false;
269                            break;
270                        }
271                    }
272                    if !all_tags_match {
273                        continue;
274                    }
275                }
276
277                // If we get here, this decision matches all filters
278                return true;
279            }
280
281            // No decisions matched the filters
282            return false;
283        }
284
285        // No content filters, so it matches
286        true
287    }
288}
289
290#[cfg(feature = "async")]
291#[async_trait::async_trait]
292impl StorageBackend for SqliteBackend {
293    async fn save(&self, snapshot: &Snapshot) -> Result<String, StorageError> {
294        let snapshot_clone = snapshot.clone();
295        let self_clone = self.clone();
296
297        task::spawn_blocking(move || self_clone.save_internal(&snapshot_clone))
298            .await
299            .map_err(|e| StorageError::ConnectionError(format!("Task join error: {}", e)))?
300    }
301
302    async fn save_decision(&self, decision: &DecisionSnapshot) -> Result<String, StorageError> {
303        let decision_clone = decision.clone();
304        let self_clone = self.clone();
305
306        task::spawn_blocking(move || self_clone.save_decision_internal(&decision_clone))
307            .await
308            .map_err(|e| StorageError::ConnectionError(format!("Task join error: {}", e)))?
309    }
310
311    async fn load(&self, snapshot_id: &str) -> Result<Snapshot, StorageError> {
312        let id = snapshot_id.to_string();
313        let self_clone = self.clone();
314
315        task::spawn_blocking(move || self_clone.load_internal(&id))
316            .await
317            .map_err(|e| StorageError::ConnectionError(format!("Task join error: {}", e)))?
318    }
319
320    async fn load_decision(&self, decision_id: &str) -> Result<DecisionSnapshot, StorageError> {
321        let snapshot = self.load(decision_id).await?;
322        if let Some(decision) = snapshot.decisions.first() {
323            Ok(decision.clone())
324        } else {
325            Err(StorageError::NotFound(format!(
326                "Decision {} not found",
327                decision_id
328            )))
329        }
330    }
331
332    async fn query(&self, query: SnapshotQuery) -> Result<Vec<Snapshot>, StorageError> {
333        let self_clone = self.clone();
334
335        task::spawn_blocking(move || self_clone.query_internal(query))
336            .await
337            .map_err(|e| StorageError::ConnectionError(format!("Task join error: {}", e)))?
338    }
339
340    async fn delete(&self, snapshot_id: &str) -> Result<bool, StorageError> {
341        let id = snapshot_id.to_string();
342        let self_clone = self.clone();
343
344        task::spawn_blocking(move || {
345            let conn_guard = self_clone.conn.lock().unwrap();
346
347            let rows_affected = conn_guard
348                .execute("DELETE FROM snapshots WHERE id = ?", params![id])
349                .map_err(|e| {
350                    StorageError::ConnectionError(format!("Failed to delete snapshot: {}", e))
351                })?;
352
353            Ok(rows_affected > 0)
354        })
355        .await
356        .map_err(|e| StorageError::ConnectionError(format!("Task join error: {}", e)))?
357    }
358
359    async fn flush(&self) -> Result<FlushResult, StorageError> {
360        let self_clone = self.clone();
361
362        task::spawn_blocking(move || {
363            let conn_guard = self_clone.conn.lock().unwrap();
364
365            // Force WAL checkpoint
366            conn_guard
367                .execute("PRAGMA wal_checkpoint(TRUNCATE)", [])
368                .map_err(|e| {
369                    StorageError::ConnectionError(format!("Failed to checkpoint WAL: {}", e))
370                })?;
371
372            // Get stats
373            let snapshot_count: i64 = conn_guard
374                .query_row("SELECT COUNT(*) FROM snapshots", [], |row| row.get(0))
375                .unwrap_or(0);
376
377            Ok(FlushResult {
378                snapshots_written: snapshot_count as usize,
379                bytes_written: 0, // SQLite doesn't easily report this
380                checkpoint_id: None,
381            })
382        })
383        .await
384        .map_err(|e| StorageError::ConnectionError(format!("Task join error: {}", e)))?
385    }
386
387    async fn health_check(&self) -> Result<bool, StorageError> {
388        let self_clone = self.clone();
389
390        task::spawn_blocking(move || {
391            let conn_guard = self_clone.conn.lock().unwrap();
392
393            // Simple query to check connection
394            let _: i64 = conn_guard
395                .query_row("SELECT 1", [], |row| row.get(0))
396                .map_err(|e| {
397                    StorageError::ConnectionError(format!("Health check failed: {}", e))
398                })?;
399
400            Ok(true)
401        })
402        .await
403        .map_err(|e| StorageError::ConnectionError(format!("Task join error: {}", e)))?
404    }
405}
406
407impl Clone for SqliteBackend {
408    fn clone(&self) -> Self {
409        Self {
410            conn: Arc::clone(&self.conn),
411        }
412    }
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418    use crate::models::*;
419    use serde_json::json;
420
421    async fn create_test_snapshot() -> Snapshot {
422        let input = Input::new("test_input", json!("value"), "string");
423        let output = Output::new("test_output", json!("result"), "string");
424        let model_params = ModelParameters::new("gpt-4");
425
426        let decision = DecisionSnapshot::new("test_function")
427            .with_module("test_module")
428            .add_input(input)
429            .add_output(output)
430            .with_model_parameters(model_params)
431            .add_tag("env", "test");
432
433        let mut snapshot = Snapshot::new(SnapshotType::Session);
434        snapshot.add_decision(decision);
435        snapshot
436    }
437
438    #[tokio::test]
439    async fn test_sqlite_in_memory() {
440        let backend = SqliteBackend::in_memory().unwrap();
441        assert!(backend.health_check().await.unwrap());
442    }
443
444    #[tokio::test]
445    async fn test_save_and_load_snapshot() {
446        let backend = SqliteBackend::in_memory().unwrap();
447        let snapshot = create_test_snapshot().await;
448
449        let snapshot_id = backend.save(&snapshot).await.unwrap();
450        let loaded_snapshot = backend.load(&snapshot_id).await.unwrap();
451
452        assert_eq!(snapshot.decisions.len(), loaded_snapshot.decisions.len());
453        assert_eq!(snapshot.snapshot_type, loaded_snapshot.snapshot_type);
454    }
455
456    #[tokio::test]
457    async fn test_query_by_function_name() {
458        let backend = SqliteBackend::in_memory().unwrap();
459        let snapshot = create_test_snapshot().await;
460        backend.save(&snapshot).await.unwrap();
461
462        let query = SnapshotQuery::new().with_function_name("test_function");
463        let results = backend.query(query).await.unwrap();
464
465        assert_eq!(results.len(), 1);
466        assert_eq!(results[0].decisions[0].function_name, "test_function");
467    }
468
469    #[tokio::test]
470    async fn test_delete_snapshot() {
471        let backend = SqliteBackend::in_memory().unwrap();
472        let snapshot = create_test_snapshot().await;
473
474        let snapshot_id = backend.save(&snapshot).await.unwrap();
475        assert!(backend.delete(&snapshot_id).await.unwrap());
476
477        let result = backend.load(&snapshot_id).await;
478        assert!(matches!(result, Err(StorageError::NotFound(_))));
479    }
480}