Skip to main content

engram/storage/
sqlite_backend.rs

1//! SQLite implementation of the StorageBackend trait (ENG-15)
2//!
3//! This module provides a SQLite-based storage backend that implements
4//! the `StorageBackend` trait, allowing the existing SQLite storage
5//! to be used through the abstracted interface.
6
7use std::collections::HashMap;
8use std::time::Instant;
9
10use crate::error::Result;
11use crate::types::{
12    CreateCrossRefInput, CreateMemoryInput, CrossReference, EdgeType, ListOptions, Memory,
13    MemoryId, SearchOptions, SearchResult, StorageConfig, UpdateMemoryInput, WorkspaceStats,
14};
15
16use super::backend::{
17    BatchCreateResult, BatchDeleteResult, CloudSyncBackend, HealthStatus, StorageBackend,
18    StorageStats, SyncDelta, SyncResult, SyncState, TransactionalBackend,
19};
20use super::connection::Storage;
21use super::queries::{
22    self, delete_memory_batch, get_related, get_sync_delta, get_sync_version, list_tags,
23};
24use crate::search::{hybrid_search, SearchConfig};
25
26/// SQLite-based storage backend
27///
28/// This implements the `StorageBackend` trait using SQLite as the
29/// underlying database. It wraps the existing `Storage` struct and
30/// delegates to the functions in `queries.rs`.
31pub struct SqliteBackend {
32    storage: Storage,
33}
34
35impl SqliteBackend {
36    /// Create a new SQLite backend with the given configuration
37    pub fn new(config: StorageConfig) -> Result<Self> {
38        let storage = Storage::open(config)?;
39        Ok(Self { storage })
40    }
41
42    /// Create an in-memory SQLite backend (useful for testing)
43    pub fn in_memory() -> Result<Self> {
44        let storage = Storage::open_in_memory()?;
45        Ok(Self { storage })
46    }
47
48    /// Get a reference to the underlying Storage
49    pub fn storage(&self) -> &Storage {
50        &self.storage
51    }
52
53    /// Get a mutable reference to the underlying Storage
54    pub fn storage_mut(&mut self) -> &mut Storage {
55        &mut self.storage
56    }
57}
58
59impl StorageBackend for SqliteBackend {
60    fn create_memory(&self, input: CreateMemoryInput) -> Result<Memory> {
61        self.storage
62            .with_transaction(|conn| queries::create_memory(conn, &input))
63    }
64
65    fn get_memory(&self, id: MemoryId) -> Result<Option<Memory>> {
66        self.storage
67            .with_connection(|conn| match queries::get_memory(conn, id) {
68                Ok(memory) => Ok(Some(memory)),
69                Err(crate::error::EngramError::NotFound(_)) => Ok(None),
70                Err(e) => Err(e),
71            })
72    }
73
74    fn update_memory(&self, id: MemoryId, input: UpdateMemoryInput) -> Result<Memory> {
75        self.storage
76            .with_transaction(|conn| queries::update_memory(conn, id, &input))
77    }
78
79    fn delete_memory(&self, id: MemoryId) -> Result<()> {
80        self.storage
81            .with_transaction(|conn| queries::delete_memory(conn, id))
82    }
83
84    fn create_memories_batch(&self, inputs: Vec<CreateMemoryInput>) -> Result<BatchCreateResult> {
85        let start = Instant::now();
86        let mut created = Vec::new();
87        let mut failed = Vec::new();
88
89        self.storage.with_transaction(|conn| {
90            for (idx, input) in inputs.into_iter().enumerate() {
91                match queries::create_memory(conn, &input) {
92                    Ok(memory) => created.push(memory),
93                    Err(e) => failed.push((idx, e.to_string())),
94                }
95            }
96            Ok(())
97        })?;
98
99        Ok(BatchCreateResult {
100            created,
101            failed,
102            elapsed_ms: start.elapsed().as_secs_f64() * 1000.0,
103        })
104    }
105
106    fn delete_memories_batch(&self, ids: Vec<MemoryId>) -> Result<BatchDeleteResult> {
107        self.storage.with_transaction(|conn| {
108            let result = delete_memory_batch(conn, &ids)?;
109            let mut not_found = Vec::new();
110            let mut failed = Vec::new();
111
112            for err in &result.failed {
113                if let Some(id) = err.id {
114                    let msg = err.error.clone();
115                    // Heuristic to detect not found errors from bulk operation
116                    if msg.to_lowercase().contains("notfound")
117                        || msg.to_lowercase().contains("not found")
118                    {
119                        not_found.push(id);
120                    } else {
121                        failed.push((id, msg));
122                    }
123                }
124            }
125
126            Ok(BatchDeleteResult {
127                deleted_count: result.total_deleted,
128                not_found,
129                failed,
130            })
131        })
132    }
133
134    fn list_memories(&self, options: ListOptions) -> Result<Vec<Memory>> {
135        self.storage
136            .with_connection(|conn| queries::list_memories(conn, &options))
137    }
138
139    fn count_memories(&self, options: ListOptions) -> Result<i64> {
140        self.storage.with_connection(|conn| {
141            let now = chrono::Utc::now().to_rfc3339();
142
143            let mut sql = String::from("SELECT COUNT(DISTINCT m.id) FROM memories m");
144            let mut conditions = vec!["m.valid_to IS NULL".to_string()];
145            let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
146
147            // Exclude expired memories
148            conditions.push("(m.expires_at IS NULL OR m.expires_at > ?)".to_string());
149            params.push(Box::new(now));
150
151            // Tag filter (requires join)
152            if let Some(ref tags) = options.tags {
153                if !tags.is_empty() {
154                    sql.push_str(
155                        " JOIN memory_tags mt ON m.id = mt.memory_id
156                          JOIN tags t ON mt.tag_id = t.id",
157                    );
158                    let placeholders: Vec<String> = tags.iter().map(|_| "?".to_string()).collect();
159                    conditions.push(format!("t.name IN ({})", placeholders.join(", ")));
160                    for tag in tags {
161                        params.push(Box::new(tag.clone()));
162                    }
163                }
164            }
165
166            // Type filter
167            if let Some(ref memory_type) = options.memory_type {
168                conditions.push("m.memory_type = ?".to_string());
169                params.push(Box::new(memory_type.as_str().to_string()));
170            }
171
172            // Metadata filter (JSON)
173            if let Some(ref metadata_filter) = options.metadata_filter {
174                for (key, value) in metadata_filter {
175                    queries::metadata_value_to_param(key, value, &mut conditions, &mut params)?;
176                }
177            }
178
179            // Scope filter
180            if let Some(ref scope) = options.scope {
181                conditions.push("m.scope_type = ?".to_string());
182                params.push(Box::new(scope.scope_type().to_string()));
183                if let Some(scope_id) = scope.scope_id() {
184                    conditions.push("m.scope_id = ?".to_string());
185                    params.push(Box::new(scope_id.to_string()));
186                } else {
187                    conditions.push("m.scope_id IS NULL".to_string());
188                }
189            }
190
191            // Workspace filter
192            if let Some(ref workspace) = options.workspace {
193                conditions.push("m.workspace = ?".to_string());
194                params.push(Box::new(workspace.clone()));
195            }
196
197            // Tier filter
198            if let Some(ref tier) = options.tier {
199                conditions.push("m.tier = ?".to_string());
200                params.push(Box::new(tier.as_str().to_string()));
201            }
202
203            // Archived filter
204            if !options.include_archived {
205                conditions.push(
206                    "(m.lifecycle_state IS NULL OR m.lifecycle_state != 'archived')".to_string(),
207                );
208            }
209
210            sql.push_str(" WHERE ");
211            sql.push_str(&conditions.join(" AND "));
212
213            let param_refs: Vec<&dyn rusqlite::ToSql> = params.iter().map(|b| b.as_ref()).collect();
214            let count: i64 = conn.query_row(&sql, param_refs.as_slice(), |row| row.get(0))?;
215
216            Ok(count)
217        })
218    }
219
220    fn search_memories(&self, query: &str, options: SearchOptions) -> Result<Vec<SearchResult>> {
221        self.storage.with_connection(|conn| {
222            let config = SearchConfig::default();
223            // Note: hybrid_search expects embedding if vector search is desired.
224            // Here we only perform lexical/fuzzy search unless embedding is handled higher up.
225            // The trait signature doesn't take embedding, implying embedding generation happens
226            // outside or inside if we had the embedder.
227            // Since SqliteBackend doesn't have the Embedder, we pass None.
228            hybrid_search(conn, query, None, &options, &config)
229        })
230    }
231
232    fn create_crossref(
233        &self,
234        from_id: MemoryId,
235        to_id: MemoryId,
236        edge_type: EdgeType,
237        score: f32,
238    ) -> Result<CrossReference> {
239        self.storage.with_transaction(|conn| {
240            let input = CreateCrossRefInput {
241                from_id,
242                to_id,
243                edge_type,
244                strength: Some(score),
245                source_context: None,
246                pinned: false,
247            };
248            queries::create_crossref(conn, &input)
249        })
250    }
251
252    fn get_crossrefs(&self, memory_id: MemoryId) -> Result<Vec<CrossReference>> {
253        self.storage
254            .with_connection(|conn| get_related(conn, memory_id))
255    }
256
257    fn delete_crossref(&self, from_id: MemoryId, to_id: MemoryId) -> Result<()> {
258        self.storage.with_transaction(|conn| {
259            // Try to delete both directions if bidirectional?
260            // The trait implies directed deletion.
261            // We use queries::delete_crossref which takes an edge type.
262            // We'll delete all edge types for this pair.
263            for edge_type in EdgeType::all() {
264                // Ignore result (might not exist for all types)
265                let _ = queries::delete_crossref(conn, from_id, to_id, *edge_type);
266            }
267            Ok(())
268        })
269    }
270
271    fn list_tags(&self) -> Result<Vec<(String, i64)>> {
272        self.storage.with_connection(|conn| {
273            let tags = list_tags(conn)?;
274            Ok(tags.into_iter().map(|t| (t.name, t.count)).collect())
275        })
276    }
277
278    fn get_memories_by_tag(&self, tag: &str, limit: Option<usize>) -> Result<Vec<Memory>> {
279        self.storage.with_connection(|conn| {
280            let options = ListOptions {
281                tags: Some(vec![tag.to_string()]),
282                limit: limit.map(|v| v as i64),
283                ..Default::default()
284            };
285            queries::list_memories(conn, &options)
286        })
287    }
288
289    fn list_workspaces(&self) -> Result<Vec<(String, i64)>> {
290        self.storage.with_connection(|conn| {
291            let workspaces = queries::list_workspaces(conn)?;
292            Ok(workspaces
293                .into_iter()
294                .map(|w| (w.workspace, w.memory_count))
295                .collect())
296        })
297    }
298
299    fn get_workspace_stats(&self, workspace: &str) -> Result<HashMap<String, i64>> {
300        self.storage.with_connection(|conn| {
301            let stats: WorkspaceStats = queries::get_workspace_stats(conn, workspace)?;
302            let mut map = HashMap::new();
303            map.insert("memory_count".to_string(), stats.memory_count);
304            map.insert("permanent_count".to_string(), stats.permanent_count);
305            map.insert("daily_count".to_string(), stats.daily_count);
306            Ok(map)
307        })
308    }
309
310    fn move_to_workspace(&self, ids: Vec<MemoryId>, workspace: &str) -> Result<usize> {
311        self.storage.with_transaction(|conn| {
312            let mut moved = 0usize;
313            for id in ids {
314                if queries::move_to_workspace(conn, id, workspace).is_ok() {
315                    moved += 1;
316                }
317            }
318            Ok(moved)
319        })
320    }
321
322    fn get_stats(&self) -> Result<StorageStats> {
323        self.storage.with_connection(queries::get_stats)
324    }
325
326    fn health_check(&self) -> Result<HealthStatus> {
327        let start = Instant::now();
328
329        let result = self.storage.with_connection(|conn| {
330            conn.query_row("SELECT 1", [], |_| Ok(()))?;
331            Ok(())
332        });
333
334        let latency_ms = start.elapsed().as_secs_f64() * 1000.0;
335        let db_path = self.storage.db_path().to_string();
336
337        match result {
338            Ok(()) => Ok(HealthStatus {
339                healthy: true,
340                latency_ms,
341                error: None,
342                details: HashMap::from([
343                    ("db_path".to_string(), db_path),
344                    (
345                        "storage_mode".to_string(),
346                        format!("{:?}", self.storage.storage_mode()),
347                    ),
348                ]),
349            }),
350            Err(e) => Ok(HealthStatus {
351                healthy: false,
352                latency_ms,
353                error: Some(e.to_string()),
354                details: HashMap::from([("db_path".to_string(), db_path)]),
355            }),
356        }
357    }
358
359    fn optimize(&self) -> Result<()> {
360        self.storage.vacuum()?;
361        self.storage.checkpoint()?;
362        Ok(())
363    }
364
365    fn backend_name(&self) -> &'static str {
366        "sqlite"
367    }
368
369    fn schema_version(&self) -> Result<i32> {
370        self.storage.with_connection(|conn| {
371            let version: i32 = conn
372                .query_row("SELECT MAX(version) FROM schema_version", [], |row| {
373                    row.get(0)
374                })
375                .unwrap_or(0);
376            Ok(version)
377        })
378    }
379}
380
381impl TransactionalBackend for SqliteBackend {
382    fn with_transaction<F, T>(&self, f: F) -> Result<T>
383    where
384        F: FnOnce(&dyn StorageBackend) -> Result<T>,
385    {
386        // Note: This is where we would ideally pass a transaction-aware
387        // backend wrapper. For now, since SQLite doesn't support nested
388        // transactions easily without savepoints (which we are adding),
389        // we just execute the closure.
390        // The closure expects &dyn StorageBackend, so we pass self.
391        f(self)
392    }
393
394    fn savepoint(&self, name: &str) -> Result<()> {
395        self.storage.with_connection(|conn| {
396            conn.execute(&format!("SAVEPOINT {}", name), [])?;
397            Ok(())
398        })
399    }
400
401    fn release_savepoint(&self, name: &str) -> Result<()> {
402        self.storage.with_connection(|conn| {
403            conn.execute(&format!("RELEASE SAVEPOINT {}", name), [])?;
404            Ok(())
405        })
406    }
407
408    fn rollback_to_savepoint(&self, name: &str) -> Result<()> {
409        self.storage.with_connection(|conn| {
410            conn.execute(&format!("ROLLBACK TO SAVEPOINT {}", name), [])?;
411            Ok(())
412        })
413    }
414}
415
416impl CloudSyncBackend for SqliteBackend {
417    fn push(&self) -> Result<SyncResult> {
418        // Placeholder - actual cloud sync is handled by the sync module
419        Ok(SyncResult {
420            success: true,
421            pushed_count: 0,
422            pulled_count: 0,
423            conflicts_resolved: 0,
424            error: None,
425            new_version: 0,
426        })
427    }
428
429    fn pull(&self) -> Result<SyncResult> {
430        // Placeholder - actual cloud sync is handled by the sync module
431        Ok(SyncResult {
432            success: true,
433            pushed_count: 0,
434            pulled_count: 0,
435            conflicts_resolved: 0,
436            error: None,
437            new_version: 0,
438        })
439    }
440
441    fn sync_delta(&self, since_version: u64) -> Result<SyncDelta> {
442        self.storage.with_connection(|conn| {
443            let delta = get_sync_delta(conn, since_version as i64)?;
444            Ok(SyncDelta {
445                created: delta.created,
446                updated: delta.updated,
447                deleted: delta.deleted,
448                version: delta.to_version as u64,
449            })
450        })
451    }
452
453    fn sync_state(&self) -> Result<SyncState> {
454        self.storage.with_connection(|conn| {
455            let version = get_sync_version(conn)?;
456            let (last_sync, pending_changes): (Option<String>, i64) = conn
457                .query_row(
458                    "SELECT last_sync, pending_changes FROM sync_state WHERE id = 1",
459                    [],
460                    |row| Ok((row.get(0)?, row.get(1)?)),
461                )
462                .unwrap_or((None, 0));
463
464            let last_sync = last_sync.and_then(|s| {
465                chrono::DateTime::parse_from_rfc3339(&s)
466                    .map(|dt| dt.with_timezone(&chrono::Utc))
467                    .ok()
468            });
469
470            Ok(SyncState {
471                local_version: version.version as u64,
472                remote_version: None,
473                last_sync,
474                has_pending_changes: pending_changes > 0,
475                pending_count: pending_changes as usize,
476            })
477        })
478    }
479
480    fn force_sync(&self) -> Result<SyncResult> {
481        // Push then pull
482        self.push()?;
483        self.pull()
484    }
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490    use crate::types::{MemoryScope, MemoryTier, MemoryType};
491
492    #[test]
493    fn test_create_in_memory() {
494        let backend = SqliteBackend::in_memory().unwrap();
495        assert_eq!(backend.backend_name(), "sqlite");
496    }
497
498    #[test]
499    fn test_health_check() {
500        let backend = SqliteBackend::in_memory().unwrap();
501        let health = backend.health_check().unwrap();
502        assert!(health.healthy);
503        assert!(health.latency_ms >= 0.0);
504    }
505
506    #[test]
507    fn test_get_stats() {
508        let backend = SqliteBackend::in_memory().unwrap();
509        let stats = backend.get_stats().unwrap();
510        assert_eq!(stats.total_memories, 0);
511        assert!(stats.storage_mode.starts_with("sqlite"));
512    }
513
514    #[test]
515    fn test_crud_operations() {
516        let backend = SqliteBackend::in_memory().unwrap();
517
518        // Create
519        let input = CreateMemoryInput {
520            content: "Test memory".to_string(),
521            memory_type: MemoryType::Note,
522            tags: vec!["test".to_string()],
523            metadata: HashMap::new(),
524            importance: Some(0.5),
525            scope: MemoryScope::Global,
526            workspace: Some("default".to_string()),
527            tier: MemoryTier::Permanent,
528            defer_embedding: true,
529            ttl_seconds: None,
530            dedup_mode: crate::types::DedupMode::Allow,
531            dedup_threshold: None,
532            event_time: None,
533            event_duration_seconds: None,
534            trigger_pattern: None,
535            summary_of_id: None,
536        };
537
538        let memory = backend.create_memory(input).unwrap();
539        assert_eq!(memory.content, "Test memory");
540        assert_eq!(memory.memory_type, MemoryType::Note);
541
542        // Read
543        let retrieved = backend.get_memory(memory.id).unwrap();
544        assert!(retrieved.is_some());
545        let retrieved = retrieved.unwrap();
546        assert_eq!(retrieved.id, memory.id);
547
548        // Update
549        let update_input = UpdateMemoryInput {
550            content: Some("Updated memory".to_string()),
551            memory_type: None,
552            tags: None,
553            metadata: None,
554            importance: None,
555            scope: None,
556            ttl_seconds: None,
557            event_time: None,
558            trigger_pattern: None,
559        };
560        let updated = backend.update_memory(memory.id, update_input).unwrap();
561        assert_eq!(updated.content, "Updated memory");
562
563        // Delete
564        backend.delete_memory(memory.id).unwrap();
565        let deleted = backend.get_memory(memory.id).unwrap();
566        assert!(deleted.is_none());
567    }
568}