Skip to main content

codemem_storage/
graph_persistence.rs

1//! Graph node/edge CRUD and embedding storage on Storage.
2
3use crate::{MapStorageErr, Storage};
4use codemem_core::{CodememError, Edge, GraphNode, NodeKind, RelationshipType};
5use rusqlite::{params, OptionalExtension};
6use std::collections::HashMap;
7
8/// Raw tuple type for an edge row from the database.
9///
10/// Fields: `(id, src, dst, relationship, weight, properties, created_at, valid_from, valid_to)`.
11pub(crate) type EdgeRow = (
12    String,
13    String,
14    String,
15    String,
16    f64,
17    String,
18    i64,
19    Option<i64>,
20    Option<i64>,
21);
22
23/// Deserialize an `Edge` from a raw database row tuple.
24/// Logs a warning if an edge is dropped due to an unrecognized relationship type.
25pub(crate) fn edge_from_row(row: EdgeRow) -> Option<Edge> {
26    let (id, src, dst, rel_str, weight, props_str, created_ts, valid_from_ts, valid_to_ts) = row;
27    let relationship: RelationshipType = match rel_str.parse() {
28        Ok(r) => r,
29        Err(_) => {
30            tracing::warn!(
31                edge_id = %id,
32                relationship = %rel_str,
33                "Dropping edge with unrecognized relationship type"
34            );
35            return None;
36        }
37    };
38    let properties: HashMap<String, serde_json::Value> =
39        serde_json::from_str(&props_str).unwrap_or_default();
40    let created_at = chrono::DateTime::from_timestamp(created_ts, 0)?.with_timezone(&chrono::Utc);
41    let valid_from = valid_from_ts
42        .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
43        .map(|dt| dt.with_timezone(&chrono::Utc));
44    let valid_to = valid_to_ts
45        .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
46        .map(|dt| dt.with_timezone(&chrono::Utc));
47    Some(Edge {
48        id,
49        src,
50        dst,
51        relationship,
52        weight,
53        properties,
54        created_at,
55        valid_from,
56        valid_to,
57    })
58}
59
60/// Extract the 9-column edge tuple from a database row.
61///
62/// Use with `query_map` to produce the tuple that `edge_from_row` consumes.
63pub(crate) fn extract_edge_tuple(row: &rusqlite::Row<'_>) -> rusqlite::Result<EdgeRow> {
64    let rel_str: String = row.get(3)?;
65    let props_str: String = row.get(5)?;
66    let created_ts: i64 = row.get(6)?;
67    let valid_from_ts: Option<i64> = row.get(7)?;
68    let valid_to_ts: Option<i64> = row.get(8)?;
69    Ok((
70        row.get::<_, String>(0)?,
71        row.get::<_, String>(1)?,
72        row.get::<_, String>(2)?,
73        rel_str,
74        row.get::<_, f64>(4)?,
75        props_str,
76        created_ts,
77        valid_from_ts,
78        valid_to_ts,
79    ))
80}
81
82impl Storage {
83    // ── Embedding Storage ───────────────────────────────────────────────
84
85    /// Store an embedding for a memory.
86    pub fn store_embedding(&self, memory_id: &str, embedding: &[f32]) -> Result<(), CodememError> {
87        let conn = self.conn()?;
88        let blob: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
89
90        conn.execute(
91            "INSERT OR REPLACE INTO memory_embeddings (memory_id, embedding) VALUES (?1, ?2)",
92            params![memory_id, blob],
93        )
94        .storage_err()?;
95
96        Ok(())
97    }
98
99    /// Get an embedding by memory ID.
100    pub fn get_embedding(&self, memory_id: &str) -> Result<Option<Vec<f32>>, CodememError> {
101        let conn = self.conn()?;
102        let blob: Option<Vec<u8>> = conn
103            .query_row(
104                "SELECT embedding FROM memory_embeddings WHERE memory_id = ?1",
105                params![memory_id],
106                |row| row.get(0),
107            )
108            .optional()
109            .storage_err()?;
110
111        match blob {
112            Some(bytes) => {
113                let floats: Vec<f32> = bytes
114                    .chunks_exact(4)
115                    .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
116                    .collect();
117                Ok(Some(floats))
118            }
119            None => Ok(None),
120        }
121    }
122
123    // ── Graph Node Storage ──────────────────────────────────────────────
124
125    /// Insert a graph node.
126    pub fn insert_graph_node(&self, node: &GraphNode) -> Result<(), CodememError> {
127        let conn = self.conn()?;
128        let payload_json = serde_json::to_string(&node.payload)?;
129
130        conn.execute(
131            "INSERT OR REPLACE INTO graph_nodes (id, kind, label, payload, centrality, memory_id, namespace, valid_from, valid_to)
132             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
133            params![
134                node.id,
135                node.kind.to_string(),
136                node.label,
137                payload_json,
138                node.centrality,
139                node.memory_id,
140                node.namespace,
141                node.valid_from.map(|dt| dt.timestamp()),
142                node.valid_to.map(|dt| dt.timestamp()),
143            ],
144        )
145        .storage_err()?;
146
147        Ok(())
148    }
149
150    /// Get a graph node by ID.
151    pub fn get_graph_node(&self, id: &str) -> Result<Option<GraphNode>, CodememError> {
152        let conn = self.conn()?;
153        conn.query_row(
154            "SELECT id, kind, label, payload, centrality, memory_id, namespace, valid_from, valid_to FROM graph_nodes WHERE id = ?1",
155            params![id],
156            |row| {
157                let kind_str: String = row.get(1)?;
158                let payload_str: String = row.get(3)?;
159                Ok((
160                    row.get::<_, String>(0)?,
161                    kind_str,
162                    row.get::<_, String>(2)?,
163                    payload_str,
164                    row.get::<_, f64>(4)?,
165                    row.get::<_, Option<String>>(5)?,
166                    row.get::<_, Option<String>>(6)?,
167                    row.get::<_, Option<i64>>(7)?,
168                    row.get::<_, Option<i64>>(8)?,
169                ))
170            },
171        )
172        .optional()
173        .storage_err()?
174        .map(|(id, kind_str, label, payload_str, centrality, memory_id, namespace, valid_from_ts, valid_to_ts)| {
175            let kind: NodeKind = kind_str.parse().map_err(|e: CodememError| CodememError::Storage(e.to_string()))?;
176            let payload: HashMap<String, serde_json::Value> =
177                serde_json::from_str(&payload_str).unwrap_or_default();
178            Ok(GraphNode {
179                id,
180                kind,
181                label,
182                payload,
183                centrality,
184                memory_id,
185                namespace,
186                valid_from: valid_from_ts.and_then(|ts| chrono::DateTime::from_timestamp(ts, 0)),
187                valid_to: valid_to_ts.and_then(|ts| chrono::DateTime::from_timestamp(ts, 0)),
188            })
189        })
190        .transpose()
191    }
192
193    /// Delete a graph node by ID.
194    pub fn delete_graph_node(&self, id: &str) -> Result<bool, CodememError> {
195        let conn = self.conn()?;
196        let rows = conn
197            .execute("DELETE FROM graph_nodes WHERE id = ?1", params![id])
198            .storage_err()?;
199        Ok(rows > 0)
200    }
201
202    /// Get all graph nodes. Logs warnings for rows with parse errors instead of silently dropping.
203    pub fn all_graph_nodes(&self) -> Result<Vec<GraphNode>, CodememError> {
204        let conn = self.conn()?;
205        let mut stmt = conn
206            .prepare("SELECT id, kind, label, payload, centrality, memory_id, namespace, valid_from, valid_to FROM graph_nodes")
207            .storage_err()?;
208
209        let rows = stmt
210            .query_map([], |row| {
211                let kind_str: String = row.get(1)?;
212                let payload_str: String = row.get(3)?;
213                Ok((
214                    row.get::<_, String>(0)?,
215                    kind_str,
216                    row.get::<_, String>(2)?,
217                    payload_str,
218                    row.get::<_, f64>(4)?,
219                    row.get::<_, Option<String>>(5)?,
220                    row.get::<_, Option<String>>(6)?,
221                    row.get::<_, Option<i64>>(7)?,
222                    row.get::<_, Option<i64>>(8)?,
223                ))
224            })
225            .storage_err()?;
226
227        let mut nodes = Vec::new();
228        for row_result in rows {
229            let (
230                id,
231                kind_str,
232                label,
233                payload_str,
234                centrality,
235                memory_id,
236                namespace,
237                valid_from_ts,
238                valid_to_ts,
239            ) = row_result.storage_err()?;
240            let kind: NodeKind = match kind_str.parse() {
241                Ok(k) => k,
242                Err(_) => {
243                    tracing::warn!(
244                        node_id = %id,
245                        kind = %kind_str,
246                        "Skipping graph node with unrecognized kind"
247                    );
248                    continue;
249                }
250            };
251            let payload: HashMap<String, serde_json::Value> =
252                serde_json::from_str(&payload_str).unwrap_or_default();
253            nodes.push(GraphNode {
254                id,
255                kind,
256                label,
257                payload,
258                centrality,
259                memory_id,
260                namespace,
261                valid_from: valid_from_ts.and_then(|ts| chrono::DateTime::from_timestamp(ts, 0)),
262                valid_to: valid_to_ts.and_then(|ts| chrono::DateTime::from_timestamp(ts, 0)),
263            });
264        }
265
266        Ok(nodes)
267    }
268
269    // ── Graph Edge Storage ──────────────────────────────────────────────
270
271    /// Insert a graph edge.
272    pub fn insert_graph_edge(&self, edge: &Edge) -> Result<(), CodememError> {
273        let conn = self.conn()?;
274        let props_json = serde_json::to_string(&edge.properties)?;
275
276        conn.execute(
277            "INSERT OR REPLACE INTO graph_edges (id, src, dst, relationship, weight, properties, created_at, valid_from, valid_to)
278             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
279            params![
280                edge.id,
281                edge.src,
282                edge.dst,
283                edge.relationship.to_string(),
284                edge.weight,
285                props_json,
286                edge.created_at.timestamp(),
287                edge.valid_from.map(|dt| dt.timestamp()),
288                edge.valid_to.map(|dt| dt.timestamp()),
289            ],
290        )
291        .storage_err()?;
292
293        Ok(())
294    }
295
296    /// Get all edges from or to a node.
297    pub fn get_edges_for_node(&self, node_id: &str) -> Result<Vec<Edge>, CodememError> {
298        let conn = self.conn()?;
299        let mut stmt = conn
300            .prepare(
301                "SELECT id, src, dst, relationship, weight, properties, created_at, valid_from, valid_to FROM graph_edges WHERE src = ?1 OR dst = ?1",
302            )
303            .storage_err()?;
304
305        let edges = stmt
306            .query_map(params![node_id], extract_edge_tuple)
307            .storage_err()?
308            .filter_map(|r| match r {
309                Ok(v) => Some(v),
310                Err(e) => {
311                    tracing::warn!("Failed to process edge row: {e}");
312                    None
313                }
314            })
315            .filter_map(edge_from_row)
316            .collect();
317
318        Ok(edges)
319    }
320
321    /// Get all graph edges.
322    pub fn all_graph_edges(&self) -> Result<Vec<Edge>, CodememError> {
323        let conn = self.conn()?;
324        let mut stmt = conn
325            .prepare("SELECT id, src, dst, relationship, weight, properties, created_at, valid_from, valid_to FROM graph_edges")
326            .storage_err()?;
327
328        let edges = stmt
329            .query_map([], extract_edge_tuple)
330            .storage_err()?
331            .filter_map(|r| match r {
332                Ok(v) => Some(v),
333                Err(e) => {
334                    tracing::warn!("Failed to process edge row: {e}");
335                    None
336                }
337            })
338            .filter_map(edge_from_row)
339            .collect();
340
341        Ok(edges)
342    }
343
344    /// Delete a single graph edge by its ID. Returns true if a row was deleted.
345    pub fn delete_graph_edge(&self, edge_id: &str) -> Result<bool, CodememError> {
346        let conn = self.conn()?;
347        let rows = conn
348            .execute("DELETE FROM graph_edges WHERE id = ?1", params![edge_id])
349            .storage_err()?;
350        Ok(rows > 0)
351    }
352
353    /// Delete all graph edges connected to a node (as src or dst).
354    pub fn delete_graph_edges_for_node(&self, node_id: &str) -> Result<usize, CodememError> {
355        let conn = self.conn()?;
356        let rows = conn
357            .execute(
358                "DELETE FROM graph_edges WHERE src = ?1 OR dst = ?1",
359                params![node_id],
360            )
361            .storage_err()?;
362        Ok(rows)
363    }
364
365    /// Get all graph edges where both src and dst nodes belong to the given namespace.
366    pub fn graph_edges_for_namespace(&self, namespace: &str) -> Result<Vec<Edge>, CodememError> {
367        self.graph_edges_for_namespace_with_cross(namespace, false)
368    }
369
370    /// Get graph edges for a namespace, optionally including cross-namespace edges.
371    ///
372    /// When `include_cross_namespace` is false, both endpoints must belong to the namespace.
373    /// When true, at least one endpoint must belong to the namespace (OR instead of AND).
374    pub fn graph_edges_for_namespace_with_cross(
375        &self,
376        namespace: &str,
377        include_cross_namespace: bool,
378    ) -> Result<Vec<Edge>, CodememError> {
379        let conn = self.conn()?;
380        let condition = if include_cross_namespace {
381            "gs.namespace = ?1 OR gd.namespace = ?1"
382        } else {
383            "gs.namespace = ?1 AND gd.namespace = ?1"
384        };
385        let sql = format!(
386            "SELECT e.id, e.src, e.dst, e.relationship, e.weight, e.properties, e.created_at, e.valid_from, e.valid_to
387             FROM graph_edges e
388             INNER JOIN graph_nodes gs ON e.src = gs.id
389             INNER JOIN graph_nodes gd ON e.dst = gd.id
390             WHERE {condition}"
391        );
392        let mut stmt = conn.prepare(&sql).storage_err()?;
393
394        let edges = stmt
395            .query_map(params![namespace], extract_edge_tuple)
396            .storage_err()?
397            .filter_map(|r| match r {
398                Ok(v) => Some(v),
399                Err(e) => {
400                    tracing::warn!("Failed to process edge row: {e}");
401                    None
402                }
403            })
404            .filter_map(edge_from_row)
405            .collect();
406
407        Ok(edges)
408    }
409}
410
411#[cfg(test)]
412#[path = "tests/graph_persistence_tests.rs"]
413mod tests;