use super::strategy::FusionStrategy;
fn sample_results() -> Vec<Vec<(u64, f32)>> {
vec![
vec![(1, 0.95), (2, 0.85), (3, 0.75), (4, 0.65)],
vec![(2, 0.90), (1, 0.80), (5, 0.70), (3, 0.60)],
vec![(1, 0.92), (3, 0.82), (2, 0.72), (6, 0.62)],
]
}
fn partial_overlap_results() -> Vec<Vec<(u64, f32)>> {
vec![
vec![(1, 0.9), (2, 0.8)], vec![(2, 0.85), (3, 0.75)], vec![(3, 0.8), (4, 0.7)], ]
}
fn single_query_results() -> Vec<Vec<(u64, f32)>> {
vec![vec![(1, 0.95), (2, 0.85), (3, 0.75)]]
}
fn empty_results() -> Vec<Vec<(u64, f32)>> {
vec![]
}
fn results_with_empty_query() -> Vec<Vec<(u64, f32)>> {
vec![
vec![(1, 0.9), (2, 0.8)],
vec![], vec![(1, 0.85), (3, 0.75)],
]
}
#[test]
fn test_average_basic() {
let strategy = FusionStrategy::Average;
let results = sample_results();
let fused = strategy.fuse(results).unwrap();
assert!(!fused.is_empty());
let doc1 = fused.iter().find(|(id, _)| *id == 1).unwrap();
assert!(
(doc1.1 - 0.89).abs() < 0.01,
"Doc 1 avg should be ~0.89, got {}",
doc1.1
);
for i in 1..fused.len() {
assert!(
fused[i - 1].1 >= fused[i].1,
"Results should be sorted descending"
);
}
}
#[test]
fn test_average_partial_overlap() {
let strategy = FusionStrategy::Average;
let results = partial_overlap_results();
let fused = strategy.fuse(results).unwrap();
let doc1 = fused.iter().find(|(id, _)| *id == 1).unwrap();
assert!((doc1.1 - 0.9).abs() < 0.01);
let doc2 = fused.iter().find(|(id, _)| *id == 2).unwrap();
assert!((doc2.1 - 0.825).abs() < 0.01);
}
#[test]
fn test_average_single_query() {
let strategy = FusionStrategy::Average;
let results = single_query_results();
let fused = strategy.fuse(results).unwrap();
assert_eq!(fused.len(), 3);
assert!((fused[0].1 - 0.95).abs() < 0.001);
}
#[test]
fn test_average_empty_input() {
let strategy = FusionStrategy::Average;
let results = empty_results();
let fused = strategy.fuse(results).unwrap();
assert!(fused.is_empty());
}
#[test]
fn test_average_with_empty_query() {
let strategy = FusionStrategy::Average;
let results = results_with_empty_query();
let fused = strategy.fuse(results).unwrap();
assert!(!fused.is_empty());
}
#[test]
fn test_maximum_basic() {
let strategy = FusionStrategy::Maximum;
let results = sample_results();
let fused = strategy.fuse(results).unwrap();
let doc1 = fused.iter().find(|(id, _)| *id == 1).unwrap();
assert!(
(doc1.1 - 0.95).abs() < 0.001,
"Doc 1 max should be 0.95, got {}",
doc1.1
);
let doc2 = fused.iter().find(|(id, _)| *id == 2).unwrap();
assert!(
(doc2.1 - 0.90).abs() < 0.001,
"Doc 2 max should be 0.90, got {}",
doc2.1
);
for i in 1..fused.len() {
assert!(fused[i - 1].1 >= fused[i].1);
}
}
#[test]
fn test_maximum_partial_overlap() {
let strategy = FusionStrategy::Maximum;
let results = partial_overlap_results();
let fused = strategy.fuse(results).unwrap();
let doc2 = fused.iter().find(|(id, _)| *id == 2).unwrap();
assert!((doc2.1 - 0.85).abs() < 0.001);
}
#[test]
fn test_maximum_single_query() {
let strategy = FusionStrategy::Maximum;
let results = single_query_results();
let fused = strategy.fuse(results).unwrap();
assert_eq!(fused.len(), 3);
assert!((fused[0].1 - 0.95).abs() < 0.001);
}
#[test]
fn test_rrf_basic() {
let strategy = FusionStrategy::RRF { k: 60 };
let results = sample_results();
let fused = strategy.fuse(results).unwrap();
assert!(!fused.is_empty());
let doc1 = fused.iter().find(|(id, _)| *id == 1).unwrap();
assert!(
doc1.1 > 0.04,
"Doc 1 RRF score should be > 0.04, got {}",
doc1.1
);
for i in 1..fused.len() {
assert!(fused[i - 1].1 >= fused[i].1);
}
}
#[test]
fn test_rrf_k_parameter() {
let results = sample_results();
let strategy_low_k = FusionStrategy::RRF { k: 1 };
let strategy_high_k = FusionStrategy::RRF { k: 100 };
let fused_low = strategy_low_k.fuse(results.clone()).unwrap();
let fused_high = strategy_high_k.fuse(results).unwrap();
let doc1_low = fused_low.iter().find(|(id, _)| *id == 1).unwrap();
let doc1_high = fused_high.iter().find(|(id, _)| *id == 1).unwrap();
assert!(
doc1_low.1 > doc1_high.1,
"Lower k should yield higher scores"
);
}
#[test]
fn test_rrf_default_k() {
let strategy = FusionStrategy::rrf_default();
match strategy {
FusionStrategy::RRF { k } => assert_eq!(k, 60),
_ => panic!("Expected RRF variant"),
}
}
#[test]
fn test_rrf_single_query() {
let strategy = FusionStrategy::RRF { k: 60 };
let results = single_query_results();
let fused = strategy.fuse(results).unwrap();
assert_eq!(fused.len(), 3);
assert!(fused[0].1 > fused[1].1);
}
#[test]
fn test_weighted_basic() {
let strategy = FusionStrategy::Weighted {
avg_weight: 0.6,
max_weight: 0.3,
hit_weight: 0.1,
};
let results = sample_results();
let fused = strategy.fuse(results).unwrap();
let doc1 = fused.iter().find(|(id, _)| *id == 1).unwrap();
assert!(
(doc1.1 - 0.919).abs() < 0.02,
"Doc 1 weighted should be ~0.919, got {}",
doc1.1
);
for i in 1..fused.len() {
assert!(fused[i - 1].1 >= fused[i].1);
}
}
#[test]
fn test_weighted_partial_overlap() {
let strategy = FusionStrategy::Weighted {
avg_weight: 0.6,
max_weight: 0.3,
hit_weight: 0.1,
};
let results = partial_overlap_results();
let fused = strategy.fuse(results).unwrap();
let doc1 = fused.iter().find(|(id, _)| *id == 1).unwrap();
let doc2 = fused.iter().find(|(id, _)| *id == 2).unwrap();
assert!(doc1.1 > doc2.1);
}
#[test]
fn test_weighted_validation_sum_to_one() {
let result = FusionStrategy::weighted(0.5, 0.3, 0.1);
assert!(result.is_err(), "Weights summing to 0.9 should fail");
let result = FusionStrategy::weighted(0.6, 0.3, 0.1);
assert!(result.is_ok(), "Weights summing to 1.0 should succeed");
let result = FusionStrategy::weighted(0.5, 0.5, 0.1);
assert!(result.is_err(), "Weights summing to 1.1 should fail");
}
#[test]
fn test_weighted_validation_non_negative() {
let result = FusionStrategy::weighted(-0.1, 0.6, 0.5);
assert!(result.is_err(), "Negative weights should fail");
}
#[test]
fn test_weighted_zero_hit_weight() {
let strategy = FusionStrategy::Weighted {
avg_weight: 0.7,
max_weight: 0.3,
hit_weight: 0.0,
};
let results = sample_results();
let fused = strategy.fuse(results).unwrap();
assert!(!fused.is_empty());
}
#[test]
fn test_fuse_preserves_all_documents() {
let strategy = FusionStrategy::Average;
let results = sample_results();
let fused = strategy.fuse(results).unwrap();
let ids: std::collections::HashSet<u64> = fused.iter().map(|(id, _)| *id).collect();
assert!(ids.contains(&1));
assert!(ids.contains(&2));
assert!(ids.contains(&3));
assert!(ids.contains(&4));
assert!(ids.contains(&5));
assert!(ids.contains(&6));
}
#[test]
fn test_fuse_handles_duplicate_ids_in_same_query() {
let strategy = FusionStrategy::Average;
let results = vec![
vec![(1, 0.9), (1, 0.8), (2, 0.7)], ];
let fused = strategy.fuse(results).unwrap();
let doc1_count = fused.iter().filter(|(id, _)| *id == 1).count();
assert_eq!(doc1_count, 1, "Doc 1 should appear exactly once");
}
#[test]
fn test_fuse_score_bounds() {
let strategy = FusionStrategy::Average;
let results = sample_results();
let fused = strategy.fuse(results).unwrap();
for (_, score) in &fused {
assert!(
*score >= 0.0 && *score <= 1.0,
"Score {score} out of bounds"
);
}
}
#[test]
fn test_rrf_scores_are_positive() {
let strategy = FusionStrategy::RRF { k: 60 };
let results = sample_results();
let fused = strategy.fuse(results).unwrap();
for (_, score) in &fused {
assert!(*score > 0.0, "RRF score should be positive");
}
}
#[test]
fn test_rsf_normalization() {
let strategy = FusionStrategy::relative_score(0.5, 0.5).unwrap();
let results = vec![
vec![(1, 1.0_f32), (2, 2.0), (3, 3.0)], vec![(2, 10.0_f32), (4, 20.0)], ];
let fused = strategy.fuse(results).unwrap();
let find = |id: u64| fused.iter().find(|(i, _)| *i == id).unwrap().1;
assert!((find(3) - 0.5).abs() < 1e-5);
assert!((find(4) - 0.5).abs() < 1e-5);
assert!((find(2) - 0.25).abs() < 1e-5);
assert!((find(1) - 0.0).abs() < 1e-5);
}
#[test]
fn test_rsf_normalization_equal_scores() {
let strategy = FusionStrategy::relative_score(0.5, 0.5).unwrap();
let results = vec![
vec![(1, 5.0_f32), (2, 5.0), (3, 5.0)],
vec![(1, 3.0_f32), (4, 3.0)],
];
let fused = strategy.fuse(results).unwrap();
let find = |id: u64| fused.iter().find(|(i, _)| *i == id).unwrap().1;
assert!((find(1) - 0.5).abs() < 1e-5);
assert!((find(2) - 0.25).abs() < 1e-5);
assert!((find(4) - 0.25).abs() < 1e-5);
}
#[test]
fn test_rsf_fuse_two_branches() {
let strategy = FusionStrategy::relative_score(0.7, 0.3).unwrap();
let results = vec![
vec![(1, 10.0_f32), (2, 8.0), (3, 6.0)], vec![(3, 5.0_f32), (4, 3.0), (1, 1.0)], ];
let fused = strategy.fuse(results).unwrap();
assert_eq!(fused[0].0, 1); assert_eq!(fused[1].0, 2); assert_eq!(fused[2].0, 3); assert_eq!(fused[3].0, 4); }
#[test]
fn test_rsf_single_branch_empty() {
let strategy = FusionStrategy::relative_score(0.5, 0.5).unwrap();
let results = vec![
vec![], vec![(1, 5.0_f32), (2, 3.0), (3, 1.0)], ];
let fused = strategy.fuse(results).unwrap();
assert_eq!(fused.len(), 3);
assert_eq!(fused[0].0, 1);
assert!((fused[0].1 - 0.5).abs() < 1e-5);
}
#[test]
fn test_rsf_validation_negative_weight() {
let result = FusionStrategy::relative_score(-0.1, 1.1);
assert!(result.is_err());
}
#[test]
fn test_rsf_validation_sum_not_one() {
let result = FusionStrategy::relative_score(0.3, 0.3);
assert!(result.is_err());
}
#[test]
fn test_rsf_ignores_extra_branches_beyond_two() {
let strategy = FusionStrategy::relative_score(0.6, 0.4).unwrap();
let two_branches = vec![
vec![(1_u64, 10.0_f32), (2, 8.0)], vec![(2_u64, 5.0_f32), (3, 3.0)], ];
let three_branches = vec![
vec![(1_u64, 10.0_f32), (2, 8.0)], vec![(2_u64, 5.0_f32), (3, 3.0)], vec![(4_u64, 99.0_f32)], ];
let fused_two = strategy.fuse(two_branches).unwrap();
let fused_three = strategy.fuse(three_branches).unwrap();
assert!(
!fused_three.iter().any(|(id, _)| *id == 4),
"doc 4 from the extra branch must be absent from RSF output"
);
for (id, score) in &fused_two {
let matching = fused_three.iter().find(|(i, _)| i == id);
assert!(
matching.is_some(),
"doc {id} must appear in three-branch result"
);
let (_, score_three) = matching.unwrap();
assert!(
(score - score_three).abs() < 1e-5,
"score for doc {id} must not change when an extra branch is present"
);
}
}