Skip to main content

codemem_storage/
graph_persistence.rs

1//! Graph node/edge CRUD and embedding storage on Storage.
2
3use crate::{MapStorageErr, Storage};
4use codemem_core::{CodememError, Edge, GraphNode, NodeKind, RelationshipType};
5use rusqlite::{params, OptionalExtension};
6use std::collections::HashMap;
7
8/// Raw tuple type for an edge row from the database.
9///
10/// Fields: `(id, src, dst, relationship, weight, properties, created_at, valid_from, valid_to)`.
11pub(crate) type EdgeRow = (
12    String,
13    String,
14    String,
15    String,
16    f64,
17    String,
18    i64,
19    Option<i64>,
20    Option<i64>,
21);
22
23/// Deserialize an `Edge` from a raw database row tuple.
24/// Logs a warning if an edge is dropped due to an unrecognized relationship type.
25pub(crate) fn edge_from_row(row: EdgeRow) -> Option<Edge> {
26    let (id, src, dst, rel_str, weight, props_str, created_ts, valid_from_ts, valid_to_ts) = row;
27    let relationship: RelationshipType = match rel_str.parse() {
28        Ok(r) => r,
29        Err(_) => {
30            tracing::warn!(
31                edge_id = %id,
32                relationship = %rel_str,
33                "Dropping edge with unrecognized relationship type"
34            );
35            return None;
36        }
37    };
38    let properties: HashMap<String, serde_json::Value> =
39        serde_json::from_str(&props_str).unwrap_or_default();
40    let created_at = chrono::DateTime::from_timestamp(created_ts, 0)?.with_timezone(&chrono::Utc);
41    let valid_from = valid_from_ts
42        .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
43        .map(|dt| dt.with_timezone(&chrono::Utc));
44    let valid_to = valid_to_ts
45        .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
46        .map(|dt| dt.with_timezone(&chrono::Utc));
47    Some(Edge {
48        id,
49        src,
50        dst,
51        relationship,
52        weight,
53        properties,
54        created_at,
55        valid_from,
56        valid_to,
57    })
58}
59
60/// Extract the 9-column edge tuple from a database row.
61///
62/// Use with `query_map` to produce the tuple that `edge_from_row` consumes.
63pub(crate) fn extract_edge_tuple(row: &rusqlite::Row<'_>) -> rusqlite::Result<EdgeRow> {
64    let rel_str: String = row.get(3)?;
65    let props_str: String = row.get(5)?;
66    let created_ts: i64 = row.get(6)?;
67    let valid_from_ts: Option<i64> = row.get(7)?;
68    let valid_to_ts: Option<i64> = row.get(8)?;
69    Ok((
70        row.get::<_, String>(0)?,
71        row.get::<_, String>(1)?,
72        row.get::<_, String>(2)?,
73        rel_str,
74        row.get::<_, f64>(4)?,
75        props_str,
76        created_ts,
77        valid_from_ts,
78        valid_to_ts,
79    ))
80}
81
82impl Storage {
83    // ── Embedding Storage ───────────────────────────────────────────────
84
85    /// Store an embedding for a memory.
86    pub fn store_embedding(&self, memory_id: &str, embedding: &[f32]) -> Result<(), CodememError> {
87        let conn = self.conn()?;
88        let blob: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
89
90        conn.execute(
91            "INSERT OR REPLACE INTO memory_embeddings (memory_id, embedding) VALUES (?1, ?2)",
92            params![memory_id, blob],
93        )
94        .storage_err()?;
95
96        Ok(())
97    }
98
99    /// Get an embedding by memory ID.
100    pub fn get_embedding(&self, memory_id: &str) -> Result<Option<Vec<f32>>, CodememError> {
101        let conn = self.conn()?;
102        let blob: Option<Vec<u8>> = conn
103            .query_row(
104                "SELECT embedding FROM memory_embeddings WHERE memory_id = ?1",
105                params![memory_id],
106                |row| row.get(0),
107            )
108            .optional()
109            .storage_err()?;
110
111        match blob {
112            Some(bytes) => {
113                let floats: Vec<f32> = bytes
114                    .chunks_exact(4)
115                    .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
116                    .collect();
117                Ok(Some(floats))
118            }
119            None => Ok(None),
120        }
121    }
122
123    // ── Graph Node Storage ──────────────────────────────────────────────
124
125    /// Insert a graph node.
126    pub fn insert_graph_node(&self, node: &GraphNode) -> Result<(), CodememError> {
127        let conn = self.conn()?;
128        let payload_json = serde_json::to_string(&node.payload)?;
129
130        conn.execute(
131            "INSERT OR REPLACE INTO graph_nodes (id, kind, label, payload, centrality, memory_id, namespace)
132             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
133            params![
134                node.id,
135                node.kind.to_string(),
136                node.label,
137                payload_json,
138                node.centrality,
139                node.memory_id,
140                node.namespace,
141            ],
142        )
143        .storage_err()?;
144
145        Ok(())
146    }
147
148    /// Get a graph node by ID.
149    pub fn get_graph_node(&self, id: &str) -> Result<Option<GraphNode>, CodememError> {
150        let conn = self.conn()?;
151        conn.query_row(
152            "SELECT id, kind, label, payload, centrality, memory_id, namespace FROM graph_nodes WHERE id = ?1",
153            params![id],
154            |row| {
155                let kind_str: String = row.get(1)?;
156                let payload_str: String = row.get(3)?;
157                Ok((
158                    row.get::<_, String>(0)?,
159                    kind_str,
160                    row.get::<_, String>(2)?,
161                    payload_str,
162                    row.get::<_, f64>(4)?,
163                    row.get::<_, Option<String>>(5)?,
164                    row.get::<_, Option<String>>(6)?,
165                ))
166            },
167        )
168        .optional()
169        .storage_err()?
170        .map(|(id, kind_str, label, payload_str, centrality, memory_id, namespace)| {
171            let kind: NodeKind = kind_str.parse().map_err(|e: CodememError| CodememError::Storage(e.to_string()))?;
172            let payload: HashMap<String, serde_json::Value> =
173                serde_json::from_str(&payload_str).unwrap_or_default();
174            Ok(GraphNode {
175                id,
176                kind,
177                label,
178                payload,
179                centrality,
180                memory_id,
181                namespace,
182            })
183        })
184        .transpose()
185    }
186
187    /// Delete a graph node by ID.
188    pub fn delete_graph_node(&self, id: &str) -> Result<bool, CodememError> {
189        let conn = self.conn()?;
190        let rows = conn
191            .execute("DELETE FROM graph_nodes WHERE id = ?1", params![id])
192            .storage_err()?;
193        Ok(rows > 0)
194    }
195
196    /// Get all graph nodes. Logs warnings for rows with parse errors instead of silently dropping.
197    pub fn all_graph_nodes(&self) -> Result<Vec<GraphNode>, CodememError> {
198        let conn = self.conn()?;
199        let mut stmt = conn
200            .prepare("SELECT id, kind, label, payload, centrality, memory_id, namespace FROM graph_nodes")
201            .storage_err()?;
202
203        let rows = stmt
204            .query_map([], |row| {
205                let kind_str: String = row.get(1)?;
206                let payload_str: String = row.get(3)?;
207                Ok((
208                    row.get::<_, String>(0)?,
209                    kind_str,
210                    row.get::<_, String>(2)?,
211                    payload_str,
212                    row.get::<_, f64>(4)?,
213                    row.get::<_, Option<String>>(5)?,
214                    row.get::<_, Option<String>>(6)?,
215                ))
216            })
217            .storage_err()?;
218
219        let mut nodes = Vec::new();
220        for row_result in rows {
221            let (id, kind_str, label, payload_str, centrality, memory_id, namespace) =
222                row_result.storage_err()?;
223            let kind: NodeKind = match kind_str.parse() {
224                Ok(k) => k,
225                Err(_) => {
226                    tracing::warn!(
227                        node_id = %id,
228                        kind = %kind_str,
229                        "Skipping graph node with unrecognized kind"
230                    );
231                    continue;
232                }
233            };
234            let payload: HashMap<String, serde_json::Value> =
235                serde_json::from_str(&payload_str).unwrap_or_default();
236            nodes.push(GraphNode {
237                id,
238                kind,
239                label,
240                payload,
241                centrality,
242                memory_id,
243                namespace,
244            });
245        }
246
247        Ok(nodes)
248    }
249
250    // ── Graph Edge Storage ──────────────────────────────────────────────
251
252    /// Insert a graph edge.
253    pub fn insert_graph_edge(&self, edge: &Edge) -> Result<(), CodememError> {
254        let conn = self.conn()?;
255        let props_json = serde_json::to_string(&edge.properties)?;
256
257        conn.execute(
258            "INSERT OR REPLACE INTO graph_edges (id, src, dst, relationship, weight, properties, created_at, valid_from, valid_to)
259             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
260            params![
261                edge.id,
262                edge.src,
263                edge.dst,
264                edge.relationship.to_string(),
265                edge.weight,
266                props_json,
267                edge.created_at.timestamp(),
268                edge.valid_from.map(|dt| dt.timestamp()),
269                edge.valid_to.map(|dt| dt.timestamp()),
270            ],
271        )
272        .storage_err()?;
273
274        Ok(())
275    }
276
277    /// Get all edges from or to a node.
278    pub fn get_edges_for_node(&self, node_id: &str) -> Result<Vec<Edge>, CodememError> {
279        let conn = self.conn()?;
280        let mut stmt = conn
281            .prepare(
282                "SELECT id, src, dst, relationship, weight, properties, created_at, valid_from, valid_to FROM graph_edges WHERE src = ?1 OR dst = ?1",
283            )
284            .storage_err()?;
285
286        let edges = stmt
287            .query_map(params![node_id], extract_edge_tuple)
288            .storage_err()?
289            .filter_map(|r| match r {
290                Ok(v) => Some(v),
291                Err(e) => {
292                    tracing::warn!("Failed to process edge row: {e}");
293                    None
294                }
295            })
296            .filter_map(edge_from_row)
297            .collect();
298
299        Ok(edges)
300    }
301
302    /// Get all graph edges.
303    pub fn all_graph_edges(&self) -> Result<Vec<Edge>, CodememError> {
304        let conn = self.conn()?;
305        let mut stmt = conn
306            .prepare("SELECT id, src, dst, relationship, weight, properties, created_at, valid_from, valid_to FROM graph_edges")
307            .storage_err()?;
308
309        let edges = stmt
310            .query_map([], extract_edge_tuple)
311            .storage_err()?
312            .filter_map(|r| match r {
313                Ok(v) => Some(v),
314                Err(e) => {
315                    tracing::warn!("Failed to process edge row: {e}");
316                    None
317                }
318            })
319            .filter_map(edge_from_row)
320            .collect();
321
322        Ok(edges)
323    }
324
325    /// Delete all graph edges connected to a node (as src or dst).
326    pub fn delete_graph_edges_for_node(&self, node_id: &str) -> Result<usize, CodememError> {
327        let conn = self.conn()?;
328        let rows = conn
329            .execute(
330                "DELETE FROM graph_edges WHERE src = ?1 OR dst = ?1",
331                params![node_id],
332            )
333            .storage_err()?;
334        Ok(rows)
335    }
336
337    /// Get all graph edges where both src and dst nodes belong to the given namespace.
338    pub fn graph_edges_for_namespace(&self, namespace: &str) -> Result<Vec<Edge>, CodememError> {
339        let conn = self.conn()?;
340        let mut stmt = conn
341            .prepare(
342                "SELECT e.id, e.src, e.dst, e.relationship, e.weight, e.properties, e.created_at, e.valid_from, e.valid_to
343                 FROM graph_edges e
344                 INNER JOIN graph_nodes gs ON e.src = gs.id
345                 INNER JOIN graph_nodes gd ON e.dst = gd.id
346                 WHERE gs.namespace = ?1 AND gd.namespace = ?1",
347            )
348            .storage_err()?;
349
350        let edges = stmt
351            .query_map(params![namespace], extract_edge_tuple)
352            .storage_err()?
353            .filter_map(|r| match r {
354                Ok(v) => Some(v),
355                Err(e) => {
356                    tracing::warn!("Failed to process edge row: {e}");
357                    None
358                }
359            })
360            .filter_map(edge_from_row)
361            .collect();
362
363        Ok(edges)
364    }
365}
366
367#[cfg(test)]
368#[path = "tests/graph_persistence_tests.rs"]
369mod tests;