uni_query/query/
fusion.rs1use std::collections::{HashMap, HashSet};
9use uni_common::Vid;
10
11pub fn fuse_rrf(
16 vec_results: &[(Vid, f32)],
17 fts_results: &[(Vid, f32)],
18 k: usize,
19) -> Vec<(Vid, f32)> {
20 let mut scores: HashMap<Vid, f32> = HashMap::new();
21
22 for ranked_list in [vec_results, fts_results] {
23 for (rank, (vid, _)) in ranked_list.iter().enumerate() {
24 let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
25 *scores.entry(*vid).or_default() += rrf_score;
26 }
27 }
28
29 let mut results: Vec<_> = scores.into_iter().collect();
30 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
31 results
32}
33
34pub fn fuse_weighted(
40 vec_results: &[(Vid, f32)],
41 fts_results: &[(Vid, f32)],
42 alpha: f32,
43) -> Vec<(Vid, f32)> {
44 let vec_max = vec_results.iter().map(|(_, s)| *s).fold(f32::MIN, f32::max);
46 let vec_min = vec_results.iter().map(|(_, s)| *s).fold(f32::MAX, f32::min);
47 let vec_range = if vec_max > vec_min {
48 vec_max - vec_min
49 } else {
50 1.0
51 };
52
53 let fts_max = fts_results.iter().map(|(_, s)| *s).fold(0.0f32, f32::max);
54
55 let vec_scores: HashMap<Vid, f32> = vec_results
56 .iter()
57 .map(|(vid, dist)| {
58 let norm = 1.0 - (dist - vec_min) / vec_range;
59 (*vid, norm)
60 })
61 .collect();
62
63 let fts_scores: HashMap<Vid, f32> = fts_results
64 .iter()
65 .map(|(vid, score)| {
66 let norm = if fts_max > 0.0 { score / fts_max } else { 0.0 };
67 (*vid, norm)
68 })
69 .collect();
70
71 let all_vids: HashSet<Vid> = vec_scores
72 .keys()
73 .chain(fts_scores.keys())
74 .cloned()
75 .collect();
76
77 let mut results: Vec<(Vid, f32)> = all_vids
78 .into_iter()
79 .map(|vid| {
80 let vec_score = *vec_scores.get(&vid).unwrap_or(&0.0);
81 let fts_score = *fts_scores.get(&vid).unwrap_or(&0.0);
82 let fused = alpha * vec_score + (1.0 - alpha) * fts_score;
83 (vid, fused)
84 })
85 .collect();
86
87 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
88 results
89}
90
91pub fn fuse_weighted_multi(scores: &[f32], weights: &[f32]) -> f32 {
96 debug_assert_eq!(scores.len(), weights.len());
97 scores.iter().zip(weights.iter()).map(|(s, w)| s * w).sum()
98}
99
100pub fn fuse_rrf_point(scores: &[f32]) -> (f32, bool) {
108 if scores.is_empty() {
109 return (0.0, false);
110 }
111 let weight = 1.0 / scores.len() as f32;
112 let fused: f32 = scores.iter().map(|s| s * weight).sum();
113 (fused, true)
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119
120 #[test]
121 fn test_fuse_weighted_multi() {
122 let scores = vec![0.8, 0.6];
123 let weights = vec![0.7, 0.3];
124 let result = fuse_weighted_multi(&scores, &weights);
125 assert!((result - 0.74).abs() < 1e-6);
126 }
127
128 #[test]
129 fn test_fuse_weighted_multi_equal() {
130 let scores = vec![0.5, 0.5, 0.5];
131 let weights = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
132 let result = fuse_weighted_multi(&scores, &weights);
133 assert!((result - 0.5).abs() < 1e-6);
134 }
135
136 #[test]
137 fn test_fuse_rrf_point_fallback() {
138 let scores = vec![0.8, 0.6];
139 let (result, used_fallback) = fuse_rrf_point(&scores);
140 assert!(used_fallback);
141 assert!((result - 0.7).abs() < 1e-6);
142 }
143
144 #[test]
145 fn test_fuse_rrf_point_empty() {
146 let (result, used_fallback) = fuse_rrf_point(&[]);
147 assert!(!used_fallback);
148 assert!((result - 0.0).abs() < 1e-6);
149 }
150}