Skip to main content

codemem_storage/
backend.rs

1//! `StorageBackend` trait implementation for Storage.
2
3use crate::{MemoryRow, Storage};
4use codemem_core::{
5    CodememError, ConsolidationLogEntry, Edge, GraphNode, MemoryNode, NodeKind, Session,
6    StorageBackend, StorageStats,
7};
8use rusqlite::params;
9use std::collections::HashMap;
10
11impl StorageBackend for Storage {
12    fn insert_memory(&self, memory: &MemoryNode) -> Result<(), CodememError> {
13        Storage::insert_memory(self, memory)
14    }
15
16    fn get_memory(&self, id: &str) -> Result<Option<MemoryNode>, CodememError> {
17        Storage::get_memory(self, id)
18    }
19
20    fn get_memories_batch(&self, ids: &[&str]) -> Result<Vec<MemoryNode>, CodememError> {
21        if ids.is_empty() {
22            return Ok(Vec::new());
23        }
24        let conn = self.conn();
25
26        let placeholders: Vec<String> = (1..=ids.len()).map(|i| format!("?{i}")).collect();
27        let sql = format!(
28            "SELECT id, content, memory_type, importance, confidence, access_count, content_hash, tags, metadata, namespace, created_at, updated_at, last_accessed_at FROM memories WHERE id IN ({})",
29            placeholders.join(",")
30        );
31
32        let mut stmt = conn
33            .prepare(&sql)
34            .map_err(|e| CodememError::Storage(e.to_string()))?;
35
36        let params: Vec<&dyn rusqlite::types::ToSql> = ids
37            .iter()
38            .map(|id| id as &dyn rusqlite::types::ToSql)
39            .collect();
40
41        let rows = stmt
42            .query_map(params.as_slice(), |row| {
43                Ok(MemoryRow {
44                    id: row.get(0)?,
45                    content: row.get(1)?,
46                    memory_type: row.get(2)?,
47                    importance: row.get(3)?,
48                    confidence: row.get(4)?,
49                    access_count: row.get(5)?,
50                    content_hash: row.get(6)?,
51                    tags: row.get(7)?,
52                    metadata: row.get(8)?,
53                    namespace: row.get(9)?,
54                    created_at: row.get(10)?,
55                    updated_at: row.get(11)?,
56                    last_accessed_at: row.get(12)?,
57                })
58            })
59            .map_err(|e| CodememError::Storage(e.to_string()))?;
60
61        let mut memories = Vec::new();
62        for row in rows {
63            let row = row.map_err(|e| CodememError::Storage(e.to_string()))?;
64            memories.push(row.into_memory_node()?);
65        }
66        Ok(memories)
67    }
68
69    fn update_memory(
70        &self,
71        id: &str,
72        content: &str,
73        importance: Option<f64>,
74    ) -> Result<(), CodememError> {
75        Storage::update_memory(self, id, content, importance)
76    }
77
78    fn delete_memory(&self, id: &str) -> Result<bool, CodememError> {
79        Storage::delete_memory(self, id)
80    }
81
82    fn list_memory_ids(&self) -> Result<Vec<String>, CodememError> {
83        Storage::list_memory_ids(self)
84    }
85
86    fn list_memory_ids_for_namespace(&self, namespace: &str) -> Result<Vec<String>, CodememError> {
87        Storage::list_memory_ids_for_namespace(self, namespace)
88    }
89
90    fn list_namespaces(&self) -> Result<Vec<String>, CodememError> {
91        Storage::list_namespaces(self)
92    }
93
94    fn memory_count(&self) -> Result<usize, CodememError> {
95        Storage::memory_count(self)
96    }
97
98    fn store_embedding(&self, memory_id: &str, embedding: &[f32]) -> Result<(), CodememError> {
99        Storage::store_embedding(self, memory_id, embedding)
100    }
101
102    fn get_embedding(&self, memory_id: &str) -> Result<Option<Vec<f32>>, CodememError> {
103        Storage::get_embedding(self, memory_id)
104    }
105
106    fn delete_embedding(&self, memory_id: &str) -> Result<bool, CodememError> {
107        let conn = self.conn();
108        let deleted = conn
109            .execute(
110                "DELETE FROM memory_embeddings WHERE memory_id = ?1",
111                [memory_id],
112            )
113            .map_err(|e| CodememError::Storage(e.to_string()))?;
114        Ok(deleted > 0)
115    }
116
117    fn list_all_embeddings(&self) -> Result<Vec<(String, Vec<f32>)>, CodememError> {
118        let conn = self.conn();
119        let mut stmt = conn
120            .prepare("SELECT memory_id, embedding FROM memory_embeddings")
121            .map_err(|e| CodememError::Storage(e.to_string()))?;
122        let rows = stmt
123            .query_map([], |row| {
124                let id: String = row.get(0)?;
125                let blob: Vec<u8> = row.get(1)?;
126                Ok((id, blob))
127            })
128            .map_err(|e| CodememError::Storage(e.to_string()))?;
129        let mut result = Vec::new();
130        for row in rows {
131            let (id, blob) = row.map_err(|e| CodememError::Storage(e.to_string()))?;
132            let floats: Vec<f32> = blob
133                .chunks_exact(4)
134                .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
135                .collect();
136            result.push((id, floats));
137        }
138        Ok(result)
139    }
140
141    fn insert_graph_node(&self, node: &GraphNode) -> Result<(), CodememError> {
142        Storage::insert_graph_node(self, node)
143    }
144
145    fn get_graph_node(&self, id: &str) -> Result<Option<GraphNode>, CodememError> {
146        Storage::get_graph_node(self, id)
147    }
148
149    fn delete_graph_node(&self, id: &str) -> Result<bool, CodememError> {
150        Storage::delete_graph_node(self, id)
151    }
152
153    fn all_graph_nodes(&self) -> Result<Vec<GraphNode>, CodememError> {
154        Storage::all_graph_nodes(self)
155    }
156
157    fn insert_graph_edge(&self, edge: &Edge) -> Result<(), CodememError> {
158        Storage::insert_graph_edge(self, edge)
159    }
160
161    fn get_edges_for_node(&self, node_id: &str) -> Result<Vec<Edge>, CodememError> {
162        Storage::get_edges_for_node(self, node_id)
163    }
164
165    fn all_graph_edges(&self) -> Result<Vec<Edge>, CodememError> {
166        Storage::all_graph_edges(self)
167    }
168
169    fn delete_graph_edges_for_node(&self, node_id: &str) -> Result<usize, CodememError> {
170        Storage::delete_graph_edges_for_node(self, node_id)
171    }
172
173    fn start_session(&self, id: &str, namespace: Option<&str>) -> Result<(), CodememError> {
174        Storage::start_session(self, id, namespace)
175    }
176
177    fn end_session(&self, id: &str, summary: Option<&str>) -> Result<(), CodememError> {
178        Storage::end_session(self, id, summary)
179    }
180
181    fn list_sessions(
182        &self,
183        namespace: Option<&str>,
184        limit: usize,
185    ) -> Result<Vec<Session>, CodememError> {
186        self.list_sessions_with_limit(namespace, limit)
187    }
188
189    fn insert_consolidation_log(
190        &self,
191        cycle_type: &str,
192        affected_count: usize,
193    ) -> Result<(), CodememError> {
194        Storage::insert_consolidation_log(self, cycle_type, affected_count)
195    }
196
197    fn last_consolidation_runs(&self) -> Result<Vec<ConsolidationLogEntry>, CodememError> {
198        Storage::last_consolidation_runs(self)
199    }
200
201    fn get_repeated_searches(
202        &self,
203        min_count: usize,
204        namespace: Option<&str>,
205    ) -> Result<Vec<(String, usize, Vec<String>)>, CodememError> {
206        Storage::get_repeated_searches(self, min_count, namespace)
207    }
208
209    fn get_file_hotspots(
210        &self,
211        min_count: usize,
212        namespace: Option<&str>,
213    ) -> Result<Vec<(String, usize, Vec<String>)>, CodememError> {
214        Storage::get_file_hotspots(self, min_count, namespace)
215    }
216
217    fn get_tool_usage_stats(
218        &self,
219        namespace: Option<&str>,
220    ) -> Result<Vec<(String, usize)>, CodememError> {
221        let map = Storage::get_tool_usage_stats(self, namespace)?;
222        let mut vec: Vec<(String, usize)> = map.into_iter().collect();
223        vec.sort_by(|a, b| b.1.cmp(&a.1));
224        Ok(vec)
225    }
226
227    fn get_decision_chains(
228        &self,
229        min_count: usize,
230        namespace: Option<&str>,
231    ) -> Result<Vec<(String, usize, Vec<String>)>, CodememError> {
232        Storage::get_decision_chains(self, min_count, namespace)
233    }
234
235    fn decay_stale_memories(
236        &self,
237        threshold_ts: i64,
238        decay_factor: f64,
239    ) -> Result<usize, CodememError> {
240        let conn = self.conn();
241        let rows = conn
242            .execute(
243                "UPDATE memories SET importance = importance * ?1 WHERE last_accessed_at < ?2",
244                params![decay_factor, threshold_ts],
245            )
246            .map_err(|e| CodememError::Storage(e.to_string()))?;
247        Ok(rows)
248    }
249
250    fn list_memories_for_creative(
251        &self,
252    ) -> Result<Vec<(String, String, Vec<String>)>, CodememError> {
253        let conn = self.conn();
254        let mut stmt = conn
255            .prepare("SELECT id, memory_type, tags FROM memories ORDER BY created_at DESC")
256            .map_err(|e| CodememError::Storage(e.to_string()))?;
257
258        let rows = stmt
259            .query_map([], |row| {
260                Ok((
261                    row.get::<_, String>(0)?,
262                    row.get::<_, String>(1)?,
263                    row.get::<_, String>(2)?,
264                ))
265            })
266            .map_err(|e| CodememError::Storage(e.to_string()))?
267            .collect::<Result<Vec<_>, _>>()
268            .map_err(|e| CodememError::Storage(e.to_string()))?;
269
270        Ok(rows
271            .into_iter()
272            .map(|(id, mtype, tags_json)| {
273                let tags: Vec<String> = serde_json::from_str(&tags_json).unwrap_or_default();
274                (id, mtype, tags)
275            })
276            .collect())
277    }
278
279    fn find_cluster_duplicates(&self) -> Result<Vec<(String, String, f64)>, CodememError> {
280        let conn = self.conn();
281        let mut stmt = conn
282            .prepare(
283                "SELECT a.id, b.id, 1.0 as similarity
284                 FROM memories a
285                 INNER JOIN memories b ON substr(a.content_hash, 1, 16) = substr(b.content_hash, 1, 16)
286                 WHERE a.id < b.id",
287            )
288            .map_err(|e| CodememError::Storage(e.to_string()))?;
289
290        let rows = stmt
291            .query_map([], |row| {
292                Ok((
293                    row.get::<_, String>(0)?,
294                    row.get::<_, String>(1)?,
295                    row.get::<_, f64>(2)?,
296                ))
297            })
298            .map_err(|e| CodememError::Storage(e.to_string()))?
299            .collect::<Result<Vec<_>, _>>()
300            .map_err(|e| CodememError::Storage(e.to_string()))?;
301
302        Ok(rows)
303    }
304
305    fn find_forgettable(&self, importance_threshold: f64) -> Result<Vec<String>, CodememError> {
306        let conn = self.conn();
307        let mut stmt = conn
308            .prepare(
309                "SELECT id FROM memories WHERE importance < ?1 AND access_count = 0 ORDER BY importance ASC, last_accessed_at ASC",
310            )
311            .map_err(|e| CodememError::Storage(e.to_string()))?;
312
313        let ids = stmt
314            .query_map(params![importance_threshold], |row| row.get(0))
315            .map_err(|e| CodememError::Storage(e.to_string()))?
316            .collect::<Result<Vec<String>, _>>()
317            .map_err(|e| CodememError::Storage(e.to_string()))?;
318
319        Ok(ids)
320    }
321
322    fn insert_memories_batch(&self, memories: &[MemoryNode]) -> Result<(), CodememError> {
323        let conn = self.conn();
324        let tx = conn
325            .unchecked_transaction()
326            .map_err(|e| CodememError::Storage(e.to_string()))?;
327
328        for memory in memories {
329            let tags_json = serde_json::to_string(&memory.tags)?;
330            let metadata_json = serde_json::to_string(&memory.metadata)?;
331
332            tx.execute(
333                "INSERT OR IGNORE INTO memories (id, content, memory_type, importance, confidence, access_count, content_hash, tags, metadata, namespace, created_at, updated_at, last_accessed_at)
334                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13)",
335                params![
336                    memory.id,
337                    memory.content,
338                    memory.memory_type.to_string(),
339                    memory.importance,
340                    memory.confidence,
341                    memory.access_count,
342                    memory.content_hash,
343                    tags_json,
344                    metadata_json,
345                    memory.namespace,
346                    memory.created_at.timestamp(),
347                    memory.updated_at.timestamp(),
348                    memory.last_accessed_at.timestamp(),
349                ],
350            )
351            .map_err(|e| CodememError::Storage(e.to_string()))?;
352        }
353
354        tx.commit()
355            .map_err(|e| CodememError::Storage(e.to_string()))?;
356        Ok(())
357    }
358
359    fn store_embeddings_batch(&self, items: &[(&str, &[f32])]) -> Result<(), CodememError> {
360        let conn = self.conn();
361        let tx = conn
362            .unchecked_transaction()
363            .map_err(|e| CodememError::Storage(e.to_string()))?;
364
365        for (id, embedding) in items {
366            let blob: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
367            tx.execute(
368                "INSERT OR REPLACE INTO memory_embeddings (memory_id, embedding) VALUES (?1, ?2)",
369                params![id, blob],
370            )
371            .map_err(|e| CodememError::Storage(e.to_string()))?;
372        }
373
374        tx.commit()
375            .map_err(|e| CodememError::Storage(e.to_string()))?;
376        Ok(())
377    }
378
379    fn load_file_hashes(&self) -> Result<HashMap<String, String>, CodememError> {
380        let conn = self.conn();
381        let mut stmt = conn
382            .prepare("SELECT file_path, content_hash FROM file_hashes")
383            .map_err(|e| CodememError::Storage(e.to_string()))?;
384
385        let rows = stmt
386            .query_map([], |row| {
387                Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
388            })
389            .map_err(|e| CodememError::Storage(e.to_string()))?
390            .collect::<Result<Vec<_>, _>>()
391            .map_err(|e| CodememError::Storage(e.to_string()))?;
392
393        Ok(rows.into_iter().collect())
394    }
395
396    fn save_file_hashes(&self, hashes: &HashMap<String, String>) -> Result<(), CodememError> {
397        let conn = self.conn();
398        let tx = conn
399            .unchecked_transaction()
400            .map_err(|e| CodememError::Storage(e.to_string()))?;
401
402        tx.execute("DELETE FROM file_hashes", [])
403            .map_err(|e| CodememError::Storage(e.to_string()))?;
404
405        for (path, hash) in hashes {
406            tx.execute(
407                "INSERT INTO file_hashes (file_path, content_hash) VALUES (?1, ?2)",
408                params![path, hash],
409            )
410            .map_err(|e| CodememError::Storage(e.to_string()))?;
411        }
412
413        tx.commit()
414            .map_err(|e| CodememError::Storage(e.to_string()))?;
415        Ok(())
416    }
417
418    fn insert_graph_nodes_batch(&self, nodes: &[GraphNode]) -> Result<(), CodememError> {
419        let conn = self.conn();
420        let tx = conn
421            .unchecked_transaction()
422            .map_err(|e| CodememError::Storage(e.to_string()))?;
423
424        for node in nodes {
425            let payload_json =
426                serde_json::to_string(&node.payload).unwrap_or_else(|_| "{}".to_string());
427            tx.execute(
428                "INSERT OR REPLACE INTO graph_nodes (id, kind, label, payload, centrality, memory_id, namespace)
429                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
430                params![
431                    node.id,
432                    node.kind.to_string(),
433                    node.label,
434                    payload_json,
435                    node.centrality,
436                    node.memory_id,
437                    node.namespace,
438                ],
439            )
440            .map_err(|e| CodememError::Storage(e.to_string()))?;
441        }
442
443        tx.commit()
444            .map_err(|e| CodememError::Storage(e.to_string()))?;
445        Ok(())
446    }
447
448    fn insert_graph_edges_batch(&self, edges: &[Edge]) -> Result<(), CodememError> {
449        let conn = self.conn();
450        let tx = conn
451            .unchecked_transaction()
452            .map_err(|e| CodememError::Storage(e.to_string()))?;
453
454        for edge in edges {
455            let props_json =
456                serde_json::to_string(&edge.properties).unwrap_or_else(|_| "{}".to_string());
457            tx.execute(
458                "INSERT OR REPLACE INTO graph_edges (id, src, dst, relationship, weight, properties, created_at, valid_from, valid_to)
459                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
460                params![
461                    edge.id,
462                    edge.src,
463                    edge.dst,
464                    edge.relationship.to_string(),
465                    edge.weight,
466                    props_json,
467                    edge.created_at.timestamp(),
468                    edge.valid_from.map(|dt| dt.timestamp()),
469                    edge.valid_to.map(|dt| dt.timestamp()),
470                ],
471            )
472            .map_err(|e| CodememError::Storage(e.to_string()))?;
473        }
474
475        tx.commit()
476            .map_err(|e| CodememError::Storage(e.to_string()))?;
477        Ok(())
478    }
479
480    fn get_edges_at_time(&self, node_id: &str, timestamp: i64) -> Result<Vec<Edge>, CodememError> {
481        let conn = self.conn();
482        let mut stmt = conn
483            .prepare(
484                "SELECT id, src, dst, relationship, weight, properties, created_at, valid_from, valid_to
485                 FROM graph_edges
486                 WHERE (src = ?1 OR dst = ?1)
487                   AND (valid_from IS NULL OR valid_from <= ?2)
488                   AND (valid_to IS NULL OR valid_to > ?2)",
489            )
490            .map_err(|e| CodememError::Storage(e.to_string()))?;
491
492        let edges = stmt
493            .query_map(params![node_id, timestamp], |row| {
494                let rel_str: String = row.get(3)?;
495                let props_str: String = row.get(5)?;
496                let created_ts: i64 = row.get(6)?;
497                let valid_from_ts: Option<i64> = row.get(7)?;
498                let valid_to_ts: Option<i64> = row.get(8)?;
499                Ok((
500                    row.get::<_, String>(0)?,
501                    row.get::<_, String>(1)?,
502                    row.get::<_, String>(2)?,
503                    rel_str,
504                    row.get::<_, f64>(4)?,
505                    props_str,
506                    created_ts,
507                    valid_from_ts,
508                    valid_to_ts,
509                ))
510            })
511            .map_err(|e| CodememError::Storage(e.to_string()))?
512            .filter_map(|r| r.ok())
513            .filter_map(
514                |(
515                    id,
516                    src,
517                    dst,
518                    rel_str,
519                    weight,
520                    props_str,
521                    created_ts,
522                    valid_from_ts,
523                    valid_to_ts,
524                )| {
525                    let relationship: codemem_core::RelationshipType = rel_str.parse().ok()?;
526                    let properties: std::collections::HashMap<String, serde_json::Value> =
527                        serde_json::from_str(&props_str).unwrap_or_default();
528                    let created_at = chrono::DateTime::from_timestamp(created_ts, 0)?
529                        .with_timezone(&chrono::Utc);
530                    let valid_from = valid_from_ts
531                        .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
532                        .map(|dt| dt.with_timezone(&chrono::Utc));
533                    let valid_to = valid_to_ts
534                        .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
535                        .map(|dt| dt.with_timezone(&chrono::Utc));
536                    Some(Edge {
537                        id,
538                        src,
539                        dst,
540                        relationship,
541                        weight,
542                        properties,
543                        created_at,
544                        valid_from,
545                        valid_to,
546                    })
547                },
548            )
549            .collect();
550
551        Ok(edges)
552    }
553
554    fn get_stale_memories_for_decay(
555        &self,
556        threshold_ts: i64,
557    ) -> Result<Vec<(String, f64, u32, i64)>, CodememError> {
558        let conn = self.conn();
559        let mut stmt = conn
560            .prepare(
561                "SELECT id, importance, access_count, last_accessed_at FROM memories WHERE last_accessed_at < ?1",
562            )
563            .map_err(|e| CodememError::Storage(e.to_string()))?;
564
565        let rows = stmt
566            .query_map(params![threshold_ts], |row| {
567                Ok((
568                    row.get::<_, String>(0)?,
569                    row.get::<_, f64>(1)?,
570                    row.get::<_, u32>(2)?,
571                    row.get::<_, i64>(3)?,
572                ))
573            })
574            .map_err(|e| CodememError::Storage(e.to_string()))?
575            .collect::<Result<Vec<_>, _>>()
576            .map_err(|e| CodememError::Storage(e.to_string()))?;
577
578        Ok(rows)
579    }
580
581    fn batch_update_importance(&self, updates: &[(String, f64)]) -> Result<usize, CodememError> {
582        if updates.is_empty() {
583            return Ok(0);
584        }
585        let conn = self.conn();
586        let tx = conn
587            .unchecked_transaction()
588            .map_err(|e| CodememError::Storage(e.to_string()))?;
589
590        let mut count = 0usize;
591        for (id, importance) in updates {
592            let rows = tx
593                .execute(
594                    "UPDATE memories SET importance = ?1 WHERE id = ?2",
595                    params![importance, id],
596                )
597                .map_err(|e| CodememError::Storage(e.to_string()))?;
598            count += rows;
599        }
600
601        tx.commit()
602            .map_err(|e| CodememError::Storage(e.to_string()))?;
603        Ok(count)
604    }
605
606    fn session_count(&self, namespace: Option<&str>) -> Result<usize, CodememError> {
607        let conn = self.conn();
608        let count: i64 = if let Some(ns) = namespace {
609            conn.query_row(
610                "SELECT COUNT(*) FROM sessions WHERE namespace = ?1",
611                params![ns],
612                |row| row.get(0),
613            )
614            .map_err(|e| CodememError::Storage(e.to_string()))?
615        } else {
616            conn.query_row("SELECT COUNT(*) FROM sessions", [], |row| row.get(0))
617                .map_err(|e| CodememError::Storage(e.to_string()))?
618        };
619        Ok(count as usize)
620    }
621
622    fn find_unembedded_memories(&self) -> Result<Vec<(String, String)>, CodememError> {
623        let conn = self.conn();
624        let mut stmt = conn
625            .prepare(
626                "SELECT m.id, m.content FROM memories m
627                 LEFT JOIN memory_embeddings me ON m.id = me.memory_id
628                 WHERE me.memory_id IS NULL",
629            )
630            .map_err(|e| CodememError::Storage(e.to_string()))?;
631
632        let rows = stmt
633            .query_map([], |row| {
634                Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
635            })
636            .map_err(|e| CodememError::Storage(e.to_string()))?
637            .collect::<Result<Vec<_>, _>>()
638            .map_err(|e| CodememError::Storage(e.to_string()))?;
639
640        Ok(rows)
641    }
642
643    fn search_graph_nodes(
644        &self,
645        query: &str,
646        namespace: Option<&str>,
647        limit: usize,
648    ) -> Result<Vec<GraphNode>, CodememError> {
649        let conn = self.conn();
650        let pattern = format!("%{}%", query.to_lowercase());
651
652        let (sql, params_vec): (String, Vec<Box<dyn rusqlite::types::ToSql>>) =
653            if let Some(ns) = namespace {
654                (
655                    "SELECT id, kind, label, payload, centrality, memory_id, namespace \
656                 FROM graph_nodes WHERE LOWER(label) LIKE ?1 AND namespace = ?2 \
657                 ORDER BY centrality DESC LIMIT ?3"
658                        .to_string(),
659                    vec![
660                        Box::new(pattern) as Box<dyn rusqlite::types::ToSql>,
661                        Box::new(ns.to_string()),
662                        Box::new(limit as i64),
663                    ],
664                )
665            } else {
666                (
667                    "SELECT id, kind, label, payload, centrality, memory_id, namespace \
668                 FROM graph_nodes WHERE LOWER(label) LIKE ?1 \
669                 ORDER BY centrality DESC LIMIT ?2"
670                        .to_string(),
671                    vec![
672                        Box::new(pattern) as Box<dyn rusqlite::types::ToSql>,
673                        Box::new(limit as i64),
674                    ],
675                )
676            };
677
678        let refs: Vec<&dyn rusqlite::types::ToSql> =
679            params_vec.iter().map(|p| p.as_ref()).collect();
680        let mut stmt = conn
681            .prepare(&sql)
682            .map_err(|e| CodememError::Storage(e.to_string()))?;
683
684        let rows = stmt
685            .query_map(refs.as_slice(), |row| {
686                let kind_str: String = row.get(1)?;
687                let payload_str: String = row.get(3)?;
688                Ok(GraphNode {
689                    id: row.get(0)?,
690                    kind: kind_str.parse().unwrap_or(NodeKind::Memory),
691                    label: row.get(2)?,
692                    payload: serde_json::from_str(&payload_str).unwrap_or_default(),
693                    centrality: row.get(4)?,
694                    memory_id: row.get(5)?,
695                    namespace: row.get(6)?,
696                })
697            })
698            .map_err(|e| CodememError::Storage(e.to_string()))?
699            .collect::<Result<Vec<_>, _>>()
700            .map_err(|e| CodememError::Storage(e.to_string()))?;
701
702        Ok(rows)
703    }
704
705    fn list_memories_filtered(
706        &self,
707        namespace: Option<&str>,
708        memory_type: Option<&str>,
709    ) -> Result<Vec<MemoryNode>, CodememError> {
710        let conn = self.conn();
711        let mut sql = "SELECT id, content, memory_type, importance, confidence, access_count, \
712                        content_hash, tags, metadata, namespace, created_at, updated_at, \
713                        last_accessed_at FROM memories WHERE 1=1"
714            .to_string();
715        let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
716
717        if let Some(ns) = namespace {
718            param_values.push(Box::new(ns.to_string()));
719            sql.push_str(&format!(" AND namespace = ?{}", param_values.len()));
720        }
721        if let Some(mt) = memory_type {
722            param_values.push(Box::new(mt.to_string()));
723            sql.push_str(&format!(" AND memory_type = ?{}", param_values.len()));
724        }
725        sql.push_str(" ORDER BY created_at DESC");
726
727        let refs: Vec<&dyn rusqlite::types::ToSql> =
728            param_values.iter().map(|p| p.as_ref()).collect();
729        let mut stmt = conn
730            .prepare(&sql)
731            .map_err(|e| CodememError::Storage(e.to_string()))?;
732
733        let rows = stmt
734            .query_map(refs.as_slice(), |row| {
735                Ok(MemoryRow {
736                    id: row.get(0)?,
737                    content: row.get(1)?,
738                    memory_type: row.get(2)?,
739                    importance: row.get(3)?,
740                    confidence: row.get(4)?,
741                    access_count: row.get(5)?,
742                    content_hash: row.get(6)?,
743                    tags: row.get(7)?,
744                    metadata: row.get(8)?,
745                    namespace: row.get(9)?,
746                    created_at: row.get(10)?,
747                    updated_at: row.get(11)?,
748                    last_accessed_at: row.get(12)?,
749                })
750            })
751            .map_err(|e| CodememError::Storage(e.to_string()))?;
752
753        let mut result = Vec::new();
754        for row in rows {
755            let mr = row.map_err(|e| CodememError::Storage(e.to_string()))?;
756            result.push(mr.into_memory_node()?);
757        }
758
759        Ok(result)
760    }
761
762    fn graph_edges_for_namespace(&self, namespace: &str) -> Result<Vec<Edge>, CodememError> {
763        Storage::graph_edges_for_namespace(self, namespace)
764    }
765
766    fn stats(&self) -> Result<StorageStats, CodememError> {
767        Storage::stats(self)
768    }
769}
770
771#[cfg(test)]
772mod tests {
773    use crate::Storage;
774    use codemem_core::{MemoryNode, MemoryType, StorageBackend};
775    use std::collections::HashMap;
776
777    fn test_memory() -> MemoryNode {
778        let now = chrono::Utc::now();
779        let content = "Test memory content";
780        MemoryNode {
781            id: uuid::Uuid::new_v4().to_string(),
782            content: content.to_string(),
783            memory_type: MemoryType::Context,
784            importance: 0.7,
785            confidence: 1.0,
786            access_count: 0,
787            content_hash: Storage::content_hash(content),
788            tags: vec!["test".to_string()],
789            metadata: HashMap::new(),
790            namespace: None,
791            created_at: now,
792            updated_at: now,
793            last_accessed_at: now,
794        }
795    }
796
797    #[test]
798    fn get_memories_batch_returns_multiple() {
799        let storage = Storage::open_in_memory().unwrap();
800        let m1 = test_memory();
801        let mut m2 = test_memory();
802        m2.id = uuid::Uuid::new_v4().to_string();
803        m2.content = "Different content".to_string();
804        m2.content_hash = Storage::content_hash(&m2.content);
805
806        storage.insert_memory(&m1).unwrap();
807        storage.insert_memory(&m2).unwrap();
808
809        let backend: &dyn StorageBackend = &storage;
810        let batch = backend.get_memories_batch(&[&m1.id, &m2.id]).unwrap();
811        assert_eq!(batch.len(), 2);
812    }
813
814    #[test]
815    fn get_memories_batch_empty() {
816        let storage = Storage::open_in_memory().unwrap();
817        let backend: &dyn StorageBackend = &storage;
818        let batch = backend.get_memories_batch(&[]).unwrap();
819        assert!(batch.is_empty());
820    }
821
822    #[test]
823    fn storage_backend_trait_object() {
824        let storage = Storage::open_in_memory().unwrap();
825        let backend: Box<dyn StorageBackend> = Box::new(storage);
826
827        let m = test_memory();
828        backend.insert_memory(&m).unwrap();
829        let retrieved = backend.get_memory(&m.id).unwrap().unwrap();
830        assert_eq!(retrieved.id, m.id);
831    }
832
833    #[test]
834    fn file_hashes_roundtrip() {
835        let storage = Storage::open_in_memory().unwrap();
836        let backend: &dyn StorageBackend = &storage;
837
838        let mut hashes = HashMap::new();
839        hashes.insert("src/main.rs".to_string(), "abc123".to_string());
840        hashes.insert("src/lib.rs".to_string(), "def456".to_string());
841
842        backend.save_file_hashes(&hashes).unwrap();
843        let loaded = backend.load_file_hashes().unwrap();
844        assert_eq!(loaded.len(), 2);
845        assert_eq!(loaded.get("src/main.rs"), Some(&"abc123".to_string()));
846    }
847
848    #[test]
849    fn decay_stale_memories_updates() {
850        let storage = Storage::open_in_memory().unwrap();
851        let backend: &dyn StorageBackend = &storage;
852
853        let m = test_memory();
854        backend.insert_memory(&m).unwrap();
855
856        // Decay memories older than far future = none affected
857        let count = backend.decay_stale_memories(0, 0.5).unwrap();
858        assert_eq!(count, 0);
859
860        // Decay all memories (threshold in the future)
861        let count = backend.decay_stale_memories(i64::MAX, 0.5).unwrap();
862        assert_eq!(count, 1);
863    }
864
865    #[test]
866    fn find_forgettable_returns_low_importance() {
867        let storage = Storage::open_in_memory().unwrap();
868        let backend: &dyn StorageBackend = &storage;
869
870        let mut m = test_memory();
871        m.importance = 0.1;
872        backend.insert_memory(&m).unwrap();
873
874        let forgettable = backend.find_forgettable(0.5).unwrap();
875        assert_eq!(forgettable.len(), 1);
876        assert_eq!(forgettable[0], m.id);
877
878        let forgettable = backend.find_forgettable(0.05).unwrap();
879        assert!(forgettable.is_empty());
880    }
881}