use std::collections::{HashMap, HashSet};
use uni_common::Vid;
pub fn fuse_rrf_multi(ranked_lists: &[&[(Vid, f32)]], k: usize) -> Vec<(Vid, f32)> {
let mut scores: HashMap<Vid, f32> = HashMap::new();
for ranked_list in ranked_lists {
for (rank, (vid, _)) in ranked_list.iter().enumerate() {
let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
*scores.entry(*vid).or_default() += rrf_score;
}
}
let mut results: Vec<_> = scores.into_iter().collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results
}
pub fn fuse_rrf(
vec_results: &[(Vid, f32)],
fts_results: &[(Vid, f32)],
k: usize,
) -> Vec<(Vid, f32)> {
fuse_rrf_multi(&[vec_results, fts_results], k)
}
pub fn fuse_weighted(
vec_results: &[(Vid, f32)],
fts_results: &[(Vid, f32)],
alpha: f32,
) -> Vec<(Vid, f32)> {
let vec_max = vec_results.iter().map(|(_, s)| *s).fold(f32::MIN, f32::max);
let vec_min = vec_results.iter().map(|(_, s)| *s).fold(f32::MAX, f32::min);
let vec_range = if vec_max > vec_min {
vec_max - vec_min
} else {
1.0
};
let fts_max = fts_results.iter().map(|(_, s)| *s).fold(0.0f32, f32::max);
let vec_scores: HashMap<Vid, f32> = vec_results
.iter()
.map(|(vid, dist)| {
let norm = 1.0 - (dist - vec_min) / vec_range;
(*vid, norm)
})
.collect();
let fts_scores: HashMap<Vid, f32> = fts_results
.iter()
.map(|(vid, score)| {
let norm = if fts_max > 0.0 { score / fts_max } else { 0.0 };
(*vid, norm)
})
.collect();
let all_vids: HashSet<Vid> = vec_scores
.keys()
.chain(fts_scores.keys())
.cloned()
.collect();
let mut results: Vec<(Vid, f32)> = all_vids
.into_iter()
.map(|vid| {
let vec_score = *vec_scores.get(&vid).unwrap_or(&0.0);
let fts_score = *fts_scores.get(&vid).unwrap_or(&0.0);
let fused = alpha * vec_score + (1.0 - alpha) * fts_score;
(vid, fused)
})
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NormKind {
DistanceToSim,
ScoreByMax,
}
pub type WeightedSource<'a> = (&'a [(Vid, f32)], f32, NormKind);
pub fn fuse_weighted_sources(sources: &[WeightedSource<'_>]) -> Vec<(Vid, f32)> {
let mut fused: HashMap<Vid, f32> = HashMap::new();
for (results, weight, norm) in sources {
let normalized: HashMap<Vid, f32> = match norm {
NormKind::DistanceToSim => {
let max = results.iter().map(|(_, s)| *s).fold(f32::MIN, f32::max);
let min = results.iter().map(|(_, s)| *s).fold(f32::MAX, f32::min);
let range = if max > min { max - min } else { 1.0 };
results
.iter()
.map(|(vid, dist)| (*vid, 1.0 - (dist - min) / range))
.collect()
}
NormKind::ScoreByMax => {
let max = results.iter().map(|(_, s)| *s).fold(0.0f32, f32::max);
results
.iter()
.map(|(vid, score)| {
let norm = if max > 0.0 { score / max } else { 0.0 };
(*vid, norm)
})
.collect()
}
};
for (vid, norm_score) in normalized {
*fused.entry(vid).or_default() += weight * norm_score;
}
}
let mut results: Vec<(Vid, f32)> = fused.into_iter().collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results
}
pub fn fuse_weighted_multi(scores: &[f32], weights: &[f32]) -> f32 {
debug_assert_eq!(scores.len(), weights.len());
scores.iter().zip(weights.iter()).map(|(s, w)| s * w).sum()
}
pub fn fuse_rrf_point(scores: &[f32]) -> (f32, bool) {
if scores.is_empty() {
return (0.0, false);
}
let weight = 1.0 / scores.len() as f32;
let fused: f32 = scores.iter().map(|s| s * weight).sum();
(fused, true)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fuse_weighted_multi() {
let scores = vec![0.8, 0.6];
let weights = vec![0.7, 0.3];
let result = fuse_weighted_multi(&scores, &weights);
assert!((result - 0.74).abs() < 1e-6);
}
#[test]
fn test_fuse_weighted_multi_equal() {
let scores = vec![0.5, 0.5, 0.5];
let weights = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
let result = fuse_weighted_multi(&scores, &weights);
assert!((result - 0.5).abs() < 1e-6);
}
#[test]
fn test_fuse_rrf_point_fallback() {
let scores = vec![0.8, 0.6];
let (result, used_fallback) = fuse_rrf_point(&scores);
assert!(used_fallback);
assert!((result - 0.7).abs() < 1e-6);
}
#[test]
fn test_fuse_rrf_point_empty() {
let (result, used_fallback) = fuse_rrf_point(&[]);
assert!(!used_fallback);
assert!((result - 0.0).abs() < 1e-6);
}
#[test]
fn test_fuse_rrf_disjoint_lists() {
let vec_results = vec![(Vid::from(1u64), 0.9), (Vid::from(2u64), 0.7)];
let fts_results = vec![(Vid::from(3u64), 0.8), (Vid::from(4u64), 0.6)];
let fused = fuse_rrf(&vec_results, &fts_results, 60);
assert_eq!(fused.len(), 4);
let vids: HashSet<Vid> = fused.iter().map(|(v, _)| *v).collect();
assert!(vids.contains(&Vid::from(1u64)));
assert!(vids.contains(&Vid::from(2u64)));
assert!(vids.contains(&Vid::from(3u64)));
assert!(vids.contains(&Vid::from(4u64)));
}
#[test]
fn test_fuse_rrf_overlapping_lists() {
let vec_results = vec![(Vid::from(1u64), 0.9), (Vid::from(2u64), 0.7)];
let fts_results = vec![(Vid::from(1u64), 0.8), (Vid::from(3u64), 0.6)];
let fused = fuse_rrf(&vec_results, &fts_results, 60);
assert_eq!(fused.len(), 3);
assert_eq!(
fused[0].0,
Vid::from(1u64),
"Overlapping VID should rank first"
);
}
#[test]
fn test_fuse_rrf_empty_lists() {
let fused = fuse_rrf(&[], &[], 60);
assert!(fused.is_empty());
}
#[test]
fn test_fuse_rrf_multi_three_sources_overlap_wins() {
let vec_results = vec![(Vid::from(1u64), 0.9), (Vid::from(2u64), 0.7)];
let fts_results = vec![(Vid::from(1u64), 0.8), (Vid::from(3u64), 0.6)];
let sparse_results = vec![(Vid::from(1u64), 5.0), (Vid::from(4u64), 1.0)];
let fused = fuse_rrf_multi(&[&vec_results, &fts_results, &sparse_results], 60);
assert_eq!(fused.len(), 4);
assert_eq!(fused[0].0, Vid::from(1u64));
}
#[test]
fn test_fuse_rrf_multi_empty_third_source_is_noop() {
let vec_results = vec![(Vid::from(1u64), 0.9), (Vid::from(2u64), 0.7)];
let fts_results = vec![(Vid::from(1u64), 0.8), (Vid::from(3u64), 0.6)];
let two_way: HashMap<Vid, f32> = fuse_rrf(&vec_results, &fts_results, 60)
.into_iter()
.collect();
let three_way: HashMap<Vid, f32> = fuse_rrf_multi(&[&vec_results, &fts_results, &[]], 60)
.into_iter()
.collect();
assert_eq!(two_way, three_way, "absent sparse source must be a no-op");
}
#[test]
fn test_fuse_weighted_sources_normalizes_per_source() {
let vec_results = vec![(Vid::from(1u64), 0.0), (Vid::from(2u64), 1.0)];
let sparse_results = vec![(Vid::from(1u64), 2.0), (Vid::from(2u64), 4.0)];
let fused = fuse_weighted_sources(&[
(&vec_results, 0.5, NormKind::DistanceToSim),
(&sparse_results, 0.5, NormKind::ScoreByMax),
]);
let v1 = fused.iter().find(|(v, _)| *v == Vid::from(1u64)).unwrap().1;
let v2 = fused.iter().find(|(v, _)| *v == Vid::from(2u64)).unwrap().1;
assert!((v1 - 0.75).abs() < 1e-6);
assert!((v2 - 0.50).abs() < 1e-6);
assert_eq!(fused[0].0, Vid::from(1u64));
}
#[test]
fn test_fuse_weighted_sources_zero_max_sparse() {
let vec_results = vec![(Vid::from(1u64), 0.0), (Vid::from(2u64), 1.0)];
let sparse_results = vec![(Vid::from(1u64), 0.0), (Vid::from(2u64), 0.0)];
let fused = fuse_weighted_sources(&[
(&vec_results, 0.5, NormKind::DistanceToSim),
(&sparse_results, 0.5, NormKind::ScoreByMax),
]);
let v1 = fused.iter().find(|(v, _)| *v == Vid::from(1u64)).unwrap().1;
assert!(
(v1 - 0.5).abs() < 1e-6,
"sparse contributes 0 when all zero"
);
}
}