Skip to main content

mnemo_core/query/
retrieval.rs

1use std::collections::HashMap;
2use uuid::Uuid;
3
4/// Weighted Reciprocal Rank Fusion: combines multiple ranked lists with per-list weights.
5/// Each item's score = sum over all lists of weights[i] / (k + rank_in_list + 1.0).
6/// If weights is empty or shorter than ranked_lists, uses 1.0 as default weight.
7pub fn weighted_reciprocal_rank_fusion(
8    ranked_lists: &[Vec<(Uuid, f32)>],
9    k: f32,
10    weights: &[f32],
11) -> Vec<(Uuid, f32)> {
12    let mut scores: HashMap<Uuid, f32> = HashMap::new();
13    for (i, list) in ranked_lists.iter().enumerate() {
14        let w = weights.get(i).copied().unwrap_or(1.0);
15        for (rank, (id, _original_score)) in list.iter().enumerate() {
16            *scores.entry(*id).or_insert(0.0) += w / (k + rank as f32 + 1.0);
17        }
18    }
19    let mut fused: Vec<(Uuid, f32)> = scores.into_iter().collect();
20    fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
21    fused
22}
23
24/// Reciprocal Rank Fusion: combines multiple ranked lists into a single fused ranking.
25/// Each item's score = sum over all lists of 1/(k + rank_in_list).
26/// Items not present in a list are not penalized.
27pub fn reciprocal_rank_fusion(ranked_lists: &[Vec<(Uuid, f32)>], k: f32) -> Vec<(Uuid, f32)> {
28    weighted_reciprocal_rank_fusion(ranked_lists, k, &[])
29}
30
31/// Compute a recency score using exponential decay.
32/// Returns a value in [0, 1] where 1 = just created, 0 = very old.
33/// half_life_hours controls decay rate (e.g., 168 = 1 week half-life).
34pub fn recency_score(created_at: &str, half_life_hours: f64) -> f32 {
35    let now = chrono::Utc::now();
36    let created = match chrono::DateTime::parse_from_rfc3339(created_at) {
37        Ok(dt) => dt.with_timezone(&chrono::Utc),
38        Err(_) => return 0.5, // fallback for unparseable dates
39    };
40    let age_hours = (now - created).num_seconds() as f64 / 3600.0;
41    if age_hours < 0.0 {
42        return 1.0; // future timestamp
43    }
44    let decay = (-age_hours * (2.0_f64.ln()) / half_life_hours).exp();
45    decay as f32
46}
47
48#[cfg(test)]
49mod tests {
50    use super::*;
51
52    #[test]
53    fn test_rrf_basic() {
54        let id1 = Uuid::now_v7();
55        let id2 = Uuid::now_v7();
56        let id3 = Uuid::now_v7();
57
58        let list1 = vec![(id1, 0.9), (id2, 0.8), (id3, 0.7)];
59        let list2 = vec![(id2, 0.95), (id1, 0.85), (id3, 0.75)];
60
61        let fused = reciprocal_rank_fusion(&[list1, list2], 60.0);
62        assert_eq!(fused.len(), 3);
63        // id1 and id2 should be top since they appear in both lists
64        // id1: 1/(60+1) + 1/(60+2) = ~0.0164 + ~0.0161 = ~0.0325
65        // id2: 1/(60+2) + 1/(60+1) = ~0.0161 + ~0.0164 = ~0.0325
66        // They should have equal scores since they swap ranks
67        assert!(fused[0].1 > 0.0);
68    }
69
70    #[test]
71    fn test_rrf_disjoint() {
72        let id1 = Uuid::now_v7();
73        let id2 = Uuid::now_v7();
74
75        let list1 = vec![(id1, 0.9)];
76        let list2 = vec![(id2, 0.8)];
77
78        let fused = reciprocal_rank_fusion(&[list1, list2], 60.0);
79        assert_eq!(fused.len(), 2);
80        // Both get same score: 1/(60+1)
81        assert!((fused[0].1 - fused[1].1).abs() < 0.0001);
82    }
83
84    #[test]
85    fn test_rrf_single_list() {
86        let id1 = Uuid::now_v7();
87        let id2 = Uuid::now_v7();
88
89        let list1 = vec![(id1, 0.9), (id2, 0.8)];
90        let fused = reciprocal_rank_fusion(&[list1], 60.0);
91        assert_eq!(fused.len(), 2);
92        assert!(fused[0].1 > fused[1].1);
93    }
94
95    #[test]
96    fn test_weighted_rrf() {
97        let id1 = Uuid::now_v7();
98        let id2 = Uuid::now_v7();
99
100        let list1 = vec![(id1, 0.9), (id2, 0.8)];
101        let list2 = vec![(id2, 0.95), (id1, 0.85)];
102
103        // With weight [2.0, 1.0], list1 should have more influence
104        let fused =
105            weighted_reciprocal_rank_fusion(&[list1.clone(), list2.clone()], 60.0, &[2.0, 1.0]);
106        assert_eq!(fused.len(), 2);
107        // id1 is rank 0 in list1 (weight 2.0) and rank 1 in list2 (weight 1.0)
108        // id1 score = 2.0/(60+1) + 1.0/(60+2) = ~0.0328 + ~0.0161 = ~0.0489
109        // id2 is rank 1 in list1 (weight 2.0) and rank 0 in list2 (weight 1.0)
110        // id2 score = 2.0/(60+2) + 1.0/(60+1) = ~0.0323 + ~0.0164 = ~0.0487
111        // id1 should score slightly higher
112        assert_eq!(fused[0].0, id1);
113    }
114
115    #[test]
116    fn test_recency_score() {
117        // Just created
118        let now = chrono::Utc::now().to_rfc3339();
119        let score = recency_score(&now, 168.0);
120        assert!(score > 0.99);
121
122        // Very old
123        let old = (chrono::Utc::now() - chrono::Duration::days(365)).to_rfc3339();
124        let score = recency_score(&old, 168.0);
125        assert!(score < 0.01);
126
127        // One week ago (half-life = 168 hours)
128        let week_ago = (chrono::Utc::now() - chrono::Duration::hours(168)).to_rfc3339();
129        let score = recency_score(&week_ago, 168.0);
130        assert!((score - 0.5).abs() < 0.05);
131    }
132}