1use crate::{MapStorageErr, Storage};
4use codemem_core::{CodememError, Edge, GraphNode, NodeKind, RelationshipType};
5use rusqlite::{params, OptionalExtension};
6use std::collections::HashMap;
7
8pub(crate) type EdgeRow = (
12 String,
13 String,
14 String,
15 String,
16 f64,
17 String,
18 i64,
19 Option<i64>,
20 Option<i64>,
21);
22
23pub(crate) fn edge_from_row(row: EdgeRow) -> Option<Edge> {
26 let (id, src, dst, rel_str, weight, props_str, created_ts, valid_from_ts, valid_to_ts) = row;
27 let relationship: RelationshipType = match rel_str.parse() {
28 Ok(r) => r,
29 Err(_) => {
30 tracing::warn!(
31 edge_id = %id,
32 relationship = %rel_str,
33 "Dropping edge with unrecognized relationship type"
34 );
35 return None;
36 }
37 };
38 let properties: HashMap<String, serde_json::Value> =
39 serde_json::from_str(&props_str).unwrap_or_default();
40 let created_at = chrono::DateTime::from_timestamp(created_ts, 0)?.with_timezone(&chrono::Utc);
41 let valid_from = valid_from_ts
42 .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
43 .map(|dt| dt.with_timezone(&chrono::Utc));
44 let valid_to = valid_to_ts
45 .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
46 .map(|dt| dt.with_timezone(&chrono::Utc));
47 Some(Edge {
48 id,
49 src,
50 dst,
51 relationship,
52 weight,
53 properties,
54 created_at,
55 valid_from,
56 valid_to,
57 })
58}
59
60pub(crate) fn extract_edge_tuple(row: &rusqlite::Row<'_>) -> rusqlite::Result<EdgeRow> {
64 let rel_str: String = row.get(3)?;
65 let props_str: String = row.get(5)?;
66 let created_ts: i64 = row.get(6)?;
67 let valid_from_ts: Option<i64> = row.get(7)?;
68 let valid_to_ts: Option<i64> = row.get(8)?;
69 Ok((
70 row.get::<_, String>(0)?,
71 row.get::<_, String>(1)?,
72 row.get::<_, String>(2)?,
73 rel_str,
74 row.get::<_, f64>(4)?,
75 props_str,
76 created_ts,
77 valid_from_ts,
78 valid_to_ts,
79 ))
80}
81
82impl Storage {
83 pub fn store_embedding(&self, memory_id: &str, embedding: &[f32]) -> Result<(), CodememError> {
87 let conn = self.conn()?;
88 let blob: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
89
90 conn.execute(
91 "INSERT OR REPLACE INTO memory_embeddings (memory_id, embedding) VALUES (?1, ?2)",
92 params![memory_id, blob],
93 )
94 .storage_err()?;
95
96 Ok(())
97 }
98
99 pub fn get_embedding(&self, memory_id: &str) -> Result<Option<Vec<f32>>, CodememError> {
101 let conn = self.conn()?;
102 let blob: Option<Vec<u8>> = conn
103 .query_row(
104 "SELECT embedding FROM memory_embeddings WHERE memory_id = ?1",
105 params![memory_id],
106 |row| row.get(0),
107 )
108 .optional()
109 .storage_err()?;
110
111 match blob {
112 Some(bytes) => {
113 let floats: Vec<f32> = bytes
114 .chunks_exact(4)
115 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
116 .collect();
117 Ok(Some(floats))
118 }
119 None => Ok(None),
120 }
121 }
122
123 pub fn insert_graph_node(&self, node: &GraphNode) -> Result<(), CodememError> {
127 let conn = self.conn()?;
128 let payload_json = serde_json::to_string(&node.payload)?;
129
130 conn.execute(
131 "INSERT OR REPLACE INTO graph_nodes (id, kind, label, payload, centrality, memory_id, namespace, valid_from, valid_to)
132 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
133 params![
134 node.id,
135 node.kind.to_string(),
136 node.label,
137 payload_json,
138 node.centrality,
139 node.memory_id,
140 node.namespace,
141 node.valid_from.map(|dt| dt.timestamp()),
142 node.valid_to.map(|dt| dt.timestamp()),
143 ],
144 )
145 .storage_err()?;
146
147 Ok(())
148 }
149
150 pub fn get_graph_node(&self, id: &str) -> Result<Option<GraphNode>, CodememError> {
152 let conn = self.conn()?;
153 conn.query_row(
154 "SELECT id, kind, label, payload, centrality, memory_id, namespace, valid_from, valid_to FROM graph_nodes WHERE id = ?1",
155 params![id],
156 |row| {
157 let kind_str: String = row.get(1)?;
158 let payload_str: String = row.get(3)?;
159 Ok((
160 row.get::<_, String>(0)?,
161 kind_str,
162 row.get::<_, String>(2)?,
163 payload_str,
164 row.get::<_, f64>(4)?,
165 row.get::<_, Option<String>>(5)?,
166 row.get::<_, Option<String>>(6)?,
167 row.get::<_, Option<i64>>(7)?,
168 row.get::<_, Option<i64>>(8)?,
169 ))
170 },
171 )
172 .optional()
173 .storage_err()?
174 .map(|(id, kind_str, label, payload_str, centrality, memory_id, namespace, valid_from_ts, valid_to_ts)| {
175 let kind: NodeKind = kind_str.parse().map_err(|e: CodememError| CodememError::Storage(e.to_string()))?;
176 let payload: HashMap<String, serde_json::Value> =
177 serde_json::from_str(&payload_str).unwrap_or_default();
178 Ok(GraphNode {
179 id,
180 kind,
181 label,
182 payload,
183 centrality,
184 memory_id,
185 namespace,
186 valid_from: valid_from_ts.and_then(|ts| chrono::DateTime::from_timestamp(ts, 0)),
187 valid_to: valid_to_ts.and_then(|ts| chrono::DateTime::from_timestamp(ts, 0)),
188 })
189 })
190 .transpose()
191 }
192
193 pub fn delete_graph_node(&self, id: &str) -> Result<bool, CodememError> {
195 let conn = self.conn()?;
196 let rows = conn
197 .execute("DELETE FROM graph_nodes WHERE id = ?1", params![id])
198 .storage_err()?;
199 Ok(rows > 0)
200 }
201
202 pub fn all_graph_nodes(&self) -> Result<Vec<GraphNode>, CodememError> {
204 let conn = self.conn()?;
205 let mut stmt = conn
206 .prepare("SELECT id, kind, label, payload, centrality, memory_id, namespace, valid_from, valid_to FROM graph_nodes")
207 .storage_err()?;
208
209 let rows = stmt
210 .query_map([], |row| {
211 let kind_str: String = row.get(1)?;
212 let payload_str: String = row.get(3)?;
213 Ok((
214 row.get::<_, String>(0)?,
215 kind_str,
216 row.get::<_, String>(2)?,
217 payload_str,
218 row.get::<_, f64>(4)?,
219 row.get::<_, Option<String>>(5)?,
220 row.get::<_, Option<String>>(6)?,
221 row.get::<_, Option<i64>>(7)?,
222 row.get::<_, Option<i64>>(8)?,
223 ))
224 })
225 .storage_err()?;
226
227 let mut nodes = Vec::new();
228 for row_result in rows {
229 let (
230 id,
231 kind_str,
232 label,
233 payload_str,
234 centrality,
235 memory_id,
236 namespace,
237 valid_from_ts,
238 valid_to_ts,
239 ) = row_result.storage_err()?;
240 let kind: NodeKind = match kind_str.parse() {
241 Ok(k) => k,
242 Err(_) => {
243 tracing::warn!(
244 node_id = %id,
245 kind = %kind_str,
246 "Skipping graph node with unrecognized kind"
247 );
248 continue;
249 }
250 };
251 let payload: HashMap<String, serde_json::Value> =
252 serde_json::from_str(&payload_str).unwrap_or_default();
253 nodes.push(GraphNode {
254 id,
255 kind,
256 label,
257 payload,
258 centrality,
259 memory_id,
260 namespace,
261 valid_from: valid_from_ts.and_then(|ts| chrono::DateTime::from_timestamp(ts, 0)),
262 valid_to: valid_to_ts.and_then(|ts| chrono::DateTime::from_timestamp(ts, 0)),
263 });
264 }
265
266 Ok(nodes)
267 }
268
269 pub fn insert_graph_edge(&self, edge: &Edge) -> Result<(), CodememError> {
273 let conn = self.conn()?;
274 let props_json = serde_json::to_string(&edge.properties)?;
275
276 conn.execute(
277 "INSERT OR REPLACE INTO graph_edges (id, src, dst, relationship, weight, properties, created_at, valid_from, valid_to)
278 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
279 params![
280 edge.id,
281 edge.src,
282 edge.dst,
283 edge.relationship.to_string(),
284 edge.weight,
285 props_json,
286 edge.created_at.timestamp(),
287 edge.valid_from.map(|dt| dt.timestamp()),
288 edge.valid_to.map(|dt| dt.timestamp()),
289 ],
290 )
291 .storage_err()?;
292
293 Ok(())
294 }
295
296 pub fn get_edges_for_node(&self, node_id: &str) -> Result<Vec<Edge>, CodememError> {
298 let conn = self.conn()?;
299 let mut stmt = conn
300 .prepare(
301 "SELECT id, src, dst, relationship, weight, properties, created_at, valid_from, valid_to FROM graph_edges WHERE src = ?1 OR dst = ?1",
302 )
303 .storage_err()?;
304
305 let edges = stmt
306 .query_map(params![node_id], extract_edge_tuple)
307 .storage_err()?
308 .filter_map(|r| match r {
309 Ok(v) => Some(v),
310 Err(e) => {
311 tracing::warn!("Failed to process edge row: {e}");
312 None
313 }
314 })
315 .filter_map(edge_from_row)
316 .collect();
317
318 Ok(edges)
319 }
320
321 pub fn all_graph_edges(&self) -> Result<Vec<Edge>, CodememError> {
323 let conn = self.conn()?;
324 let mut stmt = conn
325 .prepare("SELECT id, src, dst, relationship, weight, properties, created_at, valid_from, valid_to FROM graph_edges")
326 .storage_err()?;
327
328 let edges = stmt
329 .query_map([], extract_edge_tuple)
330 .storage_err()?
331 .filter_map(|r| match r {
332 Ok(v) => Some(v),
333 Err(e) => {
334 tracing::warn!("Failed to process edge row: {e}");
335 None
336 }
337 })
338 .filter_map(edge_from_row)
339 .collect();
340
341 Ok(edges)
342 }
343
344 pub fn delete_graph_edge(&self, edge_id: &str) -> Result<bool, CodememError> {
346 let conn = self.conn()?;
347 let rows = conn
348 .execute("DELETE FROM graph_edges WHERE id = ?1", params![edge_id])
349 .storage_err()?;
350 Ok(rows > 0)
351 }
352
353 pub fn delete_graph_edges_for_node(&self, node_id: &str) -> Result<usize, CodememError> {
355 let conn = self.conn()?;
356 let rows = conn
357 .execute(
358 "DELETE FROM graph_edges WHERE src = ?1 OR dst = ?1",
359 params![node_id],
360 )
361 .storage_err()?;
362 Ok(rows)
363 }
364
365 pub fn graph_edges_for_namespace(&self, namespace: &str) -> Result<Vec<Edge>, CodememError> {
367 self.graph_edges_for_namespace_with_cross(namespace, false)
368 }
369
370 pub fn graph_edges_for_namespace_with_cross(
375 &self,
376 namespace: &str,
377 include_cross_namespace: bool,
378 ) -> Result<Vec<Edge>, CodememError> {
379 let conn = self.conn()?;
380 let condition = if include_cross_namespace {
381 "gs.namespace = ?1 OR gd.namespace = ?1"
382 } else {
383 "gs.namespace = ?1 AND gd.namespace = ?1"
384 };
385 let sql = format!(
386 "SELECT e.id, e.src, e.dst, e.relationship, e.weight, e.properties, e.created_at, e.valid_from, e.valid_to
387 FROM graph_edges e
388 INNER JOIN graph_nodes gs ON e.src = gs.id
389 INNER JOIN graph_nodes gd ON e.dst = gd.id
390 WHERE {condition}"
391 );
392 let mut stmt = conn.prepare(&sql).storage_err()?;
393
394 let edges = stmt
395 .query_map(params![namespace], extract_edge_tuple)
396 .storage_err()?
397 .filter_map(|r| match r {
398 Ok(v) => Some(v),
399 Err(e) => {
400 tracing::warn!("Failed to process edge row: {e}");
401 None
402 }
403 })
404 .filter_map(edge_from_row)
405 .collect();
406
407 Ok(edges)
408 }
409}
410
411#[cfg(test)]
412#[path = "tests/graph_persistence_tests.rs"]
413mod tests;