rrf/
lib.rs

1use std::collections::HashMap;
2
3/// Reciprocal Rank Fusion (RRF) 融合
4///
5/// # Arguments
6/// * `ranked_lists` - 多个候选排序结果,每个内部 Vec 是按相关度降序的文档ID
7/// * `k` - 平滑参数,常用 60
8///
9/// # Returns
10/// 返回一个 Vec<(文档ID, 分数)>,按分数降序排列
11pub fn fuse<T: Clone + Eq + std::hash::Hash + Ord>(
12    ranked_lists: &[Vec<T>],
13    k: usize,
14) -> Vec<(T, f64)> {
15    let mut score: HashMap<T, f64> = HashMap::new();
16
17    for list in ranked_lists {
18        for (rank, doc) in list.iter().enumerate() {
19            let contribution = 1.0 / ((k + rank + 1) as f64);
20            *score.entry(doc.clone()).or_insert(0.0) += contribution;
21        }
22    }
23
24    let mut result: Vec<(T, f64)> = score.into_iter().collect();
25    result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap().then_with(|| a.0.cmp(&b.0)));
26    result
27}
28
29/// 带权重的 RRF 融合
30pub fn fuse_weighted<T: Clone + Eq + std::hash::Hash + Ord>(
31    ranked_lists: &[Vec<T>],
32    weights: &[f64],
33    k: usize,
34) -> Vec<(T, f64)> {
35    assert_eq!(
36        ranked_lists.len(),
37        weights.len(),
38        "ranked_lists 和 weights 长度必须一致"
39    );
40
41    let mut score: HashMap<T, f64> = HashMap::new();
42
43    for (list, &w) in ranked_lists.iter().zip(weights.iter()) {
44        for (rank, doc) in list.iter().enumerate() {
45            let contribution = w * (1.0 / ((k + rank + 1) as f64));
46            *score.entry(doc.clone()).or_insert(0.0) += contribution;
47        }
48    }
49
50    let mut result: Vec<(T, f64)> = score.into_iter().collect();
51    result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap().then_with(|| a.0.cmp(&b.0)));
52    result
53}
54
55#[cfg(test)]
56mod tests {
57    use super::*;
58
59    #[test]
60    fn test_rrf_basic() {
61        let bm25 = vec!["D3", "D1", "D2", "D5"];
62        let vector = vec!["D2", "D4", "D1"];
63        let rules = vec!["D5", "D2", "D6"];
64
65        let fused = fuse(&[bm25.clone(), vector.clone(), rules.clone()], 60);
66        assert!(fused.iter().any(|(d, _)| d == &"D2"));
67    }
68
69    #[test]
70    fn test_rrf_weighted() {
71        let bm25 = vec!["D3", "D1", "D2", "D5"];
72        let vector = vec!["D2", "D4", "D1"];
73        let rules = vec!["D5", "D2", "D6"];
74
75        let fused_w = fuse_weighted(&[bm25, vector, rules], &[1.0, 2.0, 0.5], 60);
76        assert!(fused_w[0].0 == "D2"); // D2 应该最靠前
77    }
78}