1use crate::Storage;
4use codemem_core::{CodememError, Edge, GraphNode, NodeKind, RelationshipType};
5use rusqlite::{params, OptionalExtension};
6use std::collections::HashMap;
7
8impl Storage {
9 pub fn store_embedding(&self, memory_id: &str, embedding: &[f32]) -> Result<(), CodememError> {
13 let conn = self.conn();
14 let blob: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
15
16 conn.execute(
17 "INSERT OR REPLACE INTO memory_embeddings (memory_id, embedding) VALUES (?1, ?2)",
18 params![memory_id, blob],
19 )
20 .map_err(|e| CodememError::Storage(e.to_string()))?;
21
22 Ok(())
23 }
24
25 pub fn get_embedding(&self, memory_id: &str) -> Result<Option<Vec<f32>>, CodememError> {
27 let conn = self.conn();
28 let blob: Option<Vec<u8>> = conn
29 .query_row(
30 "SELECT embedding FROM memory_embeddings WHERE memory_id = ?1",
31 params![memory_id],
32 |row| row.get(0),
33 )
34 .optional()
35 .map_err(|e| CodememError::Storage(e.to_string()))?;
36
37 match blob {
38 Some(bytes) => {
39 let floats: Vec<f32> = bytes
40 .chunks_exact(4)
41 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
42 .collect();
43 Ok(Some(floats))
44 }
45 None => Ok(None),
46 }
47 }
48
49 pub fn insert_graph_node(&self, node: &GraphNode) -> Result<(), CodememError> {
53 let conn = self.conn();
54 let payload_json = serde_json::to_string(&node.payload)?;
55
56 conn.execute(
57 "INSERT OR REPLACE INTO graph_nodes (id, kind, label, payload, centrality, memory_id, namespace)
58 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
59 params![
60 node.id,
61 node.kind.to_string(),
62 node.label,
63 payload_json,
64 node.centrality,
65 node.memory_id,
66 node.namespace,
67 ],
68 )
69 .map_err(|e| CodememError::Storage(e.to_string()))?;
70
71 Ok(())
72 }
73
74 pub fn get_graph_node(&self, id: &str) -> Result<Option<GraphNode>, CodememError> {
76 let conn = self.conn();
77 conn.query_row(
78 "SELECT id, kind, label, payload, centrality, memory_id, namespace FROM graph_nodes WHERE id = ?1",
79 params![id],
80 |row| {
81 let kind_str: String = row.get(1)?;
82 let payload_str: String = row.get(3)?;
83 Ok((
84 row.get::<_, String>(0)?,
85 kind_str,
86 row.get::<_, String>(2)?,
87 payload_str,
88 row.get::<_, f64>(4)?,
89 row.get::<_, Option<String>>(5)?,
90 row.get::<_, Option<String>>(6)?,
91 ))
92 },
93 )
94 .optional()
95 .map_err(|e| CodememError::Storage(e.to_string()))?
96 .map(|(id, kind_str, label, payload_str, centrality, memory_id, namespace)| {
97 let kind: NodeKind = kind_str.parse().map_err(|e: CodememError| CodememError::Storage(e.to_string()))?;
98 let payload: HashMap<String, serde_json::Value> =
99 serde_json::from_str(&payload_str).unwrap_or_default();
100 Ok(GraphNode {
101 id,
102 kind,
103 label,
104 payload,
105 centrality,
106 memory_id,
107 namespace,
108 })
109 })
110 .transpose()
111 }
112
113 pub fn delete_graph_node(&self, id: &str) -> Result<bool, CodememError> {
115 let conn = self.conn();
116 let rows = conn
117 .execute("DELETE FROM graph_nodes WHERE id = ?1", params![id])
118 .map_err(|e| CodememError::Storage(e.to_string()))?;
119 Ok(rows > 0)
120 }
121
122 pub fn all_graph_nodes(&self) -> Result<Vec<GraphNode>, CodememError> {
124 let conn = self.conn();
125 let mut stmt = conn
126 .prepare("SELECT id, kind, label, payload, centrality, memory_id, namespace FROM graph_nodes")
127 .map_err(|e| CodememError::Storage(e.to_string()))?;
128
129 let nodes = stmt
130 .query_map([], |row| {
131 let kind_str: String = row.get(1)?;
132 let payload_str: String = row.get(3)?;
133 Ok((
134 row.get::<_, String>(0)?,
135 kind_str,
136 row.get::<_, String>(2)?,
137 payload_str,
138 row.get::<_, f64>(4)?,
139 row.get::<_, Option<String>>(5)?,
140 row.get::<_, Option<String>>(6)?,
141 ))
142 })
143 .map_err(|e| CodememError::Storage(e.to_string()))?
144 .filter_map(|r| r.ok())
145 .filter_map(
146 |(id, kind_str, label, payload_str, centrality, memory_id, namespace)| {
147 let kind: NodeKind = kind_str.parse().ok()?;
148 let payload: HashMap<String, serde_json::Value> =
149 serde_json::from_str(&payload_str).unwrap_or_default();
150 Some(GraphNode {
151 id,
152 kind,
153 label,
154 payload,
155 centrality,
156 memory_id,
157 namespace,
158 })
159 },
160 )
161 .collect();
162
163 Ok(nodes)
164 }
165
166 pub fn insert_graph_edge(&self, edge: &Edge) -> Result<(), CodememError> {
170 let conn = self.conn();
171 let props_json = serde_json::to_string(&edge.properties)?;
172
173 conn.execute(
174 "INSERT OR REPLACE INTO graph_edges (id, src, dst, relationship, weight, properties, created_at)
175 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
176 params![
177 edge.id,
178 edge.src,
179 edge.dst,
180 edge.relationship.to_string(),
181 edge.weight,
182 props_json,
183 edge.created_at.timestamp(),
184 ],
185 )
186 .map_err(|e| CodememError::Storage(e.to_string()))?;
187
188 Ok(())
189 }
190
191 pub fn get_edges_for_node(&self, node_id: &str) -> Result<Vec<Edge>, CodememError> {
193 let conn = self.conn();
194 let mut stmt = conn
195 .prepare(
196 "SELECT id, src, dst, relationship, weight, properties, created_at FROM graph_edges WHERE src = ?1 OR dst = ?1",
197 )
198 .map_err(|e| CodememError::Storage(e.to_string()))?;
199
200 let edges = stmt
201 .query_map(params![node_id], |row| {
202 let rel_str: String = row.get(3)?;
203 let props_str: String = row.get(5)?;
204 let created_ts: i64 = row.get(6)?;
205 Ok((
206 row.get::<_, String>(0)?,
207 row.get::<_, String>(1)?,
208 row.get::<_, String>(2)?,
209 rel_str,
210 row.get::<_, f64>(4)?,
211 props_str,
212 created_ts,
213 ))
214 })
215 .map_err(|e| CodememError::Storage(e.to_string()))?
216 .filter_map(|r| r.ok())
217 .filter_map(|(id, src, dst, rel_str, weight, props_str, created_ts)| {
218 let relationship: RelationshipType = rel_str.parse().ok()?;
219 let properties: HashMap<String, serde_json::Value> =
220 serde_json::from_str(&props_str).unwrap_or_default();
221 let created_at =
222 chrono::DateTime::from_timestamp(created_ts, 0)?.with_timezone(&chrono::Utc);
223 Some(Edge {
224 id,
225 src,
226 dst,
227 relationship,
228 weight,
229 properties,
230 created_at,
231 })
232 })
233 .collect();
234
235 Ok(edges)
236 }
237
238 pub fn all_graph_edges(&self) -> Result<Vec<Edge>, CodememError> {
240 let conn = self.conn();
241 let mut stmt = conn
242 .prepare("SELECT id, src, dst, relationship, weight, properties, created_at FROM graph_edges")
243 .map_err(|e| CodememError::Storage(e.to_string()))?;
244
245 let edges = stmt
246 .query_map([], |row| {
247 let rel_str: String = row.get(3)?;
248 let props_str: String = row.get(5)?;
249 let created_ts: i64 = row.get(6)?;
250 Ok((
251 row.get::<_, String>(0)?,
252 row.get::<_, String>(1)?,
253 row.get::<_, String>(2)?,
254 rel_str,
255 row.get::<_, f64>(4)?,
256 props_str,
257 created_ts,
258 ))
259 })
260 .map_err(|e| CodememError::Storage(e.to_string()))?
261 .filter_map(|r| r.ok())
262 .filter_map(|(id, src, dst, rel_str, weight, props_str, created_ts)| {
263 let relationship: RelationshipType = rel_str.parse().ok()?;
264 let properties: HashMap<String, serde_json::Value> =
265 serde_json::from_str(&props_str).unwrap_or_default();
266 let created_at =
267 chrono::DateTime::from_timestamp(created_ts, 0)?.with_timezone(&chrono::Utc);
268 Some(Edge {
269 id,
270 src,
271 dst,
272 relationship,
273 weight,
274 properties,
275 created_at,
276 })
277 })
278 .collect();
279
280 Ok(edges)
281 }
282
283 pub fn delete_graph_edges_for_node(&self, node_id: &str) -> Result<usize, CodememError> {
285 let conn = self.conn();
286 let rows = conn
287 .execute(
288 "DELETE FROM graph_edges WHERE src = ?1 OR dst = ?1",
289 params![node_id],
290 )
291 .map_err(|e| CodememError::Storage(e.to_string()))?;
292 Ok(rows)
293 }
294
295 pub fn graph_edges_for_namespace(&self, namespace: &str) -> Result<Vec<Edge>, CodememError> {
297 let conn = self.conn();
298 let mut stmt = conn
299 .prepare(
300 "SELECT e.id, e.src, e.dst, e.relationship, e.weight, e.properties, e.created_at
301 FROM graph_edges e
302 INNER JOIN graph_nodes gs ON e.src = gs.id
303 INNER JOIN graph_nodes gd ON e.dst = gd.id
304 WHERE gs.namespace = ?1 AND gd.namespace = ?1",
305 )
306 .map_err(|e| CodememError::Storage(e.to_string()))?;
307
308 let edges = stmt
309 .query_map(params![namespace], |row| {
310 let rel_str: String = row.get(3)?;
311 let props_str: String = row.get(5)?;
312 let created_ts: i64 = row.get(6)?;
313 Ok((
314 row.get::<_, String>(0)?,
315 row.get::<_, String>(1)?,
316 row.get::<_, String>(2)?,
317 rel_str,
318 row.get::<_, f64>(4)?,
319 props_str,
320 created_ts,
321 ))
322 })
323 .map_err(|e| CodememError::Storage(e.to_string()))?
324 .filter_map(|r| r.ok())
325 .filter_map(|(id, src, dst, rel_str, weight, props_str, created_ts)| {
326 let relationship: RelationshipType = rel_str.parse().ok()?;
327 let properties: HashMap<String, serde_json::Value> =
328 serde_json::from_str(&props_str).unwrap_or_default();
329 let created_at =
330 chrono::DateTime::from_timestamp(created_ts, 0)?.with_timezone(&chrono::Utc);
331 Some(Edge {
332 id,
333 src,
334 dst,
335 relationship,
336 weight,
337 properties,
338 created_at,
339 })
340 })
341 .collect();
342
343 Ok(edges)
344 }
345
346 pub fn delete_graph_edge(&self, id: &str) -> Result<bool, CodememError> {
348 let conn = self.conn();
349 let rows = conn
350 .execute("DELETE FROM graph_edges WHERE id = ?1", params![id])
351 .map_err(|e| CodememError::Storage(e.to_string()))?;
352 Ok(rows > 0)
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use crate::Storage;
359 use codemem_core::{GraphNode, MemoryNode, MemoryType, NodeKind};
360 use std::collections::HashMap;
361
362 fn test_memory() -> MemoryNode {
363 let now = chrono::Utc::now();
364 let content = "Test memory content";
365 MemoryNode {
366 id: uuid::Uuid::new_v4().to_string(),
367 content: content.to_string(),
368 memory_type: MemoryType::Context,
369 importance: 0.7,
370 confidence: 1.0,
371 access_count: 0,
372 content_hash: Storage::content_hash(content),
373 tags: vec!["test".to_string()],
374 metadata: HashMap::new(),
375 namespace: None,
376 created_at: now,
377 updated_at: now,
378 last_accessed_at: now,
379 }
380 }
381
382 #[test]
383 fn store_and_get_embedding() {
384 let storage = Storage::open_in_memory().unwrap();
385 let memory = test_memory();
386 storage.insert_memory(&memory).unwrap();
387
388 let embedding: Vec<f32> = (0..768).map(|i| i as f32 / 768.0).collect();
389 storage.store_embedding(&memory.id, &embedding).unwrap();
390
391 let retrieved = storage.get_embedding(&memory.id).unwrap().unwrap();
392 assert_eq!(retrieved.len(), 768);
393 assert!((retrieved[0] - embedding[0]).abs() < f32::EPSILON);
394 }
395
396 #[test]
397 fn graph_node_crud() {
398 let storage = Storage::open_in_memory().unwrap();
399 let node = GraphNode {
400 id: "file:src/main.rs".to_string(),
401 kind: NodeKind::File,
402 label: "src/main.rs".to_string(),
403 payload: HashMap::new(),
404 centrality: 0.0,
405 memory_id: None,
406 namespace: None,
407 };
408
409 storage.insert_graph_node(&node).unwrap();
410 let retrieved = storage.get_graph_node(&node.id).unwrap().unwrap();
411 assert_eq!(retrieved.kind, NodeKind::File);
412 assert!(storage.delete_graph_node(&node.id).unwrap());
413 }
414}