use grafeo_common::types::NodeId;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub enum FusionMethod {
Rrf {
k: usize,
},
Weighted {
weights: Vec<f64>,
},
}
impl Default for FusionMethod {
fn default() -> Self {
Self::Rrf { k: 60 }
}
}
pub fn fuse_results(
sources: &[Vec<(NodeId, f64)>],
method: &FusionMethod,
k: usize,
) -> Vec<(NodeId, f64)> {
if sources.is_empty() {
return Vec::new();
}
if sources.len() == 1 {
let mut results = sources[0].clone();
results.truncate(k);
return results;
}
match method {
FusionMethod::Rrf { k: rrf_k } => fuse_rrf(sources, *rrf_k, k),
FusionMethod::Weighted { weights } => fuse_weighted(sources, weights, k),
}
}
fn fuse_rrf(sources: &[Vec<(NodeId, f64)>], rrf_k: usize, k: usize) -> Vec<(NodeId, f64)> {
let mut scores: HashMap<NodeId, f64> = HashMap::new();
for source in sources {
for (rank, (node_id, _score)) in source.iter().enumerate() {
let rrf_score = 1.0 / (rrf_k as f64 + (rank + 1) as f64);
*scores.entry(*node_id).or_insert(0.0) += rrf_score;
}
}
let mut results: Vec<(NodeId, f64)> = scores.into_iter().collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(k);
results
}
fn fuse_weighted(sources: &[Vec<(NodeId, f64)>], weights: &[f64], k: usize) -> Vec<(NodeId, f64)> {
let mut scores: HashMap<NodeId, f64> = HashMap::new();
for (i, source) in sources.iter().enumerate() {
let weight = weights
.get(i)
.copied()
.unwrap_or(1.0 / sources.len() as f64);
if source.is_empty() {
continue;
}
let min_score = source.iter().map(|(_, s)| *s).fold(f64::INFINITY, f64::min);
let max_score = source
.iter()
.map(|(_, s)| *s)
.fold(f64::NEG_INFINITY, f64::max);
let range = max_score - min_score;
for (node_id, score) in source {
let normalized = if range > f64::EPSILON {
(score - min_score) / range
} else {
1.0 };
*scores.entry(*node_id).or_insert(0.0) += weight * normalized;
}
}
let mut results: Vec<(NodeId, f64)> = scores.into_iter().collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(k);
results
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rrf_basic() {
let source_a = vec![
(NodeId::new(1), 10.0),
(NodeId::new(2), 8.0),
(NodeId::new(3), 6.0),
];
let source_b = vec![
(NodeId::new(2), 0.9),
(NodeId::new(3), 0.8),
(NodeId::new(4), 0.7),
];
let results = fuse_results(&[source_a, source_b], &FusionMethod::Rrf { k: 60 }, 10);
assert!(!results.is_empty());
let node2_score = results.iter().find(|(id, _)| *id == NodeId::new(2));
let node1_score = results.iter().find(|(id, _)| *id == NodeId::new(1));
assert!(node2_score.is_some());
assert!(node1_score.is_some());
assert!(node2_score.unwrap().1 > node1_score.unwrap().1);
}
#[test]
fn test_rrf_single_source() {
let source = vec![(NodeId::new(1), 10.0), (NodeId::new(2), 8.0)];
let results = fuse_results(
std::slice::from_ref(&source),
&FusionMethod::Rrf { k: 60 },
10,
);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, source[0].0);
}
#[test]
fn test_rrf_k_limit() {
let source_a = vec![
(NodeId::new(1), 10.0),
(NodeId::new(2), 8.0),
(NodeId::new(3), 6.0),
];
let results = fuse_results(&[source_a], &FusionMethod::Rrf { k: 60 }, 1);
assert_eq!(results.len(), 1);
}
#[test]
fn test_weighted_basic() {
let text_results = vec![(NodeId::new(1), 2.5), (NodeId::new(2), 1.0)];
let vector_results = vec![(NodeId::new(2), 0.95), (NodeId::new(3), 0.80)];
let results = fuse_results(
&[text_results, vector_results],
&FusionMethod::Weighted {
weights: vec![0.5, 0.5],
},
10,
);
assert!(!results.is_empty());
let node2 = results.iter().find(|(id, _)| *id == NodeId::new(2));
assert!(node2.is_some());
}
#[test]
fn test_weighted_normalization() {
let source_a = vec![(NodeId::new(1), 100.0)];
let source_b = vec![(NodeId::new(1), 0.01)];
let results = fuse_results(
&[source_a, source_b],
&FusionMethod::Weighted {
weights: vec![0.5, 0.5],
},
10,
);
assert_eq!(results.len(), 1);
assert!((results[0].1 - 1.0).abs() < 0.01);
}
#[test]
fn test_empty_sources() {
let results = fuse_results(&[], &FusionMethod::default(), 10);
assert!(results.is_empty());
}
#[test]
fn test_empty_source_lists() {
let empty: Vec<(NodeId, f64)> = vec![];
let non_empty = vec![(NodeId::new(1), 1.0)];
let results = fuse_results(
&[empty, non_empty],
&FusionMethod::Weighted {
weights: vec![0.5, 0.5],
},
10,
);
assert_eq!(results.len(), 1);
}
}