1use std::collections::HashMap;
40use std::sync::Arc;
41
42use chrono::{DateTime, Utc};
43use storage::RuVectorStore;
44
45use crate::episodic::EpisodicStore;
46use crate::graph::{EpisodicGraph, Node};
47use crate::Episode;
48
49const GRAPH_VEC: &str = "graph_vec";
53
54#[derive(Debug, thiserror::Error)]
56pub enum DualMemoryError {
57 #[error("graph read: {0}")]
58 Graph(#[from] crate::graph::GraphError),
59 #[error("legacy read: {0}")]
60 Legacy(#[from] crate::episodic::EpisodicError),
61}
62
63#[derive(Debug, Clone)]
67pub enum MemoryEntry {
68 Graph(Node),
69 Legacy(Episode),
70}
71
72impl MemoryEntry {
73 pub fn id(&self) -> &str {
74 match self {
75 MemoryEntry::Graph(n) => &n.id,
76 MemoryEntry::Legacy(e) => &e.id,
77 }
78 }
79
80 pub fn is_graph(&self) -> bool {
81 matches!(self, MemoryEntry::Graph(_))
82 }
83
84 pub fn is_legacy(&self) -> bool {
85 matches!(self, MemoryEntry::Legacy(_))
86 }
87}
88
89#[derive(Debug, Clone)]
92pub struct GraphCandidate {
93 pub content: String,
94 pub weight: f32,
95 pub created_at: DateTime<Utc>,
96}
97
98#[derive(Debug, Clone, Default)]
102pub struct GraphCandidates {
103 pub fts: Vec<(String, f64)>,
105 pub ann: Vec<(String, f64)>,
107 pub hydration: HashMap<String, GraphCandidate>,
108}
109
110#[derive(Clone)]
113pub struct DualMemoryReader {
114 graph: Option<Arc<dyn EpisodicGraph>>,
115 legacy: Option<Arc<EpisodicStore>>,
116 vectors: Option<RuVectorStore>,
119}
120
121impl DualMemoryReader {
122 pub fn graph_only(graph: Arc<dyn EpisodicGraph>) -> Self {
125 Self {
126 graph: Some(graph),
127 legacy: None,
128 vectors: None,
129 }
130 }
131
132 pub fn legacy_only(legacy: Arc<EpisodicStore>) -> Self {
136 Self {
137 graph: None,
138 legacy: Some(legacy),
139 vectors: None,
140 }
141 }
142
143 pub fn dual(legacy: Arc<EpisodicStore>, graph: Arc<dyn EpisodicGraph>) -> Self {
145 Self {
146 graph: Some(graph),
147 legacy: Some(legacy),
148 vectors: None,
149 }
150 }
151
152 pub fn with_vector_store(mut self, vectors: RuVectorStore) -> Self {
155 self.vectors = Some(vectors);
156 self
157 }
158
159 pub fn read_by_id(&self, id: &str) -> Result<Option<MemoryEntry>, DualMemoryError> {
162 if let Some(graph) = &self.graph {
163 if let Some(node) = graph.get_node(id)? {
164 return Ok(Some(MemoryEntry::Graph(node)));
165 }
166 }
167 if let Some(legacy) = &self.legacy {
168 if let Some(ep) = legacy.get_episode(id)? {
169 return Ok(Some(MemoryEntry::Legacy(ep)));
170 }
171 }
172 Ok(None)
173 }
174
175 pub async fn recall_candidates(
185 &self,
186 query: &str,
187 query_vector: Vec<f32>,
188 limit: usize,
189 namespace: Option<&str>,
190 ) -> Result<GraphCandidates, DualMemoryError> {
191 let Some(graph) = &self.graph else {
192 return Ok(GraphCandidates::default());
193 };
194 let mut out = GraphCandidates::default();
195
196 for hit in graph.search_text(query, limit, namespace)? {
198 out.fts.push((hit.id.clone(), hit.rank));
199 out.hydration.entry(hit.id).or_insert(GraphCandidate {
200 content: hit.text,
201 weight: hit.weight,
202 created_at: hit.created_at,
203 });
204 }
205
206 if let Some(vectors) = &self.vectors {
209 match vectors.search(GRAPH_VEC, query_vector, limit).await {
210 Ok(results) => {
211 for vr in results {
212 let Some(node) = graph.get_node(&vr.id)? else {
213 continue; };
215 if namespace.is_some_and(|ns| !namespace_matches(ns, &node.namespace)) {
216 continue;
217 }
218 let similarity = 1.0 / (1.0 + vr.distance as f64);
219 out.ann.push((node.id.clone(), similarity));
220 out.hydration
221 .entry(node.id.clone())
222 .or_insert_with(|| GraphCandidate {
223 content: node_content(&node),
224 weight: node.weight,
225 created_at: node.created_at,
226 });
227 }
228 }
229 Err(e) => {
230 tracing::warn!("graph_vec ANN search failed, FTS-only graph recall: {e}");
231 }
232 }
233 }
234
235 Ok(out)
236 }
237}
238
239fn namespace_matches(scope: &str, ns: &str) -> bool {
242 ns == scope || ns.starts_with(&format!("{scope}/"))
243}
244
245fn node_content(node: &Node) -> String {
248 serde_json::to_string(&node.body).unwrap_or_default()
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254 use crate::graph::{Node, NodeKind, SqliteGraph};
255 use storage::SqlitePool;
256
257 fn pool() -> SqlitePool {
258 SqlitePool::open_memory().expect("memory pool")
259 }
260
261 fn unit_vector(idx: usize) -> Vec<f32> {
264 let mut v = vec![0.0; 384];
265 v[idx % 384] = 1.0;
266 v
267 }
268
269 #[tokio::test]
270 async fn recall_candidates_returns_graph_fts_hit() {
271 let g: Arc<dyn EpisodicGraph> = Arc::new(SqliteGraph::new(pool()));
272 let mut n = Node::new(
273 NodeKind::new("tool_call"),
274 serde_json::json!({"program": "ripgrep"}),
275 "personal",
276 None,
277 );
278 n.weight = 0.7;
279 g.add_node(&n).unwrap();
280
281 let reader = DualMemoryReader::graph_only(g);
282 let cands = reader
283 .recall_candidates("ripgrep", vec![0.0; 384], 10, None)
284 .await
285 .unwrap();
286
287 assert_eq!(cands.fts.len(), 1, "FTS should surface the ripgrep node");
288 assert_eq!(cands.fts[0].0, n.id);
289 let hyd = cands.hydration.get(&n.id).expect("hydration entry");
290 assert!((hyd.weight - 0.7).abs() < 1e-6);
291 assert!(cands.ann.is_empty(), "no vector store wired → no ANN list");
292 }
293
294 #[tokio::test]
295 async fn recall_candidates_returns_graph_ann_hit() {
296 let g: Arc<dyn EpisodicGraph> = Arc::new(SqliteGraph::new(pool()));
297 let n = Node::new(
298 NodeKind::new("tool_call"),
299 serde_json::json!({"program": "opaque-binary"}),
300 "personal",
301 None,
302 );
303 g.add_node(&n).unwrap();
304
305 let dir = tempfile::tempdir().unwrap();
306 let ruv = RuVectorStore::open(dir.path(), 384).await.unwrap();
307 ruv.ensure_tables().await.unwrap();
308 let seeded = unit_vector(42);
309 ruv.add_vectors(
310 GRAPH_VEC,
311 vec![n.id.clone()],
312 vec!["opaque-binary".into()],
313 vec![seeded.clone()],
314 vec![n.created_at.to_rfc3339()],
315 "graph",
316 )
317 .await
318 .unwrap();
319
320 let reader = DualMemoryReader::graph_only(g).with_vector_store(ruv);
321 let cands = reader
324 .recall_candidates("xyzzy", seeded, 10, None)
325 .await
326 .unwrap();
327
328 assert!(cands.fts.is_empty(), "text query must not match via FTS");
329 assert_eq!(cands.ann.len(), 1, "ANN should surface the seeded node");
330 assert_eq!(cands.ann[0].0, n.id);
331 assert!(cands.ann[0].1 > 0.9, "identical vector → high similarity");
332 assert!(cands.hydration.contains_key(&n.id));
333 }
334
335 #[tokio::test]
336 async fn recall_candidates_scopes_fts_to_namespace() {
337 let g: Arc<dyn EpisodicGraph> = Arc::new(SqliteGraph::new(pool()));
338 let work = Node::new(
339 NodeKind::new("tool_call"),
340 serde_json::json!({"program": "deploy"}),
341 "work",
342 None,
343 );
344 let personal = Node::new(
345 NodeKind::new("tool_call"),
346 serde_json::json!({"program": "deploy"}),
347 "personal",
348 None,
349 );
350 g.add_node(&work).unwrap();
351 g.add_node(&personal).unwrap();
352
353 let reader = DualMemoryReader::graph_only(g);
354 let cands = reader
355 .recall_candidates("deploy", vec![0.0; 384], 10, Some("work"))
356 .await
357 .unwrap();
358 assert_eq!(cands.fts.len(), 1);
359 assert_eq!(cands.fts[0].0, work.id);
360 }
361
362 #[tokio::test]
363 async fn recall_candidates_empty_without_graph() {
364 let store = EpisodicStore::new(pool());
365 let reader = DualMemoryReader::legacy_only(Arc::new(store));
366 let cands = reader
367 .recall_candidates("anything", vec![0.0; 384], 10, None)
368 .await
369 .unwrap();
370 assert!(cands.fts.is_empty() && cands.ann.is_empty());
371 }
372
373 #[test]
374 fn graph_only_reader_finds_graph_node() {
375 let p = pool();
376 let g: Arc<dyn EpisodicGraph> = Arc::new(SqliteGraph::new(p));
377 let n = Node::new(
378 NodeKind::new("episode"),
379 serde_json::json!({"x": 1}),
380 "personal",
381 None,
382 );
383 g.add_node(&n).unwrap();
384 let r = DualMemoryReader::graph_only(g);
385 let got = r.read_by_id(&n.id).unwrap().expect("found");
386 assert!(got.is_graph());
387 assert_eq!(got.id(), n.id);
388 }
389
390 #[test]
391 fn legacy_only_reader_finds_episode() {
392 let pool = pool();
393 let store = EpisodicStore::new(pool);
394 let sid = store.create_session("test").unwrap();
395 let eid = store
396 .store_episode(&sid, "user", "hello", 0.5, None, None)
397 .unwrap();
398 let r = DualMemoryReader::legacy_only(Arc::new(store));
399 let got = r.read_by_id(&eid).unwrap().expect("found");
400 assert!(got.is_legacy());
401 assert_eq!(got.id(), &eid);
402 }
403
404 #[test]
405 fn dual_reader_prefers_graph_when_both_exist() {
406 let pool = pool();
410 let g: Arc<dyn EpisodicGraph> = Arc::new(SqliteGraph::new(pool.clone()));
411 let legacy = Arc::new(EpisodicStore::new(pool));
412
413 let sid = legacy.create_session("test").unwrap();
414 let eid = legacy
415 .store_episode(&sid, "user", "legacy text", 0.5, None, None)
416 .unwrap();
417 let n = Node {
421 id: eid.clone(),
422 session_id: Some(sid),
423 namespace: "personal".into(),
424 kind: NodeKind::new("episode"),
425 body: serde_json::json!({"text": "graph text"}),
426 vector_id: None,
427 weight: 1.0,
428 created_at: chrono::Utc::now(),
429 };
430 g.add_node(&n).unwrap();
431
432 let r = DualMemoryReader::dual(legacy, g);
433 let got = r.read_by_id(&eid).unwrap().expect("found");
434 assert!(got.is_graph(), "graph must win when both exist");
435 }
436
437 #[test]
438 fn dual_reader_falls_back_to_legacy_when_graph_misses() {
439 let pool = pool();
440 let g: Arc<dyn EpisodicGraph> = Arc::new(SqliteGraph::new(pool.clone()));
441 let legacy = Arc::new(EpisodicStore::new(pool));
442 let sid = legacy.create_session("test").unwrap();
443 let eid = legacy
444 .store_episode(&sid, "user", "only in legacy", 0.5, None, None)
445 .unwrap();
446 let r = DualMemoryReader::dual(legacy, g);
447 let got = r.read_by_id(&eid).unwrap().expect("found");
448 assert!(got.is_legacy(), "must fall back to legacy on graph miss");
449 }
450
451 #[test]
452 fn dual_reader_returns_none_when_neither_has_id() {
453 let pool = pool();
454 let g: Arc<dyn EpisodicGraph> = Arc::new(SqliteGraph::new(pool.clone()));
455 let legacy = Arc::new(EpisodicStore::new(pool));
456 let r = DualMemoryReader::dual(legacy, g);
457 assert!(r.read_by_id("does-not-exist").unwrap().is_none());
458 }
459}