mnemo_core/query/
retrieval.rs1use std::collections::HashMap;
2use uuid::Uuid;
3
4pub fn weighted_reciprocal_rank_fusion(
8 ranked_lists: &[Vec<(Uuid, f32)>],
9 k: f32,
10 weights: &[f32],
11) -> Vec<(Uuid, f32)> {
12 let mut scores: HashMap<Uuid, f32> = HashMap::new();
13 for (i, list) in ranked_lists.iter().enumerate() {
14 let w = weights.get(i).copied().unwrap_or(1.0);
15 for (rank, (id, _original_score)) in list.iter().enumerate() {
16 *scores.entry(*id).or_insert(0.0) += w / (k + rank as f32 + 1.0);
17 }
18 }
19 let mut fused: Vec<(Uuid, f32)> = scores.into_iter().collect();
20 fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
21 fused
22}
23
24pub fn reciprocal_rank_fusion(ranked_lists: &[Vec<(Uuid, f32)>], k: f32) -> Vec<(Uuid, f32)> {
28 weighted_reciprocal_rank_fusion(ranked_lists, k, &[])
29}
30
31pub fn recency_score(created_at: &str, half_life_hours: f64) -> f32 {
35 let now = chrono::Utc::now();
36 let created = match chrono::DateTime::parse_from_rfc3339(created_at) {
37 Ok(dt) => dt.with_timezone(&chrono::Utc),
38 Err(_) => return 0.5, };
40 let age_hours = (now - created).num_seconds() as f64 / 3600.0;
41 if age_hours < 0.0 {
42 return 1.0; }
44 let decay = (-age_hours * (2.0_f64.ln()) / half_life_hours).exp();
45 decay as f32
46}
47
48#[cfg(test)]
49mod tests {
50 use super::*;
51
52 #[test]
53 fn test_rrf_basic() {
54 let id1 = Uuid::now_v7();
55 let id2 = Uuid::now_v7();
56 let id3 = Uuid::now_v7();
57
58 let list1 = vec![(id1, 0.9), (id2, 0.8), (id3, 0.7)];
59 let list2 = vec![(id2, 0.95), (id1, 0.85), (id3, 0.75)];
60
61 let fused = reciprocal_rank_fusion(&[list1, list2], 60.0);
62 assert_eq!(fused.len(), 3);
63 assert!(fused[0].1 > 0.0);
68 }
69
70 #[test]
71 fn test_rrf_disjoint() {
72 let id1 = Uuid::now_v7();
73 let id2 = Uuid::now_v7();
74
75 let list1 = vec![(id1, 0.9)];
76 let list2 = vec![(id2, 0.8)];
77
78 let fused = reciprocal_rank_fusion(&[list1, list2], 60.0);
79 assert_eq!(fused.len(), 2);
80 assert!((fused[0].1 - fused[1].1).abs() < 0.0001);
82 }
83
84 #[test]
85 fn test_rrf_single_list() {
86 let id1 = Uuid::now_v7();
87 let id2 = Uuid::now_v7();
88
89 let list1 = vec![(id1, 0.9), (id2, 0.8)];
90 let fused = reciprocal_rank_fusion(&[list1], 60.0);
91 assert_eq!(fused.len(), 2);
92 assert!(fused[0].1 > fused[1].1);
93 }
94
95 #[test]
96 fn test_weighted_rrf() {
97 let id1 = Uuid::now_v7();
98 let id2 = Uuid::now_v7();
99
100 let list1 = vec![(id1, 0.9), (id2, 0.8)];
101 let list2 = vec![(id2, 0.95), (id1, 0.85)];
102
103 let fused =
105 weighted_reciprocal_rank_fusion(&[list1.clone(), list2.clone()], 60.0, &[2.0, 1.0]);
106 assert_eq!(fused.len(), 2);
107 assert_eq!(fused[0].0, id1);
113 }
114
115 #[test]
116 fn test_recency_score() {
117 let now = chrono::Utc::now().to_rfc3339();
119 let score = recency_score(&now, 168.0);
120 assert!(score > 0.99);
121
122 let old = (chrono::Utc::now() - chrono::Duration::days(365)).to_rfc3339();
124 let score = recency_score(&old, 168.0);
125 assert!(score < 0.01);
126
127 let week_ago = (chrono::Utc::now() - chrono::Duration::hours(168)).to_rfc3339();
129 let score = recency_score(&week_ago, 168.0);
130 assert!((score - 0.5).abs() < 0.05);
131 }
132}