Skip to main content

codemem_storage/
graph_persistence.rs

1//! Graph node/edge CRUD and embedding storage on Storage.
2
3use crate::Storage;
4use codemem_core::{CodememError, Edge, GraphNode, NodeKind, RelationshipType};
5use rusqlite::{params, OptionalExtension};
6use std::collections::HashMap;
7
8impl Storage {
9    // ── Embedding Storage ───────────────────────────────────────────────
10
11    /// Store an embedding for a memory.
12    pub fn store_embedding(&self, memory_id: &str, embedding: &[f32]) -> Result<(), CodememError> {
13        let conn = self.conn();
14        let blob: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
15
16        conn.execute(
17            "INSERT OR REPLACE INTO memory_embeddings (memory_id, embedding) VALUES (?1, ?2)",
18            params![memory_id, blob],
19        )
20        .map_err(|e| CodememError::Storage(e.to_string()))?;
21
22        Ok(())
23    }
24
25    /// Get an embedding by memory ID.
26    pub fn get_embedding(&self, memory_id: &str) -> Result<Option<Vec<f32>>, CodememError> {
27        let conn = self.conn();
28        let blob: Option<Vec<u8>> = conn
29            .query_row(
30                "SELECT embedding FROM memory_embeddings WHERE memory_id = ?1",
31                params![memory_id],
32                |row| row.get(0),
33            )
34            .optional()
35            .map_err(|e| CodememError::Storage(e.to_string()))?;
36
37        match blob {
38            Some(bytes) => {
39                let floats: Vec<f32> = bytes
40                    .chunks_exact(4)
41                    .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
42                    .collect();
43                Ok(Some(floats))
44            }
45            None => Ok(None),
46        }
47    }
48
49    // ── Graph Node Storage ──────────────────────────────────────────────
50
51    /// Insert a graph node.
52    pub fn insert_graph_node(&self, node: &GraphNode) -> Result<(), CodememError> {
53        let conn = self.conn();
54        let payload_json = serde_json::to_string(&node.payload)?;
55
56        conn.execute(
57            "INSERT OR REPLACE INTO graph_nodes (id, kind, label, payload, centrality, memory_id, namespace)
58             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
59            params![
60                node.id,
61                node.kind.to_string(),
62                node.label,
63                payload_json,
64                node.centrality,
65                node.memory_id,
66                node.namespace,
67            ],
68        )
69        .map_err(|e| CodememError::Storage(e.to_string()))?;
70
71        Ok(())
72    }
73
74    /// Get a graph node by ID.
75    pub fn get_graph_node(&self, id: &str) -> Result<Option<GraphNode>, CodememError> {
76        let conn = self.conn();
77        conn.query_row(
78            "SELECT id, kind, label, payload, centrality, memory_id, namespace FROM graph_nodes WHERE id = ?1",
79            params![id],
80            |row| {
81                let kind_str: String = row.get(1)?;
82                let payload_str: String = row.get(3)?;
83                Ok((
84                    row.get::<_, String>(0)?,
85                    kind_str,
86                    row.get::<_, String>(2)?,
87                    payload_str,
88                    row.get::<_, f64>(4)?,
89                    row.get::<_, Option<String>>(5)?,
90                    row.get::<_, Option<String>>(6)?,
91                ))
92            },
93        )
94        .optional()
95        .map_err(|e| CodememError::Storage(e.to_string()))?
96        .map(|(id, kind_str, label, payload_str, centrality, memory_id, namespace)| {
97            let kind: NodeKind = kind_str.parse().map_err(|e: CodememError| CodememError::Storage(e.to_string()))?;
98            let payload: HashMap<String, serde_json::Value> =
99                serde_json::from_str(&payload_str).unwrap_or_default();
100            Ok(GraphNode {
101                id,
102                kind,
103                label,
104                payload,
105                centrality,
106                memory_id,
107                namespace,
108            })
109        })
110        .transpose()
111    }
112
113    /// Delete a graph node by ID.
114    pub fn delete_graph_node(&self, id: &str) -> Result<bool, CodememError> {
115        let conn = self.conn();
116        let rows = conn
117            .execute("DELETE FROM graph_nodes WHERE id = ?1", params![id])
118            .map_err(|e| CodememError::Storage(e.to_string()))?;
119        Ok(rows > 0)
120    }
121
122    /// Get all graph nodes.
123    pub fn all_graph_nodes(&self) -> Result<Vec<GraphNode>, CodememError> {
124        let conn = self.conn();
125        let mut stmt = conn
126            .prepare("SELECT id, kind, label, payload, centrality, memory_id, namespace FROM graph_nodes")
127            .map_err(|e| CodememError::Storage(e.to_string()))?;
128
129        let nodes = stmt
130            .query_map([], |row| {
131                let kind_str: String = row.get(1)?;
132                let payload_str: String = row.get(3)?;
133                Ok((
134                    row.get::<_, String>(0)?,
135                    kind_str,
136                    row.get::<_, String>(2)?,
137                    payload_str,
138                    row.get::<_, f64>(4)?,
139                    row.get::<_, Option<String>>(5)?,
140                    row.get::<_, Option<String>>(6)?,
141                ))
142            })
143            .map_err(|e| CodememError::Storage(e.to_string()))?
144            .filter_map(|r| r.ok())
145            .filter_map(
146                |(id, kind_str, label, payload_str, centrality, memory_id, namespace)| {
147                    let kind: NodeKind = kind_str.parse().ok()?;
148                    let payload: HashMap<String, serde_json::Value> =
149                        serde_json::from_str(&payload_str).unwrap_or_default();
150                    Some(GraphNode {
151                        id,
152                        kind,
153                        label,
154                        payload,
155                        centrality,
156                        memory_id,
157                        namespace,
158                    })
159                },
160            )
161            .collect();
162
163        Ok(nodes)
164    }
165
166    // ── Graph Edge Storage ──────────────────────────────────────────────
167
168    /// Insert a graph edge.
169    pub fn insert_graph_edge(&self, edge: &Edge) -> Result<(), CodememError> {
170        let conn = self.conn();
171        let props_json = serde_json::to_string(&edge.properties)?;
172
173        conn.execute(
174            "INSERT OR REPLACE INTO graph_edges (id, src, dst, relationship, weight, properties, created_at)
175             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
176            params![
177                edge.id,
178                edge.src,
179                edge.dst,
180                edge.relationship.to_string(),
181                edge.weight,
182                props_json,
183                edge.created_at.timestamp(),
184            ],
185        )
186        .map_err(|e| CodememError::Storage(e.to_string()))?;
187
188        Ok(())
189    }
190
191    /// Get all edges from or to a node.
192    pub fn get_edges_for_node(&self, node_id: &str) -> Result<Vec<Edge>, CodememError> {
193        let conn = self.conn();
194        let mut stmt = conn
195            .prepare(
196                "SELECT id, src, dst, relationship, weight, properties, created_at FROM graph_edges WHERE src = ?1 OR dst = ?1",
197            )
198            .map_err(|e| CodememError::Storage(e.to_string()))?;
199
200        let edges = stmt
201            .query_map(params![node_id], |row| {
202                let rel_str: String = row.get(3)?;
203                let props_str: String = row.get(5)?;
204                let created_ts: i64 = row.get(6)?;
205                Ok((
206                    row.get::<_, String>(0)?,
207                    row.get::<_, String>(1)?,
208                    row.get::<_, String>(2)?,
209                    rel_str,
210                    row.get::<_, f64>(4)?,
211                    props_str,
212                    created_ts,
213                ))
214            })
215            .map_err(|e| CodememError::Storage(e.to_string()))?
216            .filter_map(|r| r.ok())
217            .filter_map(|(id, src, dst, rel_str, weight, props_str, created_ts)| {
218                let relationship: RelationshipType = rel_str.parse().ok()?;
219                let properties: HashMap<String, serde_json::Value> =
220                    serde_json::from_str(&props_str).unwrap_or_default();
221                let created_at =
222                    chrono::DateTime::from_timestamp(created_ts, 0)?.with_timezone(&chrono::Utc);
223                Some(Edge {
224                    id,
225                    src,
226                    dst,
227                    relationship,
228                    weight,
229                    properties,
230                    created_at,
231                })
232            })
233            .collect();
234
235        Ok(edges)
236    }
237
238    /// Get all graph edges.
239    pub fn all_graph_edges(&self) -> Result<Vec<Edge>, CodememError> {
240        let conn = self.conn();
241        let mut stmt = conn
242            .prepare("SELECT id, src, dst, relationship, weight, properties, created_at FROM graph_edges")
243            .map_err(|e| CodememError::Storage(e.to_string()))?;
244
245        let edges = stmt
246            .query_map([], |row| {
247                let rel_str: String = row.get(3)?;
248                let props_str: String = row.get(5)?;
249                let created_ts: i64 = row.get(6)?;
250                Ok((
251                    row.get::<_, String>(0)?,
252                    row.get::<_, String>(1)?,
253                    row.get::<_, String>(2)?,
254                    rel_str,
255                    row.get::<_, f64>(4)?,
256                    props_str,
257                    created_ts,
258                ))
259            })
260            .map_err(|e| CodememError::Storage(e.to_string()))?
261            .filter_map(|r| r.ok())
262            .filter_map(|(id, src, dst, rel_str, weight, props_str, created_ts)| {
263                let relationship: RelationshipType = rel_str.parse().ok()?;
264                let properties: HashMap<String, serde_json::Value> =
265                    serde_json::from_str(&props_str).unwrap_or_default();
266                let created_at =
267                    chrono::DateTime::from_timestamp(created_ts, 0)?.with_timezone(&chrono::Utc);
268                Some(Edge {
269                    id,
270                    src,
271                    dst,
272                    relationship,
273                    weight,
274                    properties,
275                    created_at,
276                })
277            })
278            .collect();
279
280        Ok(edges)
281    }
282
283    /// Delete all graph edges connected to a node (as src or dst).
284    pub fn delete_graph_edges_for_node(&self, node_id: &str) -> Result<usize, CodememError> {
285        let conn = self.conn();
286        let rows = conn
287            .execute(
288                "DELETE FROM graph_edges WHERE src = ?1 OR dst = ?1",
289                params![node_id],
290            )
291            .map_err(|e| CodememError::Storage(e.to_string()))?;
292        Ok(rows)
293    }
294
295    /// Get all graph edges where both src and dst nodes belong to the given namespace.
296    pub fn graph_edges_for_namespace(&self, namespace: &str) -> Result<Vec<Edge>, CodememError> {
297        let conn = self.conn();
298        let mut stmt = conn
299            .prepare(
300                "SELECT e.id, e.src, e.dst, e.relationship, e.weight, e.properties, e.created_at
301                 FROM graph_edges e
302                 INNER JOIN graph_nodes gs ON e.src = gs.id
303                 INNER JOIN graph_nodes gd ON e.dst = gd.id
304                 WHERE gs.namespace = ?1 AND gd.namespace = ?1",
305            )
306            .map_err(|e| CodememError::Storage(e.to_string()))?;
307
308        let edges = stmt
309            .query_map(params![namespace], |row| {
310                let rel_str: String = row.get(3)?;
311                let props_str: String = row.get(5)?;
312                let created_ts: i64 = row.get(6)?;
313                Ok((
314                    row.get::<_, String>(0)?,
315                    row.get::<_, String>(1)?,
316                    row.get::<_, String>(2)?,
317                    rel_str,
318                    row.get::<_, f64>(4)?,
319                    props_str,
320                    created_ts,
321                ))
322            })
323            .map_err(|e| CodememError::Storage(e.to_string()))?
324            .filter_map(|r| r.ok())
325            .filter_map(|(id, src, dst, rel_str, weight, props_str, created_ts)| {
326                let relationship: RelationshipType = rel_str.parse().ok()?;
327                let properties: HashMap<String, serde_json::Value> =
328                    serde_json::from_str(&props_str).unwrap_or_default();
329                let created_at =
330                    chrono::DateTime::from_timestamp(created_ts, 0)?.with_timezone(&chrono::Utc);
331                Some(Edge {
332                    id,
333                    src,
334                    dst,
335                    relationship,
336                    weight,
337                    properties,
338                    created_at,
339                })
340            })
341            .collect();
342
343        Ok(edges)
344    }
345
346    /// Delete a graph edge by ID.
347    pub fn delete_graph_edge(&self, id: &str) -> Result<bool, CodememError> {
348        let conn = self.conn();
349        let rows = conn
350            .execute("DELETE FROM graph_edges WHERE id = ?1", params![id])
351            .map_err(|e| CodememError::Storage(e.to_string()))?;
352        Ok(rows > 0)
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use crate::Storage;
359    use codemem_core::{GraphNode, MemoryNode, MemoryType, NodeKind};
360    use std::collections::HashMap;
361
362    fn test_memory() -> MemoryNode {
363        let now = chrono::Utc::now();
364        let content = "Test memory content";
365        MemoryNode {
366            id: uuid::Uuid::new_v4().to_string(),
367            content: content.to_string(),
368            memory_type: MemoryType::Context,
369            importance: 0.7,
370            confidence: 1.0,
371            access_count: 0,
372            content_hash: Storage::content_hash(content),
373            tags: vec!["test".to_string()],
374            metadata: HashMap::new(),
375            namespace: None,
376            created_at: now,
377            updated_at: now,
378            last_accessed_at: now,
379        }
380    }
381
382    #[test]
383    fn store_and_get_embedding() {
384        let storage = Storage::open_in_memory().unwrap();
385        let memory = test_memory();
386        storage.insert_memory(&memory).unwrap();
387
388        let embedding: Vec<f32> = (0..768).map(|i| i as f32 / 768.0).collect();
389        storage.store_embedding(&memory.id, &embedding).unwrap();
390
391        let retrieved = storage.get_embedding(&memory.id).unwrap().unwrap();
392        assert_eq!(retrieved.len(), 768);
393        assert!((retrieved[0] - embedding[0]).abs() < f32::EPSILON);
394    }
395
396    #[test]
397    fn graph_node_crud() {
398        let storage = Storage::open_in_memory().unwrap();
399        let node = GraphNode {
400            id: "file:src/main.rs".to_string(),
401            kind: NodeKind::File,
402            label: "src/main.rs".to_string(),
403            payload: HashMap::new(),
404            centrality: 0.0,
405            memory_id: None,
406            namespace: None,
407        };
408
409        storage.insert_graph_node(&node).unwrap();
410        let retrieved = storage.get_graph_node(&node.id).unwrap().unwrap();
411        assert_eq!(retrieved.kind, NodeKind::File);
412        assert!(storage.delete_graph_node(&node.id).unwrap());
413    }
414}