Skip to main content

codemem_storage/
graph_persistence.rs

1//! Graph node/edge CRUD and embedding storage on Storage.
2
3use crate::Storage;
4use codemem_core::{CodememError, Edge, GraphNode, NodeKind, RelationshipType};
5use rusqlite::{params, OptionalExtension};
6use std::collections::HashMap;
7
8impl Storage {
9    // ── Embedding Storage ───────────────────────────────────────────────
10
11    /// Store an embedding for a memory.
12    pub fn store_embedding(&self, memory_id: &str, embedding: &[f32]) -> Result<(), CodememError> {
13        let conn = self.conn();
14        let blob: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
15
16        conn.execute(
17            "INSERT OR REPLACE INTO memory_embeddings (memory_id, embedding) VALUES (?1, ?2)",
18            params![memory_id, blob],
19        )
20        .map_err(|e| CodememError::Storage(e.to_string()))?;
21
22        Ok(())
23    }
24
25    /// Get an embedding by memory ID.
26    pub fn get_embedding(&self, memory_id: &str) -> Result<Option<Vec<f32>>, CodememError> {
27        let conn = self.conn();
28        let blob: Option<Vec<u8>> = conn
29            .query_row(
30                "SELECT embedding FROM memory_embeddings WHERE memory_id = ?1",
31                params![memory_id],
32                |row| row.get(0),
33            )
34            .optional()
35            .map_err(|e| CodememError::Storage(e.to_string()))?;
36
37        match blob {
38            Some(bytes) => {
39                let floats: Vec<f32> = bytes
40                    .chunks_exact(4)
41                    .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
42                    .collect();
43                Ok(Some(floats))
44            }
45            None => Ok(None),
46        }
47    }
48
49    // ── Graph Node Storage ──────────────────────────────────────────────
50
51    /// Insert a graph node.
52    pub fn insert_graph_node(&self, node: &GraphNode) -> Result<(), CodememError> {
53        let conn = self.conn();
54        let payload_json = serde_json::to_string(&node.payload)?;
55
56        conn.execute(
57            "INSERT OR REPLACE INTO graph_nodes (id, kind, label, payload, centrality, memory_id, namespace)
58             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
59            params![
60                node.id,
61                node.kind.to_string(),
62                node.label,
63                payload_json,
64                node.centrality,
65                node.memory_id,
66                node.namespace,
67            ],
68        )
69        .map_err(|e| CodememError::Storage(e.to_string()))?;
70
71        Ok(())
72    }
73
74    /// Get a graph node by ID.
75    pub fn get_graph_node(&self, id: &str) -> Result<Option<GraphNode>, CodememError> {
76        let conn = self.conn();
77        conn.query_row(
78            "SELECT id, kind, label, payload, centrality, memory_id, namespace FROM graph_nodes WHERE id = ?1",
79            params![id],
80            |row| {
81                let kind_str: String = row.get(1)?;
82                let payload_str: String = row.get(3)?;
83                Ok((
84                    row.get::<_, String>(0)?,
85                    kind_str,
86                    row.get::<_, String>(2)?,
87                    payload_str,
88                    row.get::<_, f64>(4)?,
89                    row.get::<_, Option<String>>(5)?,
90                    row.get::<_, Option<String>>(6)?,
91                ))
92            },
93        )
94        .optional()
95        .map_err(|e| CodememError::Storage(e.to_string()))?
96        .map(|(id, kind_str, label, payload_str, centrality, memory_id, namespace)| {
97            let kind: NodeKind = kind_str.parse().map_err(|e: CodememError| CodememError::Storage(e.to_string()))?;
98            let payload: HashMap<String, serde_json::Value> =
99                serde_json::from_str(&payload_str).unwrap_or_default();
100            Ok(GraphNode {
101                id,
102                kind,
103                label,
104                payload,
105                centrality,
106                memory_id,
107                namespace,
108            })
109        })
110        .transpose()
111    }
112
113    /// Delete a graph node by ID.
114    pub fn delete_graph_node(&self, id: &str) -> Result<bool, CodememError> {
115        let conn = self.conn();
116        let rows = conn
117            .execute("DELETE FROM graph_nodes WHERE id = ?1", params![id])
118            .map_err(|e| CodememError::Storage(e.to_string()))?;
119        Ok(rows > 0)
120    }
121
122    /// Get all graph nodes.
123    pub fn all_graph_nodes(&self) -> Result<Vec<GraphNode>, CodememError> {
124        let conn = self.conn();
125        let mut stmt = conn
126            .prepare("SELECT id, kind, label, payload, centrality, memory_id, namespace FROM graph_nodes")
127            .map_err(|e| CodememError::Storage(e.to_string()))?;
128
129        let nodes = stmt
130            .query_map([], |row| {
131                let kind_str: String = row.get(1)?;
132                let payload_str: String = row.get(3)?;
133                Ok((
134                    row.get::<_, String>(0)?,
135                    kind_str,
136                    row.get::<_, String>(2)?,
137                    payload_str,
138                    row.get::<_, f64>(4)?,
139                    row.get::<_, Option<String>>(5)?,
140                    row.get::<_, Option<String>>(6)?,
141                ))
142            })
143            .map_err(|e| CodememError::Storage(e.to_string()))?
144            .filter_map(|r| r.ok())
145            .filter_map(
146                |(id, kind_str, label, payload_str, centrality, memory_id, namespace)| {
147                    let kind: NodeKind = kind_str.parse().ok()?;
148                    let payload: HashMap<String, serde_json::Value> =
149                        serde_json::from_str(&payload_str).unwrap_or_default();
150                    Some(GraphNode {
151                        id,
152                        kind,
153                        label,
154                        payload,
155                        centrality,
156                        memory_id,
157                        namespace,
158                    })
159                },
160            )
161            .collect();
162
163        Ok(nodes)
164    }
165
166    // ── Graph Edge Storage ──────────────────────────────────────────────
167
168    /// Insert a graph edge.
169    pub fn insert_graph_edge(&self, edge: &Edge) -> Result<(), CodememError> {
170        let conn = self.conn();
171        let props_json = serde_json::to_string(&edge.properties)?;
172
173        conn.execute(
174            "INSERT OR REPLACE INTO graph_edges (id, src, dst, relationship, weight, properties, created_at, valid_from, valid_to)
175             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
176            params![
177                edge.id,
178                edge.src,
179                edge.dst,
180                edge.relationship.to_string(),
181                edge.weight,
182                props_json,
183                edge.created_at.timestamp(),
184                edge.valid_from.map(|dt| dt.timestamp()),
185                edge.valid_to.map(|dt| dt.timestamp()),
186            ],
187        )
188        .map_err(|e| CodememError::Storage(e.to_string()))?;
189
190        Ok(())
191    }
192
193    /// Get all edges from or to a node.
194    pub fn get_edges_for_node(&self, node_id: &str) -> Result<Vec<Edge>, CodememError> {
195        let conn = self.conn();
196        let mut stmt = conn
197            .prepare(
198                "SELECT id, src, dst, relationship, weight, properties, created_at, valid_from, valid_to FROM graph_edges WHERE src = ?1 OR dst = ?1",
199            )
200            .map_err(|e| CodememError::Storage(e.to_string()))?;
201
202        let edges = stmt
203            .query_map(params![node_id], |row| {
204                let rel_str: String = row.get(3)?;
205                let props_str: String = row.get(5)?;
206                let created_ts: i64 = row.get(6)?;
207                let valid_from_ts: Option<i64> = row.get(7)?;
208                let valid_to_ts: Option<i64> = row.get(8)?;
209                Ok((
210                    row.get::<_, String>(0)?,
211                    row.get::<_, String>(1)?,
212                    row.get::<_, String>(2)?,
213                    rel_str,
214                    row.get::<_, f64>(4)?,
215                    props_str,
216                    created_ts,
217                    valid_from_ts,
218                    valid_to_ts,
219                ))
220            })
221            .map_err(|e| CodememError::Storage(e.to_string()))?
222            .filter_map(|r| r.ok())
223            .filter_map(
224                |(
225                    id,
226                    src,
227                    dst,
228                    rel_str,
229                    weight,
230                    props_str,
231                    created_ts,
232                    valid_from_ts,
233                    valid_to_ts,
234                )| {
235                    let relationship: RelationshipType = rel_str.parse().ok()?;
236                    let properties: HashMap<String, serde_json::Value> =
237                        serde_json::from_str(&props_str).unwrap_or_default();
238                    let created_at = chrono::DateTime::from_timestamp(created_ts, 0)?
239                        .with_timezone(&chrono::Utc);
240                    let valid_from = valid_from_ts
241                        .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
242                        .map(|dt| dt.with_timezone(&chrono::Utc));
243                    let valid_to = valid_to_ts
244                        .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
245                        .map(|dt| dt.with_timezone(&chrono::Utc));
246                    Some(Edge {
247                        id,
248                        src,
249                        dst,
250                        relationship,
251                        weight,
252                        properties,
253                        created_at,
254                        valid_from,
255                        valid_to,
256                    })
257                },
258            )
259            .collect();
260
261        Ok(edges)
262    }
263
264    /// Get all graph edges.
265    pub fn all_graph_edges(&self) -> Result<Vec<Edge>, CodememError> {
266        let conn = self.conn();
267        let mut stmt = conn
268            .prepare("SELECT id, src, dst, relationship, weight, properties, created_at, valid_from, valid_to FROM graph_edges")
269            .map_err(|e| CodememError::Storage(e.to_string()))?;
270
271        let edges = stmt
272            .query_map([], |row| {
273                let rel_str: String = row.get(3)?;
274                let props_str: String = row.get(5)?;
275                let created_ts: i64 = row.get(6)?;
276                let valid_from_ts: Option<i64> = row.get(7)?;
277                let valid_to_ts: Option<i64> = row.get(8)?;
278                Ok((
279                    row.get::<_, String>(0)?,
280                    row.get::<_, String>(1)?,
281                    row.get::<_, String>(2)?,
282                    rel_str,
283                    row.get::<_, f64>(4)?,
284                    props_str,
285                    created_ts,
286                    valid_from_ts,
287                    valid_to_ts,
288                ))
289            })
290            .map_err(|e| CodememError::Storage(e.to_string()))?
291            .filter_map(|r| r.ok())
292            .filter_map(
293                |(
294                    id,
295                    src,
296                    dst,
297                    rel_str,
298                    weight,
299                    props_str,
300                    created_ts,
301                    valid_from_ts,
302                    valid_to_ts,
303                )| {
304                    let relationship: RelationshipType = rel_str.parse().ok()?;
305                    let properties: HashMap<String, serde_json::Value> =
306                        serde_json::from_str(&props_str).unwrap_or_default();
307                    let created_at = chrono::DateTime::from_timestamp(created_ts, 0)?
308                        .with_timezone(&chrono::Utc);
309                    let valid_from = valid_from_ts
310                        .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
311                        .map(|dt| dt.with_timezone(&chrono::Utc));
312                    let valid_to = valid_to_ts
313                        .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
314                        .map(|dt| dt.with_timezone(&chrono::Utc));
315                    Some(Edge {
316                        id,
317                        src,
318                        dst,
319                        relationship,
320                        weight,
321                        properties,
322                        created_at,
323                        valid_from,
324                        valid_to,
325                    })
326                },
327            )
328            .collect();
329
330        Ok(edges)
331    }
332
333    /// Delete all graph edges connected to a node (as src or dst).
334    pub fn delete_graph_edges_for_node(&self, node_id: &str) -> Result<usize, CodememError> {
335        let conn = self.conn();
336        let rows = conn
337            .execute(
338                "DELETE FROM graph_edges WHERE src = ?1 OR dst = ?1",
339                params![node_id],
340            )
341            .map_err(|e| CodememError::Storage(e.to_string()))?;
342        Ok(rows)
343    }
344
345    /// Get all graph edges where both src and dst nodes belong to the given namespace.
346    pub fn graph_edges_for_namespace(&self, namespace: &str) -> Result<Vec<Edge>, CodememError> {
347        let conn = self.conn();
348        let mut stmt = conn
349            .prepare(
350                "SELECT e.id, e.src, e.dst, e.relationship, e.weight, e.properties, e.created_at, e.valid_from, e.valid_to
351                 FROM graph_edges e
352                 INNER JOIN graph_nodes gs ON e.src = gs.id
353                 INNER JOIN graph_nodes gd ON e.dst = gd.id
354                 WHERE gs.namespace = ?1 AND gd.namespace = ?1",
355            )
356            .map_err(|e| CodememError::Storage(e.to_string()))?;
357
358        let edges = stmt
359            .query_map(params![namespace], |row| {
360                let rel_str: String = row.get(3)?;
361                let props_str: String = row.get(5)?;
362                let created_ts: i64 = row.get(6)?;
363                let valid_from_ts: Option<i64> = row.get(7)?;
364                let valid_to_ts: Option<i64> = row.get(8)?;
365                Ok((
366                    row.get::<_, String>(0)?,
367                    row.get::<_, String>(1)?,
368                    row.get::<_, String>(2)?,
369                    rel_str,
370                    row.get::<_, f64>(4)?,
371                    props_str,
372                    created_ts,
373                    valid_from_ts,
374                    valid_to_ts,
375                ))
376            })
377            .map_err(|e| CodememError::Storage(e.to_string()))?
378            .filter_map(|r| r.ok())
379            .filter_map(
380                |(
381                    id,
382                    src,
383                    dst,
384                    rel_str,
385                    weight,
386                    props_str,
387                    created_ts,
388                    valid_from_ts,
389                    valid_to_ts,
390                )| {
391                    let relationship: RelationshipType = rel_str.parse().ok()?;
392                    let properties: HashMap<String, serde_json::Value> =
393                        serde_json::from_str(&props_str).unwrap_or_default();
394                    let created_at = chrono::DateTime::from_timestamp(created_ts, 0)?
395                        .with_timezone(&chrono::Utc);
396                    let valid_from = valid_from_ts
397                        .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
398                        .map(|dt| dt.with_timezone(&chrono::Utc));
399                    let valid_to = valid_to_ts
400                        .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
401                        .map(|dt| dt.with_timezone(&chrono::Utc));
402                    Some(Edge {
403                        id,
404                        src,
405                        dst,
406                        relationship,
407                        weight,
408                        properties,
409                        created_at,
410                        valid_from,
411                        valid_to,
412                    })
413                },
414            )
415            .collect();
416
417        Ok(edges)
418    }
419
420    /// Delete a graph edge by ID.
421    pub fn delete_graph_edge(&self, id: &str) -> Result<bool, CodememError> {
422        let conn = self.conn();
423        let rows = conn
424            .execute("DELETE FROM graph_edges WHERE id = ?1", params![id])
425            .map_err(|e| CodememError::Storage(e.to_string()))?;
426        Ok(rows > 0)
427    }
428}
429
430#[cfg(test)]
431mod tests {
432    use crate::Storage;
433    use codemem_core::{GraphNode, MemoryNode, MemoryType, NodeKind};
434    use std::collections::HashMap;
435
436    fn test_memory() -> MemoryNode {
437        let now = chrono::Utc::now();
438        let content = "Test memory content";
439        MemoryNode {
440            id: uuid::Uuid::new_v4().to_string(),
441            content: content.to_string(),
442            memory_type: MemoryType::Context,
443            importance: 0.7,
444            confidence: 1.0,
445            access_count: 0,
446            content_hash: Storage::content_hash(content),
447            tags: vec!["test".to_string()],
448            metadata: HashMap::new(),
449            namespace: None,
450            created_at: now,
451            updated_at: now,
452            last_accessed_at: now,
453        }
454    }
455
456    #[test]
457    fn store_and_get_embedding() {
458        let storage = Storage::open_in_memory().unwrap();
459        let memory = test_memory();
460        storage.insert_memory(&memory).unwrap();
461
462        let embedding: Vec<f32> = (0..768).map(|i| i as f32 / 768.0).collect();
463        storage.store_embedding(&memory.id, &embedding).unwrap();
464
465        let retrieved = storage.get_embedding(&memory.id).unwrap().unwrap();
466        assert_eq!(retrieved.len(), 768);
467        assert!((retrieved[0] - embedding[0]).abs() < f32::EPSILON);
468    }
469
470    #[test]
471    fn graph_node_crud() {
472        let storage = Storage::open_in_memory().unwrap();
473        let node = GraphNode {
474            id: "file:src/main.rs".to_string(),
475            kind: NodeKind::File,
476            label: "src/main.rs".to_string(),
477            payload: HashMap::new(),
478            centrality: 0.0,
479            memory_id: None,
480            namespace: None,
481        };
482
483        storage.insert_graph_node(&node).unwrap();
484        let retrieved = storage.get_graph_node(&node.id).unwrap().unwrap();
485        assert_eq!(retrieved.kind, NodeKind::File);
486        assert!(storage.delete_graph_node(&node.id).unwrap());
487    }
488}