use crate::types::NodeRef;
use std::collections::HashMap;
#[cfg(test)]
pub fn rrf_merge(result_sets: &[Vec<(NodeRef, f64)>], k: u32) -> Vec<(NodeRef, f64)> {
rrf_merge_weighted(result_sets, k, None)
}
pub fn rrf_merge_weighted(
result_sets: &[Vec<(NodeRef, f64)>],
k: u32,
weights: Option<&[f32]>,
) -> Vec<(NodeRef, f64)> {
let mut scores: HashMap<NodeRef, f64> = HashMap::new();
for (set_idx, result_set) in result_sets.iter().enumerate() {
let weight = weights.and_then(|w| w.get(set_idx)).copied().unwrap_or(1.0) as f64;
for (rank, (node_ref, _original_score)) in result_set.iter().enumerate() {
*scores.entry(*node_ref).or_default() += weight / (k as f64 + rank as f64 + 1.0);
}
}
let mut merged: Vec<(NodeRef, f64)> = scores.into_iter().collect();
merged.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
merged
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::*;
use proptest::prelude::*;
proptest! {
#[test]
fn prop_rrf_scores_are_positive(
scores in proptest::collection::vec(0.0f64..1.0f64, 1..10),
) {
let items: Vec<(NodeRef, f64)> = scores.into_iter()
.enumerate()
.map(|(i, s)| (NodeRef::Episode(EpisodeId(i as i64 + 1)), s))
.collect();
let result = rrf_merge(&[items], 60);
for (_, score) in &result {
prop_assert!(*score > 0.0, "RRF score should be positive, got {}", score);
}
}
#[test]
fn prop_rrf_preserves_ordering(
n in 2usize..10,
) {
let items: Vec<(NodeRef, f64)> = (0..n)
.map(|i| (NodeRef::Episode(EpisodeId(i as i64 + 1)), 1.0 - (i as f64 / n as f64)))
.collect();
let result = rrf_merge(&[items], 60);
for i in 1..result.len() {
prop_assert!(
result[i - 1].1 >= result[i].1,
"RRF should preserve ordering: {} < {}",
result[i - 1].1,
result[i].1
);
}
}
}
#[test]
fn test_rrf_single_set() {
let set = vec![
(NodeRef::Episode(EpisodeId(1)), 0.9),
(NodeRef::Episode(EpisodeId(2)), 0.5),
];
let merged = rrf_merge(&[set], 60);
assert_eq!(merged.len(), 2);
assert!(merged[0].1 > merged[1].1);
}
#[test]
fn test_rrf_two_sets_overlap() {
let set_a = vec![
(NodeRef::Episode(EpisodeId(1)), 0.9),
(NodeRef::Episode(EpisodeId(2)), 0.5),
];
let set_b = vec![
(NodeRef::Episode(EpisodeId(2)), 0.8),
(NodeRef::Episode(EpisodeId(3)), 0.3),
];
let merged = rrf_merge(&[set_a, set_b], 60);
assert_eq!(merged[0].0, NodeRef::Episode(EpisodeId(2)));
}
#[test]
fn test_rrf_disjoint() {
let set_a = vec![(NodeRef::Episode(EpisodeId(1)), 0.9)];
let set_b = vec![(NodeRef::Episode(EpisodeId(2)), 0.8)];
let merged = rrf_merge(&[set_a, set_b], 60);
assert_eq!(merged.len(), 2);
assert!((merged[0].1 - merged[1].1).abs() < 1e-10);
}
#[test]
fn test_rrf_empty_sets() {
let sets: Vec<Vec<(NodeRef, f64)>> = vec![];
let result = rrf_merge(&sets, 60);
assert!(result.is_empty());
}
#[test]
fn test_rrf_single_item() {
let sets = vec![vec![(NodeRef::Episode(EpisodeId(1)), 1.0)]];
let result = rrf_merge(&sets, 60);
assert_eq!(result.len(), 1);
assert_eq!(result[0].0, NodeRef::Episode(EpisodeId(1)));
}
}