use std::collections::HashMap;
use crate::episodic::EpisodicStore;
use crate::semantic::SemanticStore;
#[derive(Debug, Clone)]
pub struct Memory {
pub id: String,
pub content: String,
pub source: MemorySource,
pub score: f64,
pub importance: f64,
pub timestamp: String,
pub agent: Option<String>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum MemorySource {
Episodic,
Semantic,
}
#[derive(Debug, Clone)]
pub struct RecallConfig {
pub rrf_k: f64,
pub pre_fusion_limit: usize,
pub importance_weight: f64,
pub recency_weight: f64,
pub decay_rate: f64,
pub similarity_threshold: f64,
}
impl RecallConfig {
pub fn from_config(
rrf_k: u32,
pre_fusion_limit: u32,
importance_weight: f64,
recency_weight: f64,
decay_rate: f64,
similarity_threshold: f64,
) -> Self {
Self {
rrf_k: rrf_k as f64,
pre_fusion_limit: pre_fusion_limit as usize,
importance_weight,
recency_weight,
decay_rate,
similarity_threshold,
}
}
}
impl Default for RecallConfig {
fn default() -> Self {
Self {
rrf_k: 60.0,
pre_fusion_limit: 50,
importance_weight: 0.3,
recency_weight: 0.2,
decay_rate: 0.01,
similarity_threshold: 0.65,
}
}
}
pub fn rrf_fuse(ranked_lists: &[Vec<(String, f64)>], k: f64) -> Vec<(String, f64)> {
let mut scores: HashMap<String, f64> = HashMap::new();
for list in ranked_lists {
for (rank, (id, _original_score)) in list.iter().enumerate() {
let rrf_score = 1.0 / (k + (rank as f64 + 1.0));
*scores.entry(id.clone()).or_default() += rrf_score;
}
}
let mut fused: Vec<(String, f64)> = scores.into_iter().collect();
fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
fused
}
pub fn forgetting_curve(importance: f64, hours_since_access: f64, decay_rate: f64) -> f64 {
importance * (-decay_rate * hours_since_access).exp()
}
pub struct RecallEngine {
config: RecallConfig,
}
impl RecallEngine {
pub fn new(config: RecallConfig) -> Self {
Self { config }
}
pub fn with_defaults() -> Self {
Self::new(RecallConfig::default())
}
#[allow(clippy::too_many_arguments)]
pub async fn recall(
&self,
query: &str,
query_vector: Vec<f32>,
episodic: &EpisodicStore,
semantic: &SemanticStore,
top_k: usize,
namespace: Option<&str>,
agent: Option<&str>,
) -> Result<Vec<Memory>, RecallError> {
let limit = self.config.pre_fusion_limit;
let bm25_results = episodic
.search_bm25(query, limit, namespace, agent)
.map_err(RecallError::Episodic)?;
let bm25_ranked: Vec<(String, f64)> = bm25_results
.iter()
.map(|r| (r.episode_id.clone(), r.rank))
.collect();
let ann_results = semantic
.search_similar(query_vector, limit, namespace, agent)
.await
.map_err(RecallError::Semantic)?;
let threshold = self.config.similarity_threshold;
let ann_ranked: Vec<(String, f64)> = ann_results
.iter()
.map(|r| (r.fact.id.clone(), 1.0 / (1.0 + r.distance as f64)))
.filter(|(_, sim)| *sim >= threshold)
.collect();
let fused = rrf_fuse(&[bm25_ranked, ann_ranked], self.config.rrf_k);
let now = chrono::Utc::now();
let mut memories: Vec<Memory> = Vec::new();
for (id, rrf_score) in &fused {
if let Some(fts) = bm25_results.iter().find(|r| &r.episode_id == id) {
let importance = fts.importance;
let hours = parse_elapsed_hours(&fts.timestamp, &now);
let retention = forgetting_curve(importance, hours, self.config.decay_rate);
let final_score = rrf_score
+ self.config.importance_weight * importance
+ self.config.recency_weight * retention;
memories.push(Memory {
id: id.clone(),
content: fts.content.clone(),
source: MemorySource::Episodic,
score: final_score,
importance,
timestamp: fts.timestamp.clone(),
agent: fts.agent.clone(),
});
continue;
}
if let Some(sr) = ann_results.iter().find(|r| &r.fact.id == id) {
let importance = sr.fact.confidence;
let hours = parse_elapsed_hours(&sr.created_at, &now);
let retention = forgetting_curve(importance, hours, self.config.decay_rate);
let final_score = rrf_score
+ self.config.importance_weight * importance
+ self.config.recency_weight * retention;
let content = format!(
"{} {} {}",
sr.fact.subject, sr.fact.predicate, sr.fact.object
);
memories.push(Memory {
id: id.clone(),
content,
source: MemorySource::Semantic,
score: final_score,
importance,
timestamp: sr.created_at.clone(),
agent: sr.fact.agent.clone(),
});
}
}
memories.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
memories.truncate(top_k);
Ok(memories)
}
}
fn parse_elapsed_hours(timestamp: &str, now: &chrono::DateTime<chrono::Utc>) -> f64 {
if timestamp.is_empty() {
tracing::warn!("Empty timestamp in recall — using 1.0h fallback");
return 1.0;
}
if let Ok(dt) = chrono::DateTime::parse_from_rfc3339(timestamp) {
let elapsed = *now - dt.with_timezone(&chrono::Utc);
return (elapsed.num_seconds() as f64 / 3600.0).max(0.01);
}
if let Ok(naive) = chrono::NaiveDateTime::parse_from_str(timestamp, "%Y-%m-%d %H:%M:%S") {
let dt = naive.and_utc();
let elapsed = *now - dt;
return (elapsed.num_seconds() as f64 / 3600.0).max(0.01);
}
tracing::warn!(
timestamp,
"Unparseable timestamp in recall — using 1.0h fallback"
);
1.0 }
#[derive(Debug, thiserror::Error)]
pub enum RecallError {
#[error("Episodic search failed: {0}")]
Episodic(crate::episodic::EpisodicError),
#[error("Semantic search failed: {0}")]
Semantic(crate::semantic::SemanticError),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rrf_single_list() {
let lists = vec![vec![
("a".to_string(), 10.0),
("b".to_string(), 5.0),
("c".to_string(), 1.0),
]];
let fused = rrf_fuse(&lists, 60.0);
assert_eq!(fused[0].0, "a");
assert_eq!(fused[1].0, "b");
assert_eq!(fused[2].0, "c");
assert!((fused[0].1 - 1.0 / 61.0).abs() < 1e-6);
}
#[test]
fn test_rrf_two_lists() {
let lists = vec![
vec![("a".to_string(), 10.0), ("b".to_string(), 5.0)],
vec![("b".to_string(), 10.0), ("a".to_string(), 5.0)],
];
let fused = rrf_fuse(&lists, 60.0);
assert_eq!(fused.len(), 2);
let score_a = fused.iter().find(|(id, _)| id == "a").unwrap().1;
let score_b = fused.iter().find(|(id, _)| id == "b").unwrap().1;
assert!((score_a - score_b).abs() < 1e-10);
}
#[test]
fn test_rrf_disjoint_lists() {
let lists = vec![vec![("a".to_string(), 10.0)], vec![("b".to_string(), 10.0)]];
let fused = rrf_fuse(&lists, 60.0);
assert_eq!(fused.len(), 2);
let score_a = fused.iter().find(|(id, _)| id == "a").unwrap().1;
let score_b = fused.iter().find(|(id, _)| id == "b").unwrap().1;
assert!((score_a - score_b).abs() < 1e-10);
}
#[test]
fn test_rrf_overlap_boost() {
let lists = vec![
vec![
("a".to_string(), 10.0),
("b".to_string(), 5.0),
("c".to_string(), 1.0),
],
vec![("a".to_string(), 10.0), ("c".to_string(), 5.0)],
];
let fused = rrf_fuse(&lists, 60.0);
assert_eq!(fused[0].0, "a");
let score_b = fused.iter().find(|(id, _)| id == "b").unwrap().1;
let score_c = fused.iter().find(|(id, _)| id == "c").unwrap().1;
assert!(score_c > score_b, "c (in both) should rank > b (in one)");
}
#[test]
fn test_forgetting_curve_no_decay() {
let retention = forgetting_curve(1.0, 0.0, 0.01);
assert!((retention - 1.0).abs() < 1e-6);
}
#[test]
fn test_forgetting_curve_decay() {
let retention_1h = forgetting_curve(1.0, 1.0, 0.01);
let retention_24h = forgetting_curve(1.0, 24.0, 0.01);
let retention_168h = forgetting_curve(1.0, 168.0, 0.01);
assert!(retention_1h > retention_24h);
assert!(retention_24h > retention_168h);
let retention_high = forgetting_curve(1.0, 24.0, 0.01);
let retention_low = forgetting_curve(0.5, 24.0, 0.01);
assert!(retention_high > retention_low);
}
#[test]
fn test_forgetting_curve_importance_scaling() {
let ret_a = forgetting_curve(1.0, 10.0, 0.01);
let ret_b = forgetting_curve(0.5, 10.0, 0.01);
assert!((ret_a / ret_b - 2.0).abs() < 1e-6);
}
#[test]
fn test_rrf_empty_lists() {
let fused = rrf_fuse(&[], 60.0);
assert!(fused.is_empty());
let fused2 = rrf_fuse(&[vec![]], 60.0);
assert!(fused2.is_empty());
}
#[test]
fn test_recall_config_defaults() {
let config = RecallConfig::default();
assert_eq!(config.rrf_k, 60.0);
assert_eq!(config.pre_fusion_limit, 50);
assert!((config.importance_weight - 0.3).abs() < 1e-6);
assert!((config.recency_weight - 0.2).abs() < 1e-6);
}
}