1use crate::{MapStorageErr, Storage};
4use codemem_core::{CodememError, Edge, GraphNode, NodeKind, RelationshipType};
5use rusqlite::{params, OptionalExtension};
6use std::collections::HashMap;
7
8pub(crate) type EdgeRow = (
12 String,
13 String,
14 String,
15 String,
16 f64,
17 String,
18 i64,
19 Option<i64>,
20 Option<i64>,
21);
22
23pub(crate) fn edge_from_row(row: EdgeRow) -> Option<Edge> {
26 let (id, src, dst, rel_str, weight, props_str, created_ts, valid_from_ts, valid_to_ts) = row;
27 let relationship: RelationshipType = match rel_str.parse() {
28 Ok(r) => r,
29 Err(_) => {
30 tracing::warn!(
31 edge_id = %id,
32 relationship = %rel_str,
33 "Dropping edge with unrecognized relationship type"
34 );
35 return None;
36 }
37 };
38 let properties: HashMap<String, serde_json::Value> =
39 serde_json::from_str(&props_str).unwrap_or_default();
40 let created_at = chrono::DateTime::from_timestamp(created_ts, 0)?.with_timezone(&chrono::Utc);
41 let valid_from = valid_from_ts
42 .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
43 .map(|dt| dt.with_timezone(&chrono::Utc));
44 let valid_to = valid_to_ts
45 .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
46 .map(|dt| dt.with_timezone(&chrono::Utc));
47 Some(Edge {
48 id,
49 src,
50 dst,
51 relationship,
52 weight,
53 properties,
54 created_at,
55 valid_from,
56 valid_to,
57 })
58}
59
60pub(crate) fn extract_edge_tuple(row: &rusqlite::Row<'_>) -> rusqlite::Result<EdgeRow> {
64 let rel_str: String = row.get(3)?;
65 let props_str: String = row.get(5)?;
66 let created_ts: i64 = row.get(6)?;
67 let valid_from_ts: Option<i64> = row.get(7)?;
68 let valid_to_ts: Option<i64> = row.get(8)?;
69 Ok((
70 row.get::<_, String>(0)?,
71 row.get::<_, String>(1)?,
72 row.get::<_, String>(2)?,
73 rel_str,
74 row.get::<_, f64>(4)?,
75 props_str,
76 created_ts,
77 valid_from_ts,
78 valid_to_ts,
79 ))
80}
81
82impl Storage {
83 pub fn store_embedding(&self, memory_id: &str, embedding: &[f32]) -> Result<(), CodememError> {
87 let conn = self.conn()?;
88 let blob: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
89
90 conn.execute(
91 "INSERT OR REPLACE INTO memory_embeddings (memory_id, embedding) VALUES (?1, ?2)",
92 params![memory_id, blob],
93 )
94 .storage_err()?;
95
96 Ok(())
97 }
98
99 pub fn get_embedding(&self, memory_id: &str) -> Result<Option<Vec<f32>>, CodememError> {
101 let conn = self.conn()?;
102 let blob: Option<Vec<u8>> = conn
103 .query_row(
104 "SELECT embedding FROM memory_embeddings WHERE memory_id = ?1",
105 params![memory_id],
106 |row| row.get(0),
107 )
108 .optional()
109 .storage_err()?;
110
111 match blob {
112 Some(bytes) => {
113 let floats: Vec<f32> = bytes
114 .chunks_exact(4)
115 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
116 .collect();
117 Ok(Some(floats))
118 }
119 None => Ok(None),
120 }
121 }
122
123 pub fn insert_graph_node(&self, node: &GraphNode) -> Result<(), CodememError> {
127 let conn = self.conn()?;
128 let payload_json = serde_json::to_string(&node.payload)?;
129
130 conn.execute(
131 "INSERT OR REPLACE INTO graph_nodes (id, kind, label, payload, centrality, memory_id, namespace)
132 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
133 params![
134 node.id,
135 node.kind.to_string(),
136 node.label,
137 payload_json,
138 node.centrality,
139 node.memory_id,
140 node.namespace,
141 ],
142 )
143 .storage_err()?;
144
145 Ok(())
146 }
147
148 pub fn get_graph_node(&self, id: &str) -> Result<Option<GraphNode>, CodememError> {
150 let conn = self.conn()?;
151 conn.query_row(
152 "SELECT id, kind, label, payload, centrality, memory_id, namespace FROM graph_nodes WHERE id = ?1",
153 params![id],
154 |row| {
155 let kind_str: String = row.get(1)?;
156 let payload_str: String = row.get(3)?;
157 Ok((
158 row.get::<_, String>(0)?,
159 kind_str,
160 row.get::<_, String>(2)?,
161 payload_str,
162 row.get::<_, f64>(4)?,
163 row.get::<_, Option<String>>(5)?,
164 row.get::<_, Option<String>>(6)?,
165 ))
166 },
167 )
168 .optional()
169 .storage_err()?
170 .map(|(id, kind_str, label, payload_str, centrality, memory_id, namespace)| {
171 let kind: NodeKind = kind_str.parse().map_err(|e: CodememError| CodememError::Storage(e.to_string()))?;
172 let payload: HashMap<String, serde_json::Value> =
173 serde_json::from_str(&payload_str).unwrap_or_default();
174 Ok(GraphNode {
175 id,
176 kind,
177 label,
178 payload,
179 centrality,
180 memory_id,
181 namespace,
182 })
183 })
184 .transpose()
185 }
186
187 pub fn delete_graph_node(&self, id: &str) -> Result<bool, CodememError> {
189 let conn = self.conn()?;
190 let rows = conn
191 .execute("DELETE FROM graph_nodes WHERE id = ?1", params![id])
192 .storage_err()?;
193 Ok(rows > 0)
194 }
195
196 pub fn all_graph_nodes(&self) -> Result<Vec<GraphNode>, CodememError> {
198 let conn = self.conn()?;
199 let mut stmt = conn
200 .prepare("SELECT id, kind, label, payload, centrality, memory_id, namespace FROM graph_nodes")
201 .storage_err()?;
202
203 let rows = stmt
204 .query_map([], |row| {
205 let kind_str: String = row.get(1)?;
206 let payload_str: String = row.get(3)?;
207 Ok((
208 row.get::<_, String>(0)?,
209 kind_str,
210 row.get::<_, String>(2)?,
211 payload_str,
212 row.get::<_, f64>(4)?,
213 row.get::<_, Option<String>>(5)?,
214 row.get::<_, Option<String>>(6)?,
215 ))
216 })
217 .storage_err()?;
218
219 let mut nodes = Vec::new();
220 for row_result in rows {
221 let (id, kind_str, label, payload_str, centrality, memory_id, namespace) =
222 row_result.storage_err()?;
223 let kind: NodeKind = match kind_str.parse() {
224 Ok(k) => k,
225 Err(_) => {
226 tracing::warn!(
227 node_id = %id,
228 kind = %kind_str,
229 "Skipping graph node with unrecognized kind"
230 );
231 continue;
232 }
233 };
234 let payload: HashMap<String, serde_json::Value> =
235 serde_json::from_str(&payload_str).unwrap_or_default();
236 nodes.push(GraphNode {
237 id,
238 kind,
239 label,
240 payload,
241 centrality,
242 memory_id,
243 namespace,
244 });
245 }
246
247 Ok(nodes)
248 }
249
250 pub fn insert_graph_edge(&self, edge: &Edge) -> Result<(), CodememError> {
254 let conn = self.conn()?;
255 let props_json = serde_json::to_string(&edge.properties)?;
256
257 conn.execute(
258 "INSERT OR REPLACE INTO graph_edges (id, src, dst, relationship, weight, properties, created_at, valid_from, valid_to)
259 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
260 params![
261 edge.id,
262 edge.src,
263 edge.dst,
264 edge.relationship.to_string(),
265 edge.weight,
266 props_json,
267 edge.created_at.timestamp(),
268 edge.valid_from.map(|dt| dt.timestamp()),
269 edge.valid_to.map(|dt| dt.timestamp()),
270 ],
271 )
272 .storage_err()?;
273
274 Ok(())
275 }
276
277 pub fn get_edges_for_node(&self, node_id: &str) -> Result<Vec<Edge>, CodememError> {
279 let conn = self.conn()?;
280 let mut stmt = conn
281 .prepare(
282 "SELECT id, src, dst, relationship, weight, properties, created_at, valid_from, valid_to FROM graph_edges WHERE src = ?1 OR dst = ?1",
283 )
284 .storage_err()?;
285
286 let edges = stmt
287 .query_map(params![node_id], extract_edge_tuple)
288 .storage_err()?
289 .filter_map(|r| match r {
290 Ok(v) => Some(v),
291 Err(e) => {
292 tracing::warn!("Failed to process edge row: {e}");
293 None
294 }
295 })
296 .filter_map(edge_from_row)
297 .collect();
298
299 Ok(edges)
300 }
301
302 pub fn all_graph_edges(&self) -> Result<Vec<Edge>, CodememError> {
304 let conn = self.conn()?;
305 let mut stmt = conn
306 .prepare("SELECT id, src, dst, relationship, weight, properties, created_at, valid_from, valid_to FROM graph_edges")
307 .storage_err()?;
308
309 let edges = stmt
310 .query_map([], extract_edge_tuple)
311 .storage_err()?
312 .filter_map(|r| match r {
313 Ok(v) => Some(v),
314 Err(e) => {
315 tracing::warn!("Failed to process edge row: {e}");
316 None
317 }
318 })
319 .filter_map(edge_from_row)
320 .collect();
321
322 Ok(edges)
323 }
324
325 pub fn delete_graph_edges_for_node(&self, node_id: &str) -> Result<usize, CodememError> {
327 let conn = self.conn()?;
328 let rows = conn
329 .execute(
330 "DELETE FROM graph_edges WHERE src = ?1 OR dst = ?1",
331 params![node_id],
332 )
333 .storage_err()?;
334 Ok(rows)
335 }
336
337 pub fn graph_edges_for_namespace(&self, namespace: &str) -> Result<Vec<Edge>, CodememError> {
339 let conn = self.conn()?;
340 let mut stmt = conn
341 .prepare(
342 "SELECT e.id, e.src, e.dst, e.relationship, e.weight, e.properties, e.created_at, e.valid_from, e.valid_to
343 FROM graph_edges e
344 INNER JOIN graph_nodes gs ON e.src = gs.id
345 INNER JOIN graph_nodes gd ON e.dst = gd.id
346 WHERE gs.namespace = ?1 AND gd.namespace = ?1",
347 )
348 .storage_err()?;
349
350 let edges = stmt
351 .query_map(params![namespace], extract_edge_tuple)
352 .storage_err()?
353 .filter_map(|r| match r {
354 Ok(v) => Some(v),
355 Err(e) => {
356 tracing::warn!("Failed to process edge row: {e}");
357 None
358 }
359 })
360 .filter_map(edge_from_row)
361 .collect();
362
363 Ok(edges)
364 }
365}
366
367#[cfg(test)]
368#[path = "tests/graph_persistence_tests.rs"]
369mod tests;