1use std::collections::HashMap;
2
3pub fn fuse<T: Clone + Eq + std::hash::Hash + Ord>(
12 ranked_lists: &[Vec<T>],
13 k: usize,
14) -> Vec<(T, f64)> {
15 let mut score: HashMap<T, f64> = HashMap::new();
16
17 for list in ranked_lists {
18 for (rank, doc) in list.iter().enumerate() {
19 let contribution = 1.0 / ((k + rank + 1) as f64);
20 *score.entry(doc.clone()).or_insert(0.0) += contribution;
21 }
22 }
23
24 let mut result: Vec<(T, f64)> = score.into_iter().collect();
25 result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap().then_with(|| a.0.cmp(&b.0)));
26 result
27}
28
29pub fn fuse_weighted<T: Clone + Eq + std::hash::Hash + Ord>(
31 ranked_lists: &[Vec<T>],
32 weights: &[f64],
33 k: usize,
34) -> Vec<(T, f64)> {
35 assert_eq!(
36 ranked_lists.len(),
37 weights.len(),
38 "ranked_lists 和 weights 长度必须一致"
39 );
40
41 let mut score: HashMap<T, f64> = HashMap::new();
42
43 for (list, &w) in ranked_lists.iter().zip(weights.iter()) {
44 for (rank, doc) in list.iter().enumerate() {
45 let contribution = w * (1.0 / ((k + rank + 1) as f64));
46 *score.entry(doc.clone()).or_insert(0.0) += contribution;
47 }
48 }
49
50 let mut result: Vec<(T, f64)> = score.into_iter().collect();
51 result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap().then_with(|| a.0.cmp(&b.0)));
52 result
53}
54
55#[cfg(test)]
56mod tests {
57 use super::*;
58
59 #[test]
60 fn test_rrf_basic() {
61 let bm25 = vec!["D3", "D1", "D2", "D5"];
62 let vector = vec!["D2", "D4", "D1"];
63 let rules = vec!["D5", "D2", "D6"];
64
65 let fused = fuse(&[bm25.clone(), vector.clone(), rules.clone()], 60);
66 assert!(fused.iter().any(|(d, _)| d == &"D2"));
67 }
68
69 #[test]
70 fn test_rrf_weighted() {
71 let bm25 = vec!["D3", "D1", "D2", "D5"];
72 let vector = vec!["D2", "D4", "D1"];
73 let rules = vec!["D5", "D2", "D6"];
74
75 let fused_w = fuse_weighted(&[bm25, vector, rules], &[1.0, 2.0, 0.5], 60);
76 assert!(fused_w[0].0 == "D2"); }
78}