1use std::collections::HashMap;
8use std::time::Instant;
9
10use crate::error::Result;
11use crate::types::{
12 CreateCrossRefInput, CreateMemoryInput, CrossReference, EdgeType, ListOptions, Memory,
13 MemoryId, SearchOptions, SearchResult, StorageConfig, UpdateMemoryInput, WorkspaceStats,
14};
15
16use super::backend::{
17 BatchCreateResult, BatchDeleteResult, CloudSyncBackend, HealthStatus, StorageBackend,
18 StorageStats, SyncDelta, SyncResult, SyncState, TransactionalBackend,
19};
20use super::connection::Storage;
21use super::queries::{
22 self, delete_memory_batch, get_related, get_sync_delta, get_sync_version, list_tags,
23};
24use crate::search::{hybrid_search, SearchConfig};
25
26pub struct SqliteBackend {
32 storage: Storage,
33}
34
35impl SqliteBackend {
36 pub fn new(config: StorageConfig) -> Result<Self> {
38 let storage = Storage::open(config)?;
39 Ok(Self { storage })
40 }
41
42 pub fn in_memory() -> Result<Self> {
44 let storage = Storage::open_in_memory()?;
45 Ok(Self { storage })
46 }
47
48 pub fn storage(&self) -> &Storage {
50 &self.storage
51 }
52
53 pub fn storage_mut(&mut self) -> &mut Storage {
55 &mut self.storage
56 }
57}
58
59impl StorageBackend for SqliteBackend {
60 fn create_memory(&self, input: CreateMemoryInput) -> Result<Memory> {
61 self.storage
62 .with_transaction(|conn| queries::create_memory(conn, &input))
63 }
64
65 fn get_memory(&self, id: MemoryId) -> Result<Option<Memory>> {
66 self.storage
67 .with_connection(|conn| match queries::get_memory(conn, id) {
68 Ok(memory) => Ok(Some(memory)),
69 Err(crate::error::EngramError::NotFound(_)) => Ok(None),
70 Err(e) => Err(e),
71 })
72 }
73
74 fn update_memory(&self, id: MemoryId, input: UpdateMemoryInput) -> Result<Memory> {
75 self.storage
76 .with_transaction(|conn| queries::update_memory(conn, id, &input))
77 }
78
79 fn delete_memory(&self, id: MemoryId) -> Result<()> {
80 self.storage
81 .with_transaction(|conn| queries::delete_memory(conn, id))
82 }
83
84 fn create_memories_batch(&self, inputs: Vec<CreateMemoryInput>) -> Result<BatchCreateResult> {
85 let start = Instant::now();
86 let mut created = Vec::new();
87 let mut failed = Vec::new();
88
89 self.storage.with_transaction(|conn| {
90 for (idx, input) in inputs.into_iter().enumerate() {
91 match queries::create_memory(conn, &input) {
92 Ok(memory) => created.push(memory),
93 Err(e) => failed.push((idx, e.to_string())),
94 }
95 }
96 Ok(())
97 })?;
98
99 Ok(BatchCreateResult {
100 created,
101 failed,
102 elapsed_ms: start.elapsed().as_secs_f64() * 1000.0,
103 })
104 }
105
106 fn delete_memories_batch(&self, ids: Vec<MemoryId>) -> Result<BatchDeleteResult> {
107 self.storage.with_transaction(|conn| {
108 let result = delete_memory_batch(conn, &ids)?;
109 let mut not_found = Vec::new();
110 let mut failed = Vec::new();
111
112 for err in &result.failed {
113 if let Some(id) = err.id {
114 let msg = err.error.clone();
115 if msg.to_lowercase().contains("notfound")
117 || msg.to_lowercase().contains("not found")
118 {
119 not_found.push(id);
120 } else {
121 failed.push((id, msg));
122 }
123 }
124 }
125
126 Ok(BatchDeleteResult {
127 deleted_count: result.total_deleted,
128 not_found,
129 failed,
130 })
131 })
132 }
133
134 fn list_memories(&self, options: ListOptions) -> Result<Vec<Memory>> {
135 self.storage
136 .with_connection(|conn| queries::list_memories(conn, &options))
137 }
138
139 fn count_memories(&self, options: ListOptions) -> Result<i64> {
140 self.storage.with_connection(|conn| {
141 let now = chrono::Utc::now().to_rfc3339();
142
143 let mut sql = String::from("SELECT COUNT(DISTINCT m.id) FROM memories m");
144 let mut conditions = vec!["m.valid_to IS NULL".to_string()];
145 let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
146
147 conditions.push("(m.expires_at IS NULL OR m.expires_at > ?)".to_string());
149 params.push(Box::new(now));
150
151 if let Some(ref tags) = options.tags {
153 if !tags.is_empty() {
154 sql.push_str(
155 " JOIN memory_tags mt ON m.id = mt.memory_id
156 JOIN tags t ON mt.tag_id = t.id",
157 );
158 let placeholders: Vec<String> = tags.iter().map(|_| "?".to_string()).collect();
159 conditions.push(format!("t.name IN ({})", placeholders.join(", ")));
160 for tag in tags {
161 params.push(Box::new(tag.clone()));
162 }
163 }
164 }
165
166 if let Some(ref memory_type) = options.memory_type {
168 conditions.push("m.memory_type = ?".to_string());
169 params.push(Box::new(memory_type.as_str().to_string()));
170 }
171
172 if let Some(ref metadata_filter) = options.metadata_filter {
174 for (key, value) in metadata_filter {
175 queries::metadata_value_to_param(key, value, &mut conditions, &mut params)?;
176 }
177 }
178
179 if let Some(ref scope) = options.scope {
181 conditions.push("m.scope_type = ?".to_string());
182 params.push(Box::new(scope.scope_type().to_string()));
183 if let Some(scope_id) = scope.scope_id() {
184 conditions.push("m.scope_id = ?".to_string());
185 params.push(Box::new(scope_id.to_string()));
186 } else {
187 conditions.push("m.scope_id IS NULL".to_string());
188 }
189 }
190
191 if let Some(ref workspace) = options.workspace {
193 conditions.push("m.workspace = ?".to_string());
194 params.push(Box::new(workspace.clone()));
195 }
196
197 if let Some(ref tier) = options.tier {
199 conditions.push("m.tier = ?".to_string());
200 params.push(Box::new(tier.as_str().to_string()));
201 }
202
203 if !options.include_archived {
205 conditions.push(
206 "(m.lifecycle_state IS NULL OR m.lifecycle_state != 'archived')".to_string(),
207 );
208 }
209
210 sql.push_str(" WHERE ");
211 sql.push_str(&conditions.join(" AND "));
212
213 let param_refs: Vec<&dyn rusqlite::ToSql> = params.iter().map(|b| b.as_ref()).collect();
214 let count: i64 = conn.query_row(&sql, param_refs.as_slice(), |row| row.get(0))?;
215
216 Ok(count)
217 })
218 }
219
220 fn search_memories(&self, query: &str, options: SearchOptions) -> Result<Vec<SearchResult>> {
221 self.storage.with_connection(|conn| {
222 let config = SearchConfig::default();
223 hybrid_search(conn, query, None, &options, &config)
229 })
230 }
231
232 fn create_crossref(
233 &self,
234 from_id: MemoryId,
235 to_id: MemoryId,
236 edge_type: EdgeType,
237 score: f32,
238 ) -> Result<CrossReference> {
239 self.storage.with_transaction(|conn| {
240 let input = CreateCrossRefInput {
241 from_id,
242 to_id,
243 edge_type,
244 strength: Some(score),
245 source_context: None,
246 pinned: false,
247 };
248 queries::create_crossref(conn, &input)
249 })
250 }
251
252 fn get_crossrefs(&self, memory_id: MemoryId) -> Result<Vec<CrossReference>> {
253 self.storage
254 .with_connection(|conn| get_related(conn, memory_id))
255 }
256
257 fn delete_crossref(&self, from_id: MemoryId, to_id: MemoryId) -> Result<()> {
258 self.storage.with_transaction(|conn| {
259 for edge_type in EdgeType::all() {
264 let _ = queries::delete_crossref(conn, from_id, to_id, *edge_type);
266 }
267 Ok(())
268 })
269 }
270
271 fn list_tags(&self) -> Result<Vec<(String, i64)>> {
272 self.storage.with_connection(|conn| {
273 let tags = list_tags(conn)?;
274 Ok(tags.into_iter().map(|t| (t.name, t.count)).collect())
275 })
276 }
277
278 fn get_memories_by_tag(&self, tag: &str, limit: Option<usize>) -> Result<Vec<Memory>> {
279 self.storage.with_connection(|conn| {
280 let options = ListOptions {
281 tags: Some(vec![tag.to_string()]),
282 limit: limit.map(|v| v as i64),
283 ..Default::default()
284 };
285 queries::list_memories(conn, &options)
286 })
287 }
288
289 fn list_workspaces(&self) -> Result<Vec<(String, i64)>> {
290 self.storage.with_connection(|conn| {
291 let workspaces = queries::list_workspaces(conn)?;
292 Ok(workspaces
293 .into_iter()
294 .map(|w| (w.workspace, w.memory_count))
295 .collect())
296 })
297 }
298
299 fn get_workspace_stats(&self, workspace: &str) -> Result<HashMap<String, i64>> {
300 self.storage.with_connection(|conn| {
301 let stats: WorkspaceStats = queries::get_workspace_stats(conn, workspace)?;
302 let mut map = HashMap::new();
303 map.insert("memory_count".to_string(), stats.memory_count);
304 map.insert("permanent_count".to_string(), stats.permanent_count);
305 map.insert("daily_count".to_string(), stats.daily_count);
306 Ok(map)
307 })
308 }
309
310 fn move_to_workspace(&self, ids: Vec<MemoryId>, workspace: &str) -> Result<usize> {
311 self.storage.with_transaction(|conn| {
312 let mut moved = 0usize;
313 for id in ids {
314 if queries::move_to_workspace(conn, id, workspace).is_ok() {
315 moved += 1;
316 }
317 }
318 Ok(moved)
319 })
320 }
321
322 fn get_stats(&self) -> Result<StorageStats> {
323 self.storage.with_connection(queries::get_stats)
324 }
325
326 fn health_check(&self) -> Result<HealthStatus> {
327 let start = Instant::now();
328
329 let result = self.storage.with_connection(|conn| {
330 conn.query_row("SELECT 1", [], |_| Ok(()))?;
331 Ok(())
332 });
333
334 let latency_ms = start.elapsed().as_secs_f64() * 1000.0;
335 let db_path = self.storage.db_path().to_string();
336
337 match result {
338 Ok(()) => Ok(HealthStatus {
339 healthy: true,
340 latency_ms,
341 error: None,
342 details: HashMap::from([
343 ("db_path".to_string(), db_path),
344 (
345 "storage_mode".to_string(),
346 format!("{:?}", self.storage.storage_mode()),
347 ),
348 ]),
349 }),
350 Err(e) => Ok(HealthStatus {
351 healthy: false,
352 latency_ms,
353 error: Some(e.to_string()),
354 details: HashMap::from([("db_path".to_string(), db_path)]),
355 }),
356 }
357 }
358
359 fn optimize(&self) -> Result<()> {
360 self.storage.vacuum()?;
361 self.storage.checkpoint()?;
362 Ok(())
363 }
364
365 fn backend_name(&self) -> &'static str {
366 "sqlite"
367 }
368
369 fn schema_version(&self) -> Result<i32> {
370 self.storage.with_connection(|conn| {
371 let version: i32 = conn
372 .query_row("SELECT MAX(version) FROM schema_version", [], |row| {
373 row.get(0)
374 })
375 .unwrap_or(0);
376 Ok(version)
377 })
378 }
379}
380
381impl TransactionalBackend for SqliteBackend {
382 fn with_transaction<F, T>(&self, f: F) -> Result<T>
383 where
384 F: FnOnce(&dyn StorageBackend) -> Result<T>,
385 {
386 f(self)
392 }
393
394 fn savepoint(&self, name: &str) -> Result<()> {
395 self.storage.with_connection(|conn| {
396 conn.execute(&format!("SAVEPOINT {}", name), [])?;
397 Ok(())
398 })
399 }
400
401 fn release_savepoint(&self, name: &str) -> Result<()> {
402 self.storage.with_connection(|conn| {
403 conn.execute(&format!("RELEASE SAVEPOINT {}", name), [])?;
404 Ok(())
405 })
406 }
407
408 fn rollback_to_savepoint(&self, name: &str) -> Result<()> {
409 self.storage.with_connection(|conn| {
410 conn.execute(&format!("ROLLBACK TO SAVEPOINT {}", name), [])?;
411 Ok(())
412 })
413 }
414}
415
416impl CloudSyncBackend for SqliteBackend {
417 fn push(&self) -> Result<SyncResult> {
418 Ok(SyncResult {
420 success: true,
421 pushed_count: 0,
422 pulled_count: 0,
423 conflicts_resolved: 0,
424 error: None,
425 new_version: 0,
426 })
427 }
428
429 fn pull(&self) -> Result<SyncResult> {
430 Ok(SyncResult {
432 success: true,
433 pushed_count: 0,
434 pulled_count: 0,
435 conflicts_resolved: 0,
436 error: None,
437 new_version: 0,
438 })
439 }
440
441 fn sync_delta(&self, since_version: u64) -> Result<SyncDelta> {
442 self.storage.with_connection(|conn| {
443 let delta = get_sync_delta(conn, since_version as i64)?;
444 Ok(SyncDelta {
445 created: delta.created,
446 updated: delta.updated,
447 deleted: delta.deleted,
448 version: delta.to_version as u64,
449 })
450 })
451 }
452
453 fn sync_state(&self) -> Result<SyncState> {
454 self.storage.with_connection(|conn| {
455 let version = get_sync_version(conn)?;
456 let (last_sync, pending_changes): (Option<String>, i64) = conn
457 .query_row(
458 "SELECT last_sync, pending_changes FROM sync_state WHERE id = 1",
459 [],
460 |row| Ok((row.get(0)?, row.get(1)?)),
461 )
462 .unwrap_or((None, 0));
463
464 let last_sync = last_sync.and_then(|s| {
465 chrono::DateTime::parse_from_rfc3339(&s)
466 .map(|dt| dt.with_timezone(&chrono::Utc))
467 .ok()
468 });
469
470 Ok(SyncState {
471 local_version: version.version as u64,
472 remote_version: None,
473 last_sync,
474 has_pending_changes: pending_changes > 0,
475 pending_count: pending_changes as usize,
476 })
477 })
478 }
479
480 fn force_sync(&self) -> Result<SyncResult> {
481 self.push()?;
483 self.pull()
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490 use crate::types::{MemoryScope, MemoryTier, MemoryType};
491
492 #[test]
493 fn test_create_in_memory() {
494 let backend = SqliteBackend::in_memory().unwrap();
495 assert_eq!(backend.backend_name(), "sqlite");
496 }
497
498 #[test]
499 fn test_health_check() {
500 let backend = SqliteBackend::in_memory().unwrap();
501 let health = backend.health_check().unwrap();
502 assert!(health.healthy);
503 assert!(health.latency_ms >= 0.0);
504 }
505
506 #[test]
507 fn test_get_stats() {
508 let backend = SqliteBackend::in_memory().unwrap();
509 let stats = backend.get_stats().unwrap();
510 assert_eq!(stats.total_memories, 0);
511 assert!(stats.storage_mode.starts_with("sqlite"));
512 }
513
514 #[test]
515 fn test_crud_operations() {
516 let backend = SqliteBackend::in_memory().unwrap();
517
518 let input = CreateMemoryInput {
520 content: "Test memory".to_string(),
521 memory_type: MemoryType::Note,
522 tags: vec!["test".to_string()],
523 metadata: HashMap::new(),
524 importance: Some(0.5),
525 scope: MemoryScope::Global,
526 workspace: Some("default".to_string()),
527 tier: MemoryTier::Permanent,
528 defer_embedding: true,
529 ttl_seconds: None,
530 dedup_mode: crate::types::DedupMode::Allow,
531 dedup_threshold: None,
532 event_time: None,
533 event_duration_seconds: None,
534 trigger_pattern: None,
535 summary_of_id: None,
536 };
537
538 let memory = backend.create_memory(input).unwrap();
539 assert_eq!(memory.content, "Test memory");
540 assert_eq!(memory.memory_type, MemoryType::Note);
541
542 let retrieved = backend.get_memory(memory.id).unwrap();
544 assert!(retrieved.is_some());
545 let retrieved = retrieved.unwrap();
546 assert_eq!(retrieved.id, memory.id);
547
548 let update_input = UpdateMemoryInput {
550 content: Some("Updated memory".to_string()),
551 memory_type: None,
552 tags: None,
553 metadata: None,
554 importance: None,
555 scope: None,
556 ttl_seconds: None,
557 event_time: None,
558 trigger_pattern: None,
559 };
560 let updated = backend.update_memory(memory.id, update_input).unwrap();
561 assert_eq!(updated.content, "Updated memory");
562
563 backend.delete_memory(memory.id).unwrap();
565 let deleted = backend.get_memory(memory.id).unwrap();
566 assert!(deleted.is_none());
567 }
568}