1use std::collections::HashMap;
8
9use crate::episodic::{EpisodicStore, FtsResult};
10use crate::semantic::{SemanticResult, SemanticStore};
11
12#[derive(Debug, Clone)]
14pub struct Memory {
15 pub id: String,
16 pub content: String,
17 pub source: MemorySource,
18 pub score: f64,
19 pub importance: f64,
20 pub timestamp: String,
21 pub agent: Option<String>,
23}
24
25#[derive(Debug, Clone, PartialEq)]
27pub enum MemorySource {
28 Episodic,
29 Semantic,
30 Graph,
33}
34
35#[derive(Debug, Clone)]
37pub struct RecallConfig {
38 pub rrf_k: f64,
40 pub pre_fusion_limit: usize,
42 pub importance_weight: f64,
44 pub recency_weight: f64,
46 pub decay_rate: f64,
48 pub similarity_threshold: f64,
51}
52
53impl RecallConfig {
54 pub fn from_config(
56 rrf_k: u32,
57 pre_fusion_limit: u32,
58 importance_weight: f64,
59 recency_weight: f64,
60 decay_rate: f64,
61 similarity_threshold: f64,
62 ) -> Self {
63 Self {
64 rrf_k: rrf_k as f64,
65 pre_fusion_limit: pre_fusion_limit as usize,
66 importance_weight,
67 recency_weight,
68 decay_rate,
69 similarity_threshold,
70 }
71 }
72}
73
74impl Default for RecallConfig {
75 fn default() -> Self {
76 Self {
77 rrf_k: 60.0,
78 pre_fusion_limit: 50,
79 importance_weight: 0.3,
80 recency_weight: 0.2,
81 decay_rate: 0.01,
82 similarity_threshold: 0.65,
83 }
84 }
85}
86
87pub fn rrf_fuse(ranked_lists: &[Vec<(String, f64)>], k: f64) -> Vec<(String, f64)> {
92 let mut scores: HashMap<String, f64> = HashMap::new();
93
94 for list in ranked_lists {
95 for (rank, (id, _original_score)) in list.iter().enumerate() {
96 let rrf_score = 1.0 / (k + (rank as f64 + 1.0));
97 *scores.entry(id.clone()).or_default() += rrf_score;
98 }
99 }
100
101 let mut fused: Vec<(String, f64)> = scores.into_iter().collect();
102 fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
103 fused
104}
105
106pub fn forgetting_curve(importance: f64, hours_since_access: f64, decay_rate: f64) -> f64 {
124 importance * (-decay_rate * hours_since_access).exp()
125}
126
127pub struct RecallEngine {
132 config: RecallConfig,
133}
134
135impl RecallEngine {
136 pub fn new(config: RecallConfig) -> Self {
137 Self { config }
138 }
139
140 pub fn with_defaults() -> Self {
141 Self::new(RecallConfig::default())
142 }
143
144 #[allow(clippy::too_many_arguments)]
154 pub async fn recall(
155 &self,
156 query: &str,
157 query_vector: Vec<f32>,
158 episodic: &EpisodicStore,
159 semantic: &SemanticStore,
160 top_k: usize,
161 namespace: Option<&str>,
162 agent: Option<&str>,
163 graph: Option<&crate::DualMemoryReader>,
164 ) -> Result<Vec<Memory>, RecallError> {
165 let limit = self.config.pre_fusion_limit;
166 let threshold = self.config.similarity_threshold;
167
168 let bm25_results = episodic
170 .search_bm25(query, limit, namespace, agent)
171 .map_err(RecallError::Episodic)?;
172
173 let bm25_ranked: Vec<(String, f64)> = bm25_results
174 .iter()
175 .map(|r| (r.episode_id.clone(), r.rank))
176 .collect();
177
178 let ann_results = semantic
180 .search_similar(query_vector.clone(), limit, namespace, agent)
181 .await
182 .map_err(RecallError::Semantic)?;
183
184 let ann_ranked: Vec<(String, f64)> = ann_results
187 .iter()
188 .map(|r| (r.fact.id.clone(), 1.0 / (1.0 + r.distance as f64)))
189 .filter(|(_, sim)| *sim >= threshold)
190 .collect();
191
192 let graph_candidates = match graph {
198 Some(reader) => reader
199 .recall_candidates(query, query_vector, limit, namespace)
200 .await
201 .map_err(RecallError::Graph)?,
202 None => crate::GraphCandidates::default(),
203 };
204 let graph_fts_ranked = graph_candidates.fts.clone();
205 let graph_ann_ranked: Vec<(String, f64)> = graph_candidates
206 .ann
207 .iter()
208 .filter(|(_, sim)| *sim >= threshold)
209 .cloned()
210 .collect();
211
212 let fused = rrf_fuse(
214 &[bm25_ranked, ann_ranked, graph_fts_ranked, graph_ann_ranked],
215 self.config.rrf_k,
216 );
217
218 let bm25_map: HashMap<&str, &FtsResult> = bm25_results
220 .iter()
221 .map(|r| (r.episode_id.as_str(), r))
222 .collect();
223 let ann_map: HashMap<&str, &SemanticResult> = ann_results
224 .iter()
225 .map(|r| (r.fact.id.as_str(), r))
226 .collect();
227
228 let now = chrono::Utc::now();
230 let mut memories: Vec<Memory> = Vec::new();
231
232 for (id, rrf_score) in &fused {
233 if let Some(fts) = bm25_map.get(id.as_str()) {
235 let importance = fts.importance;
236 let hours = parse_elapsed_hours(&fts.timestamp, &now);
237 let retention = forgetting_curve(importance, hours, self.config.decay_rate);
238 let final_score = rrf_score
239 + self.config.importance_weight * importance
240 + self.config.recency_weight * retention;
241
242 memories.push(Memory {
243 id: id.clone(),
244 content: fts.content.clone(),
245 source: MemorySource::Episodic,
246 score: final_score,
247 importance,
248 timestamp: fts.timestamp.clone(),
249 agent: fts.agent.clone(),
250 });
251 continue;
252 }
253
254 if let Some(sr) = ann_map.get(id.as_str()) {
256 let importance = sr.fact.confidence;
257 let hours = parse_elapsed_hours(&sr.created_at, &now);
258 let retention = forgetting_curve(importance, hours, self.config.decay_rate);
259 let final_score = rrf_score
260 + self.config.importance_weight * importance
261 + self.config.recency_weight * retention;
262
263 let content = format!(
264 "{} {} {}",
265 sr.fact.subject, sr.fact.predicate, sr.fact.object
266 );
267
268 memories.push(Memory {
269 id: id.clone(),
270 content,
271 source: MemorySource::Semantic,
272 score: final_score,
273 importance,
274 timestamp: sr.created_at.clone(),
275 agent: sr.fact.agent.clone(),
276 });
277 continue;
278 }
279
280 if let Some(gc) = graph_candidates.hydration.get(id.as_str()) {
283 let importance = gc.weight as f64;
284 let timestamp = gc.created_at.to_rfc3339();
285 let hours = parse_elapsed_hours(×tamp, &now);
286 let retention = forgetting_curve(importance, hours, self.config.decay_rate);
287 let final_score = rrf_score
288 + self.config.importance_weight * importance
289 + self.config.recency_weight * retention;
290
291 memories.push(Memory {
292 id: id.clone(),
293 content: gc.content.clone(),
294 source: MemorySource::Graph,
295 score: final_score,
296 importance,
297 timestamp,
298 agent: None,
299 });
300 }
301 }
302
303 memories.sort_by(|a, b| {
305 b.score
306 .partial_cmp(&a.score)
307 .unwrap_or(std::cmp::Ordering::Equal)
308 });
309 memories.truncate(top_k);
310
311 Ok(memories)
312 }
313}
314
315fn parse_elapsed_hours(timestamp: &str, now: &chrono::DateTime<chrono::Utc>) -> f64 {
321 if timestamp.is_empty() {
322 tracing::warn!("Empty timestamp in recall — using 1.0h fallback");
323 return 1.0;
324 }
325 if let Ok(dt) = chrono::DateTime::parse_from_rfc3339(timestamp) {
327 let elapsed = *now - dt.with_timezone(&chrono::Utc);
328 return (elapsed.num_seconds() as f64 / 3600.0).max(0.01);
329 }
330 if let Ok(naive) = chrono::NaiveDateTime::parse_from_str(timestamp, "%Y-%m-%d %H:%M:%S") {
332 let dt = naive.and_utc();
333 let elapsed = *now - dt;
334 return (elapsed.num_seconds() as f64 / 3600.0).max(0.01);
335 }
336 tracing::warn!(
337 timestamp,
338 "Unparseable timestamp in recall — using 1.0h fallback"
339 );
340 1.0 }
342
343#[derive(Debug, thiserror::Error)]
345pub enum RecallError {
346 #[error("Episodic search failed: {0}")]
347 Episodic(crate::episodic::EpisodicError),
348
349 #[error("Semantic search failed: {0}")]
350 Semantic(crate::semantic::SemanticError),
351
352 #[error("Graph recall failed: {0}")]
353 Graph(crate::dual_memory::DualMemoryError),
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359 use crate::episodic::EpisodicStore;
360 use crate::graph::{EpisodicGraph, Node, NodeKind, SqliteGraph};
361 use crate::semantic::SemanticStore;
362 use crate::DualMemoryReader;
363 use std::sync::Arc;
364 use storage::{RuVectorStore, SqlitePool};
365
366 #[tokio::test]
370 async fn recall_fuses_graph_only_fts_hit() {
371 let episodic = EpisodicStore::new(SqlitePool::open_memory().unwrap());
372 let ruv_dir = tempfile::tempdir().unwrap();
373 let ruv = RuVectorStore::open(ruv_dir.path(), 384).await.unwrap();
374 ruv.ensure_tables().await.unwrap();
375 let semantic = SemanticStore::new(SqlitePool::open_memory().unwrap(), ruv);
376
377 let g: Arc<dyn EpisodicGraph> =
378 Arc::new(SqliteGraph::new(SqlitePool::open_memory().unwrap()));
379 let node = Node::new(
380 NodeKind::new("tool_call"),
381 serde_json::json!({"verb": "terminal.open", "program": "htop"}),
382 "personal",
383 None,
384 );
385 g.add_node(&node).unwrap();
386 let reader = DualMemoryReader::graph_only(g);
387
388 let engine = RecallEngine::with_defaults();
389 let results = engine
390 .recall(
391 "htop",
392 vec![0.0; 384],
393 &episodic,
394 &semantic,
395 10,
396 None,
397 None,
398 Some(&reader),
399 )
400 .await
401 .unwrap();
402
403 let graph_hit = results
404 .iter()
405 .find(|m| m.source == MemorySource::Graph)
406 .expect("the graph node must appear in recall results");
407 assert_eq!(graph_hit.id, node.id);
408 assert!(graph_hit.content.contains("htop"));
409
410 let without = engine
412 .recall(
413 "htop",
414 vec![0.0; 384],
415 &episodic,
416 &semantic,
417 10,
418 None,
419 None,
420 None,
421 )
422 .await
423 .unwrap();
424 assert!(
425 without.is_empty(),
426 "the hit must come from the graph path, not episodic/semantic"
427 );
428 }
429
430 #[test]
431 fn test_rrf_single_list() {
432 let lists = vec![vec![
433 ("a".to_string(), 10.0),
434 ("b".to_string(), 5.0),
435 ("c".to_string(), 1.0),
436 ]];
437
438 let fused = rrf_fuse(&lists, 60.0);
439 assert_eq!(fused[0].0, "a");
440 assert_eq!(fused[1].0, "b");
441 assert_eq!(fused[2].0, "c");
442
443 assert!((fused[0].1 - 1.0 / 61.0).abs() < 1e-6);
445 }
446
447 #[test]
448 fn test_rrf_two_lists() {
449 let lists = vec![
450 vec![("a".to_string(), 10.0), ("b".to_string(), 5.0)],
451 vec![("b".to_string(), 10.0), ("a".to_string(), 5.0)],
452 ];
453
454 let fused = rrf_fuse(&lists, 60.0);
455
456 assert_eq!(fused.len(), 2);
459 let score_a = fused.iter().find(|(id, _)| id == "a").unwrap().1;
460 let score_b = fused.iter().find(|(id, _)| id == "b").unwrap().1;
461 assert!((score_a - score_b).abs() < 1e-10);
462 }
463
464 #[test]
465 fn test_rrf_disjoint_lists() {
466 let lists = vec![vec![("a".to_string(), 10.0)], vec![("b".to_string(), 10.0)]];
467
468 let fused = rrf_fuse(&lists, 60.0);
469 assert_eq!(fused.len(), 2);
470 let score_a = fused.iter().find(|(id, _)| id == "a").unwrap().1;
472 let score_b = fused.iter().find(|(id, _)| id == "b").unwrap().1;
473 assert!((score_a - score_b).abs() < 1e-10);
474 }
475
476 #[test]
477 fn test_rrf_overlap_boost() {
478 let lists = vec![
479 vec![
480 ("a".to_string(), 10.0),
481 ("b".to_string(), 5.0),
482 ("c".to_string(), 1.0),
483 ],
484 vec![("a".to_string(), 10.0), ("c".to_string(), 5.0)],
485 ];
486
487 let fused = rrf_fuse(&lists, 60.0);
488
489 assert_eq!(fused[0].0, "a");
491
492 let score_b = fused.iter().find(|(id, _)| id == "b").unwrap().1;
494 let score_c = fused.iter().find(|(id, _)| id == "c").unwrap().1;
495 assert!(score_c > score_b, "c (in both) should rank > b (in one)");
496 }
497
498 #[test]
499 fn test_forgetting_curve_no_decay() {
500 let retention = forgetting_curve(1.0, 0.0, 0.01);
501 assert!((retention - 1.0).abs() < 1e-6);
502 }
503
504 #[test]
505 fn test_forgetting_curve_decay() {
506 let retention_1h = forgetting_curve(1.0, 1.0, 0.01);
507 let retention_24h = forgetting_curve(1.0, 24.0, 0.01);
508 let retention_168h = forgetting_curve(1.0, 168.0, 0.01); assert!(retention_1h > retention_24h);
512 assert!(retention_24h > retention_168h);
513
514 let retention_high = forgetting_curve(1.0, 24.0, 0.01);
516 let retention_low = forgetting_curve(0.5, 24.0, 0.01);
517 assert!(retention_high > retention_low);
518 }
519
520 #[test]
521 fn test_forgetting_curve_importance_scaling() {
522 let ret_a = forgetting_curve(1.0, 10.0, 0.01);
523 let ret_b = forgetting_curve(0.5, 10.0, 0.01);
524 assert!((ret_a / ret_b - 2.0).abs() < 1e-6);
526 }
527
528 #[test]
529 fn test_rrf_empty_lists() {
530 let fused = rrf_fuse(&[], 60.0);
531 assert!(fused.is_empty());
532
533 let fused2 = rrf_fuse(&[vec![]], 60.0);
534 assert!(fused2.is_empty());
535 }
536
537 #[test]
538 fn test_recall_config_defaults() {
539 let config = RecallConfig::default();
540 assert_eq!(config.rrf_k, 60.0);
541 assert_eq!(config.pre_fusion_limit, 50);
542 assert!((config.importance_weight - 0.3).abs() < 1e-6);
543 assert!((config.recency_weight - 0.2).abs() < 1e-6);
544 }
545
546 use proptest::prelude::*;
554 use std::collections::HashSet;
555
556 fn ranked_list() -> impl Strategy<Value = Vec<(String, f64)>> {
559 prop::collection::vec(0u32..50, 0..12).prop_map(|idxs| {
560 let mut seen = HashSet::new();
561 idxs.into_iter()
562 .filter(|i| seen.insert(*i))
563 .map(|i| (format!("id{i}"), 0.0))
564 .collect()
565 })
566 }
567
568 fn ranked_lists() -> impl Strategy<Value = Vec<Vec<(String, f64)>>> {
569 prop::collection::vec(ranked_list(), 0..5)
570 }
571
572 proptest! {
573 #![proptest_config(ProptestConfig { cases: 512, .. ProptestConfig::default() })]
574
575 #[test]
577 fn rrf_output_is_sorted_descending(lists in ranked_lists(), k in 0.5f64..200.0) {
578 let fused = rrf_fuse(&lists, k);
579 for w in fused.windows(2) {
580 prop_assert!(w[0].1 >= w[1].1);
581 }
582 }
583
584 #[test]
587 fn rrf_output_is_exactly_the_union(lists in ranked_lists(), k in 0.5f64..200.0) {
588 let fused = rrf_fuse(&lists, k);
589 let union: HashSet<&str> = lists
590 .iter()
591 .flat_map(|l| l.iter().map(|(id, _)| id.as_str()))
592 .collect();
593 let got: HashSet<&str> = fused.iter().map(|(id, _)| id.as_str()).collect();
594 prop_assert_eq!(fused.len(), union.len(), "no duplicate ids in output");
595 prop_assert_eq!(got, union);
596 }
597
598 #[test]
601 fn rrf_scores_are_positive_and_bounded(lists in ranked_lists(), k in 0.5f64..200.0) {
602 let fused = rrf_fuse(&lists, k);
603 let ceiling = lists.len() as f64 / (k + 1.0);
604 for (_, score) in &fused {
605 prop_assert!(*score > 0.0);
606 prop_assert!(*score <= ceiling + 1e-12);
607 }
608 }
609
610 #[test]
614 fn rrf_is_additive_over_repeated_lists(list in ranked_list(), k in 0.5f64..200.0) {
615 let single = rrf_fuse(std::slice::from_ref(&list), k);
616 let doubled = rrf_fuse(&[list.clone(), list], k);
617 let single_map: HashMap<&str, f64> =
618 single.iter().map(|(id, s)| (id.as_str(), *s)).collect();
619 for (id, s) in &doubled {
620 let one = single_map[id.as_str()];
621 prop_assert!((s - 2.0 * one).abs() <= 1e-12);
622 }
623 }
624
625 #[test]
628 fn forgetting_curve_is_linear_in_importance(
629 importance in 0.0f64..1.0,
630 hours in 0.0f64..10_000.0,
631 decay in 0.0f64..1.0,
632 factor in 0.0f64..5.0,
633 ) {
634 let base = forgetting_curve(importance, hours, decay);
635 let scaled = forgetting_curve(importance * factor, hours, decay);
636 prop_assert!((scaled - factor * base).abs() <= 1e-9 + base.abs() * 1e-9);
637 }
638
639 #[test]
643 fn forgetting_curve_stays_within_zero_and_importance(
644 importance in 0.0f64..1e6,
645 hours in 0.0f64..10_000.0,
646 decay in 0.0f64..1.0,
647 ) {
648 let r = forgetting_curve(importance, hours, decay);
649 prop_assert!(r >= 0.0);
650 prop_assert!(r <= importance + 1e-9);
651 }
652
653 #[test]
656 fn forgetting_curve_is_monotone_in_elapsed(
657 importance in 0.0f64..1e3,
658 decay in 0.0f64..1.0,
659 h1 in 0.0f64..10_000.0,
660 h2 in 0.0f64..10_000.0,
661 ) {
662 let (lo, hi) = if h1 <= h2 { (h1, h2) } else { (h2, h1) };
663 prop_assert!(forgetting_curve(importance, lo, decay) + 1e-12
664 >= forgetting_curve(importance, hi, decay));
665 }
666
667 #[test]
670 fn forgetting_curve_is_identity_without_decay(
671 importance in 0.0f64..1e3,
672 hours in 0.0f64..10_000.0,
673 decay in 0.0f64..1.0,
674 ) {
675 prop_assert_eq!(forgetting_curve(importance, 0.0, decay), importance);
676 prop_assert_eq!(forgetting_curve(importance, hours, 0.0), importance);
677 }
678 }
679}