1use crate::Storage;
4use codemem_core::{CodememError, Edge, GraphNode, NodeKind, RelationshipType};
5use rusqlite::{params, OptionalExtension};
6use std::collections::HashMap;
7
8impl Storage {
9 pub fn store_embedding(&self, memory_id: &str, embedding: &[f32]) -> Result<(), CodememError> {
13 let conn = self.conn();
14 let blob: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
15
16 conn.execute(
17 "INSERT OR REPLACE INTO memory_embeddings (memory_id, embedding) VALUES (?1, ?2)",
18 params![memory_id, blob],
19 )
20 .map_err(|e| CodememError::Storage(e.to_string()))?;
21
22 Ok(())
23 }
24
25 pub fn get_embedding(&self, memory_id: &str) -> Result<Option<Vec<f32>>, CodememError> {
27 let conn = self.conn();
28 let blob: Option<Vec<u8>> = conn
29 .query_row(
30 "SELECT embedding FROM memory_embeddings WHERE memory_id = ?1",
31 params![memory_id],
32 |row| row.get(0),
33 )
34 .optional()
35 .map_err(|e| CodememError::Storage(e.to_string()))?;
36
37 match blob {
38 Some(bytes) => {
39 let floats: Vec<f32> = bytes
40 .chunks_exact(4)
41 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
42 .collect();
43 Ok(Some(floats))
44 }
45 None => Ok(None),
46 }
47 }
48
49 pub fn insert_graph_node(&self, node: &GraphNode) -> Result<(), CodememError> {
53 let conn = self.conn();
54 let payload_json = serde_json::to_string(&node.payload)?;
55
56 conn.execute(
57 "INSERT OR REPLACE INTO graph_nodes (id, kind, label, payload, centrality, memory_id, namespace)
58 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
59 params![
60 node.id,
61 node.kind.to_string(),
62 node.label,
63 payload_json,
64 node.centrality,
65 node.memory_id,
66 node.namespace,
67 ],
68 )
69 .map_err(|e| CodememError::Storage(e.to_string()))?;
70
71 Ok(())
72 }
73
74 pub fn get_graph_node(&self, id: &str) -> Result<Option<GraphNode>, CodememError> {
76 let conn = self.conn();
77 conn.query_row(
78 "SELECT id, kind, label, payload, centrality, memory_id, namespace FROM graph_nodes WHERE id = ?1",
79 params![id],
80 |row| {
81 let kind_str: String = row.get(1)?;
82 let payload_str: String = row.get(3)?;
83 Ok((
84 row.get::<_, String>(0)?,
85 kind_str,
86 row.get::<_, String>(2)?,
87 payload_str,
88 row.get::<_, f64>(4)?,
89 row.get::<_, Option<String>>(5)?,
90 row.get::<_, Option<String>>(6)?,
91 ))
92 },
93 )
94 .optional()
95 .map_err(|e| CodememError::Storage(e.to_string()))?
96 .map(|(id, kind_str, label, payload_str, centrality, memory_id, namespace)| {
97 let kind: NodeKind = kind_str.parse().map_err(|e: CodememError| CodememError::Storage(e.to_string()))?;
98 let payload: HashMap<String, serde_json::Value> =
99 serde_json::from_str(&payload_str).unwrap_or_default();
100 Ok(GraphNode {
101 id,
102 kind,
103 label,
104 payload,
105 centrality,
106 memory_id,
107 namespace,
108 })
109 })
110 .transpose()
111 }
112
113 pub fn delete_graph_node(&self, id: &str) -> Result<bool, CodememError> {
115 let conn = self.conn();
116 let rows = conn
117 .execute("DELETE FROM graph_nodes WHERE id = ?1", params![id])
118 .map_err(|e| CodememError::Storage(e.to_string()))?;
119 Ok(rows > 0)
120 }
121
122 pub fn all_graph_nodes(&self) -> Result<Vec<GraphNode>, CodememError> {
124 let conn = self.conn();
125 let mut stmt = conn
126 .prepare("SELECT id, kind, label, payload, centrality, memory_id, namespace FROM graph_nodes")
127 .map_err(|e| CodememError::Storage(e.to_string()))?;
128
129 let nodes = stmt
130 .query_map([], |row| {
131 let kind_str: String = row.get(1)?;
132 let payload_str: String = row.get(3)?;
133 Ok((
134 row.get::<_, String>(0)?,
135 kind_str,
136 row.get::<_, String>(2)?,
137 payload_str,
138 row.get::<_, f64>(4)?,
139 row.get::<_, Option<String>>(5)?,
140 row.get::<_, Option<String>>(6)?,
141 ))
142 })
143 .map_err(|e| CodememError::Storage(e.to_string()))?
144 .filter_map(|r| r.ok())
145 .filter_map(
146 |(id, kind_str, label, payload_str, centrality, memory_id, namespace)| {
147 let kind: NodeKind = kind_str.parse().ok()?;
148 let payload: HashMap<String, serde_json::Value> =
149 serde_json::from_str(&payload_str).unwrap_or_default();
150 Some(GraphNode {
151 id,
152 kind,
153 label,
154 payload,
155 centrality,
156 memory_id,
157 namespace,
158 })
159 },
160 )
161 .collect();
162
163 Ok(nodes)
164 }
165
166 pub fn insert_graph_edge(&self, edge: &Edge) -> Result<(), CodememError> {
170 let conn = self.conn();
171 let props_json = serde_json::to_string(&edge.properties)?;
172
173 conn.execute(
174 "INSERT OR REPLACE INTO graph_edges (id, src, dst, relationship, weight, properties, created_at, valid_from, valid_to)
175 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
176 params![
177 edge.id,
178 edge.src,
179 edge.dst,
180 edge.relationship.to_string(),
181 edge.weight,
182 props_json,
183 edge.created_at.timestamp(),
184 edge.valid_from.map(|dt| dt.timestamp()),
185 edge.valid_to.map(|dt| dt.timestamp()),
186 ],
187 )
188 .map_err(|e| CodememError::Storage(e.to_string()))?;
189
190 Ok(())
191 }
192
193 pub fn get_edges_for_node(&self, node_id: &str) -> Result<Vec<Edge>, CodememError> {
195 let conn = self.conn();
196 let mut stmt = conn
197 .prepare(
198 "SELECT id, src, dst, relationship, weight, properties, created_at, valid_from, valid_to FROM graph_edges WHERE src = ?1 OR dst = ?1",
199 )
200 .map_err(|e| CodememError::Storage(e.to_string()))?;
201
202 let edges = stmt
203 .query_map(params![node_id], |row| {
204 let rel_str: String = row.get(3)?;
205 let props_str: String = row.get(5)?;
206 let created_ts: i64 = row.get(6)?;
207 let valid_from_ts: Option<i64> = row.get(7)?;
208 let valid_to_ts: Option<i64> = row.get(8)?;
209 Ok((
210 row.get::<_, String>(0)?,
211 row.get::<_, String>(1)?,
212 row.get::<_, String>(2)?,
213 rel_str,
214 row.get::<_, f64>(4)?,
215 props_str,
216 created_ts,
217 valid_from_ts,
218 valid_to_ts,
219 ))
220 })
221 .map_err(|e| CodememError::Storage(e.to_string()))?
222 .filter_map(|r| r.ok())
223 .filter_map(
224 |(
225 id,
226 src,
227 dst,
228 rel_str,
229 weight,
230 props_str,
231 created_ts,
232 valid_from_ts,
233 valid_to_ts,
234 )| {
235 let relationship: RelationshipType = rel_str.parse().ok()?;
236 let properties: HashMap<String, serde_json::Value> =
237 serde_json::from_str(&props_str).unwrap_or_default();
238 let created_at = chrono::DateTime::from_timestamp(created_ts, 0)?
239 .with_timezone(&chrono::Utc);
240 let valid_from = valid_from_ts
241 .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
242 .map(|dt| dt.with_timezone(&chrono::Utc));
243 let valid_to = valid_to_ts
244 .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
245 .map(|dt| dt.with_timezone(&chrono::Utc));
246 Some(Edge {
247 id,
248 src,
249 dst,
250 relationship,
251 weight,
252 properties,
253 created_at,
254 valid_from,
255 valid_to,
256 })
257 },
258 )
259 .collect();
260
261 Ok(edges)
262 }
263
264 pub fn all_graph_edges(&self) -> Result<Vec<Edge>, CodememError> {
266 let conn = self.conn();
267 let mut stmt = conn
268 .prepare("SELECT id, src, dst, relationship, weight, properties, created_at, valid_from, valid_to FROM graph_edges")
269 .map_err(|e| CodememError::Storage(e.to_string()))?;
270
271 let edges = stmt
272 .query_map([], |row| {
273 let rel_str: String = row.get(3)?;
274 let props_str: String = row.get(5)?;
275 let created_ts: i64 = row.get(6)?;
276 let valid_from_ts: Option<i64> = row.get(7)?;
277 let valid_to_ts: Option<i64> = row.get(8)?;
278 Ok((
279 row.get::<_, String>(0)?,
280 row.get::<_, String>(1)?,
281 row.get::<_, String>(2)?,
282 rel_str,
283 row.get::<_, f64>(4)?,
284 props_str,
285 created_ts,
286 valid_from_ts,
287 valid_to_ts,
288 ))
289 })
290 .map_err(|e| CodememError::Storage(e.to_string()))?
291 .filter_map(|r| r.ok())
292 .filter_map(
293 |(
294 id,
295 src,
296 dst,
297 rel_str,
298 weight,
299 props_str,
300 created_ts,
301 valid_from_ts,
302 valid_to_ts,
303 )| {
304 let relationship: RelationshipType = rel_str.parse().ok()?;
305 let properties: HashMap<String, serde_json::Value> =
306 serde_json::from_str(&props_str).unwrap_or_default();
307 let created_at = chrono::DateTime::from_timestamp(created_ts, 0)?
308 .with_timezone(&chrono::Utc);
309 let valid_from = valid_from_ts
310 .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
311 .map(|dt| dt.with_timezone(&chrono::Utc));
312 let valid_to = valid_to_ts
313 .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
314 .map(|dt| dt.with_timezone(&chrono::Utc));
315 Some(Edge {
316 id,
317 src,
318 dst,
319 relationship,
320 weight,
321 properties,
322 created_at,
323 valid_from,
324 valid_to,
325 })
326 },
327 )
328 .collect();
329
330 Ok(edges)
331 }
332
333 pub fn delete_graph_edges_for_node(&self, node_id: &str) -> Result<usize, CodememError> {
335 let conn = self.conn();
336 let rows = conn
337 .execute(
338 "DELETE FROM graph_edges WHERE src = ?1 OR dst = ?1",
339 params![node_id],
340 )
341 .map_err(|e| CodememError::Storage(e.to_string()))?;
342 Ok(rows)
343 }
344
345 pub fn graph_edges_for_namespace(&self, namespace: &str) -> Result<Vec<Edge>, CodememError> {
347 let conn = self.conn();
348 let mut stmt = conn
349 .prepare(
350 "SELECT e.id, e.src, e.dst, e.relationship, e.weight, e.properties, e.created_at, e.valid_from, e.valid_to
351 FROM graph_edges e
352 INNER JOIN graph_nodes gs ON e.src = gs.id
353 INNER JOIN graph_nodes gd ON e.dst = gd.id
354 WHERE gs.namespace = ?1 AND gd.namespace = ?1",
355 )
356 .map_err(|e| CodememError::Storage(e.to_string()))?;
357
358 let edges = stmt
359 .query_map(params![namespace], |row| {
360 let rel_str: String = row.get(3)?;
361 let props_str: String = row.get(5)?;
362 let created_ts: i64 = row.get(6)?;
363 let valid_from_ts: Option<i64> = row.get(7)?;
364 let valid_to_ts: Option<i64> = row.get(8)?;
365 Ok((
366 row.get::<_, String>(0)?,
367 row.get::<_, String>(1)?,
368 row.get::<_, String>(2)?,
369 rel_str,
370 row.get::<_, f64>(4)?,
371 props_str,
372 created_ts,
373 valid_from_ts,
374 valid_to_ts,
375 ))
376 })
377 .map_err(|e| CodememError::Storage(e.to_string()))?
378 .filter_map(|r| r.ok())
379 .filter_map(
380 |(
381 id,
382 src,
383 dst,
384 rel_str,
385 weight,
386 props_str,
387 created_ts,
388 valid_from_ts,
389 valid_to_ts,
390 )| {
391 let relationship: RelationshipType = rel_str.parse().ok()?;
392 let properties: HashMap<String, serde_json::Value> =
393 serde_json::from_str(&props_str).unwrap_or_default();
394 let created_at = chrono::DateTime::from_timestamp(created_ts, 0)?
395 .with_timezone(&chrono::Utc);
396 let valid_from = valid_from_ts
397 .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
398 .map(|dt| dt.with_timezone(&chrono::Utc));
399 let valid_to = valid_to_ts
400 .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
401 .map(|dt| dt.with_timezone(&chrono::Utc));
402 Some(Edge {
403 id,
404 src,
405 dst,
406 relationship,
407 weight,
408 properties,
409 created_at,
410 valid_from,
411 valid_to,
412 })
413 },
414 )
415 .collect();
416
417 Ok(edges)
418 }
419
420 pub fn delete_graph_edge(&self, id: &str) -> Result<bool, CodememError> {
422 let conn = self.conn();
423 let rows = conn
424 .execute("DELETE FROM graph_edges WHERE id = ?1", params![id])
425 .map_err(|e| CodememError::Storage(e.to_string()))?;
426 Ok(rows > 0)
427 }
428}
429
430#[cfg(test)]
431#[path = "tests/graph_persistence_tests.rs"]
432mod tests;