use std::collections::{HashMap, HashSet};
use std::hash::Hash;
#[derive(Debug, Clone, PartialEq)]
pub struct ScoreStats {
pub count: usize,
pub min: f32,
pub max: f32,
pub mean: f32,
pub std_dev: f32,
pub median: f32,
pub p25: f32,
pub p75: f32,
}
pub fn score_stats<I>(results: &[(I, f32)]) -> Option<ScoreStats> {
if results.is_empty() {
return None;
}
let mut scores: Vec<f32> = results.iter().map(|(_, s)| *s).collect();
scores.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let count = scores.len();
let min = scores[0];
let max = scores[count - 1];
let sum: f32 = scores.iter().sum();
let mean = sum / count as f32;
let variance = scores.iter().map(|s| (s - mean).powi(2)).sum::<f32>() / count as f32;
let std_dev = variance.sqrt();
let median = percentile(&scores, 50.0);
let p25 = percentile(&scores, 25.0);
let p75 = percentile(&scores, 75.0);
Some(ScoreStats {
count,
min,
max,
mean,
std_dev,
median,
p25,
p75,
})
}
pub fn overlap_ratio<I: Eq + Hash>(a: &[(I, f32)], b: &[(I, f32)]) -> f32 {
if a.is_empty() && b.is_empty() {
return 0.0;
}
let set_a: HashSet<_> = a.iter().map(|(id, _)| id).collect();
let set_b: HashSet<_> = b.iter().map(|(id, _)| id).collect();
let intersection = set_a.intersection(&set_b).count();
let union = set_a.union(&set_b).count();
if union == 0 {
return 0.0;
}
intersection as f32 / union as f32
}
pub fn overlap_at_k<I: Eq + Hash>(a: &[(I, f32)], b: &[(I, f32)], k: usize) -> f32 {
let a_k: Vec<_> = a.iter().take(k).map(|(id, s)| (id, *s)).collect();
let b_k: Vec<_> = b.iter().take(k).map(|(id, s)| (id, *s)).collect();
if a_k.is_empty() && b_k.is_empty() {
return 0.0;
}
let set_a: HashSet<_> = a_k.iter().map(|(id, _)| *id).collect();
let set_b: HashSet<_> = b_k.iter().map(|(id, _)| *id).collect();
let intersection = set_a.intersection(&set_b).count();
let union = set_a.union(&set_b).count();
if union == 0 {
return 0.0;
}
intersection as f32 / union as f32
}
pub fn complementarity<I: Clone + Eq + Hash>(
a: &[(I, f32)],
b: &[(I, f32)],
qrels: &HashMap<I, u32>,
) -> f32 {
let relevant_in_a: HashSet<_> = a
.iter()
.filter(|(id, _)| qrels.get(id).is_some_and(|&r| r > 0))
.map(|(id, _)| id.clone())
.collect();
let relevant_in_b: HashSet<_> = b
.iter()
.filter(|(id, _)| qrels.get(id).is_some_and(|&r| r > 0))
.map(|(id, _)| id.clone())
.collect();
let all_relevant_found: HashSet<_> = relevant_in_a.union(&relevant_in_b).collect();
if all_relevant_found.is_empty() {
return 0.0;
}
let only_in_a: HashSet<_> = relevant_in_a.difference(&relevant_in_b).collect();
let only_in_b: HashSet<_> = relevant_in_b.difference(&relevant_in_a).collect();
let unique_relevant = only_in_a.len() + only_in_b.len();
unique_relevant as f32 / all_relevant_found.len() as f32
}
pub fn rank_correlation<I: Clone + Eq + Hash>(a: &[(I, f32)], b: &[(I, f32)]) -> f32 {
let rank_a: HashMap<_, _> = a
.iter()
.enumerate()
.map(|(r, (id, _))| (id.clone(), r))
.collect();
let rank_b: HashMap<_, _> = b
.iter()
.enumerate()
.map(|(r, (id, _))| (id.clone(), r))
.collect();
let shared: Vec<I> = rank_a
.keys()
.filter(|id| rank_b.contains_key(id))
.cloned()
.collect();
let n = shared.len();
if n < 2 {
return 0.0;
}
let mut concordant: i64 = 0;
let mut discordant: i64 = 0;
let mut ties_a: i64 = 0;
let mut ties_b: i64 = 0;
for i in 0..n {
for j in (i + 1)..n {
let ra_i = rank_a[&shared[i]];
let ra_j = rank_a[&shared[j]];
let rb_i = rank_b[&shared[i]];
let rb_j = rank_b[&shared[j]];
let sign_a = (ra_i as i64 - ra_j as i64).signum();
let sign_b = (rb_i as i64 - rb_j as i64).signum();
if sign_a == 0 && sign_b == 0 {
} else if sign_a == 0 {
ties_a += 1;
} else if sign_b == 0 {
ties_b += 1;
} else if sign_a == sign_b {
concordant += 1;
} else {
discordant += 1;
}
}
}
let n0 = concordant + discordant + ties_a + ties_b;
if n0 == 0 {
return 0.0;
}
let denom_a = concordant + discordant + ties_a;
let denom_b = concordant + discordant + ties_b;
if denom_a == 0 || denom_b == 0 {
return 0.0;
}
(concordant - discordant) as f32 / ((denom_a as f64 * denom_b as f64).sqrt() as f32)
}
#[derive(Debug, Clone)]
pub struct FusionDiagnostics<I> {
pub stats_a: Option<ScoreStats>,
pub stats_b: Option<ScoreStats>,
pub overlap: f32,
pub overlap_at_k: f32,
pub complementarity: Option<f32>,
pub rank_correlation: f32,
pub suggestion: FusionSuggestion,
pub unique_to_a: Vec<I>,
pub unique_to_b: Vec<I>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FusionSuggestion {
FuseRecommended {
reason: &'static str,
},
FuseWithCaution {
reason: &'static str,
},
SkipFusion {
reason: &'static str,
},
}
pub fn diagnose<I: Clone + Eq + Hash>(
a: &[(I, f32)],
b: &[(I, f32)],
qrels: Option<&HashMap<I, u32>>,
k: usize,
) -> FusionDiagnostics<I> {
let stats_a = score_stats(a);
let stats_b = score_stats(b);
let overlap = overlap_ratio(a, b);
let overlap_k = overlap_at_k(a, b, k);
let comp = qrels.map(|q| complementarity(a, b, q));
let tau = rank_correlation(a, b);
let set_a: HashSet<_> = a.iter().map(|(id, _)| id).collect();
let set_b: HashSet<_> = b.iter().map(|(id, _)| id).collect();
let unique_a: Vec<I> = a
.iter()
.filter(|(id, _)| !set_b.contains(id))
.map(|(id, _)| id.clone())
.collect();
let unique_b: Vec<I> = b
.iter()
.filter(|(id, _)| !set_a.contains(id))
.map(|(id, _)| id.clone())
.collect();
let suggestion = if let Some(c) = comp {
if c > 0.5 {
FusionSuggestion::FuseRecommended {
reason: "high complementarity (>0.5): retrievers find different relevant docs",
}
} else if c > 0.2 {
FusionSuggestion::FuseWithCaution {
reason: "moderate complementarity: tune weights for best results",
}
} else {
FusionSuggestion::SkipFusion {
reason: "low complementarity (<0.2): retrievers are redundant",
}
}
} else if overlap < 0.1 {
FusionSuggestion::FuseRecommended {
reason: "very low overlap: retrievers see different document sets",
}
} else if tau.abs() < 0.3 {
FusionSuggestion::FuseRecommended {
reason: "low rank correlation: retrievers disagree on ordering",
}
} else if tau > 0.8 {
FusionSuggestion::SkipFusion {
reason: "very high rank correlation (>0.8): retrievers agree, fusion adds little",
}
} else {
FusionSuggestion::FuseWithCaution {
reason: "moderate agreement: fusion may help with tuned weights",
}
};
FusionDiagnostics {
stats_a,
stats_b,
overlap,
overlap_at_k: overlap_k,
complementarity: comp,
rank_correlation: tau,
suggestion,
unique_to_a: unique_a,
unique_to_b: unique_b,
}
}
fn percentile(sorted: &[f32], p: f32) -> f32 {
if sorted.is_empty() {
return 0.0;
}
if sorted.len() == 1 {
return sorted[0];
}
let idx = (p / 100.0) * (sorted.len() - 1) as f32;
let lower = idx.floor() as usize;
let upper = idx.ceil() as usize;
let frac = idx - lower as f32;
if upper >= sorted.len() {
sorted[sorted.len() - 1]
} else {
sorted[lower] * (1.0 - frac) + sorted[upper] * frac
}
}
#[derive(Debug, Clone)]
pub struct MultiDiagnostics<I> {
pub stats: Vec<(String, ScoreStats)>,
pub pairwise_overlap: Vec<(usize, usize, f32)>,
pub pairwise_correlation: Vec<(usize, usize, f32)>,
pub pairwise_complementarity: Vec<(usize, usize, f32)>,
pub full_overlap: f32,
pub unique_docs: Vec<(usize, Vec<I>)>,
pub suggestion: FusionSuggestion,
}
pub fn diagnose_multi<I: Clone + Eq + Hash>(
runs: &[(&str, &[(I, f32)])],
qrels: Option<&HashMap<I, u32>>,
) -> MultiDiagnostics<I> {
let n = runs.len();
let stats: Vec<_> = runs
.iter()
.filter_map(|(name, list)| score_stats(list).map(|s| (name.to_string(), s)))
.collect();
let mut pairwise_overlap = Vec::new();
let mut pairwise_correlation = Vec::new();
let mut pairwise_complementarity = Vec::new();
for i in 0..n {
for j in (i + 1)..n {
pairwise_overlap.push((i, j, overlap_ratio(runs[i].1, runs[j].1)));
pairwise_correlation.push((i, j, rank_correlation(runs[i].1, runs[j].1)));
if let Some(q) = qrels {
pairwise_complementarity.push((i, j, complementarity(runs[i].1, runs[j].1, q)));
}
}
}
let all_sets: Vec<HashSet<_>> = runs
.iter()
.map(|(_, list)| list.iter().map(|(id, _)| id).collect::<HashSet<_>>())
.collect();
let full_overlap = if all_sets.is_empty() {
0.0
} else {
let mut intersection = all_sets[0].clone();
for set in &all_sets[1..] {
intersection = intersection.intersection(set).copied().collect();
}
let union_size: usize = {
let mut union = HashSet::new();
for set in &all_sets {
union.extend(set.iter().copied());
}
union.len()
};
if union_size == 0 {
0.0
} else {
intersection.len() as f32 / union_size as f32
}
};
let mut unique_docs = Vec::new();
for (idx, set) in all_sets.iter().enumerate() {
let others: HashSet<_> = all_sets
.iter()
.enumerate()
.filter(|(i, _)| *i != idx)
.flat_map(|(_, s)| s.iter().copied())
.collect();
let unique: Vec<I> = set
.iter()
.filter(|id| !others.contains(*id))
.map(|id| (*id).clone())
.collect();
if !unique.is_empty() {
unique_docs.push((idx, unique));
}
}
let avg_comp = if pairwise_complementarity.is_empty() {
None
} else {
Some(
pairwise_complementarity
.iter()
.map(|(_, _, c)| c)
.sum::<f32>()
/ pairwise_complementarity.len() as f32,
)
};
let avg_tau = if pairwise_correlation.is_empty() {
0.0
} else {
pairwise_correlation.iter().map(|(_, _, t)| t).sum::<f32>()
/ pairwise_correlation.len() as f32
};
let suggestion = if let Some(c) = avg_comp {
if c > 0.4 {
FusionSuggestion::FuseRecommended {
reason: "high average complementarity: retrievers find different relevant docs",
}
} else if c > 0.15 {
FusionSuggestion::FuseWithCaution {
reason: "moderate complementarity: tune weights per retriever",
}
} else {
FusionSuggestion::SkipFusion {
reason: "low complementarity: retrievers are largely redundant",
}
}
} else if avg_tau.abs() < 0.3 {
FusionSuggestion::FuseRecommended {
reason: "low average rank correlation: retrievers disagree on ordering",
}
} else if avg_tau > 0.7 {
FusionSuggestion::SkipFusion {
reason: "high average rank correlation: retrievers agree, fusion adds little",
}
} else {
FusionSuggestion::FuseWithCaution {
reason: "moderate agreement: fusion may help with tuned weights",
}
};
MultiDiagnostics {
stats,
pairwise_overlap,
pairwise_correlation,
pairwise_complementarity,
full_overlap,
unique_docs,
suggestion,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn score_stats_basic() {
let results = vec![("a", 1.0), ("b", 2.0), ("c", 3.0), ("d", 4.0), ("e", 5.0)];
let stats = score_stats(&results).unwrap();
assert_eq!(stats.count, 5);
assert!((stats.min - 1.0).abs() < 1e-6);
assert!((stats.max - 5.0).abs() < 1e-6);
assert!((stats.mean - 3.0).abs() < 1e-6);
assert!((stats.median - 3.0).abs() < 1e-6);
}
#[test]
fn score_stats_empty() {
let results: Vec<(&str, f32)> = vec![];
assert!(score_stats(&results).is_none());
}
#[test]
fn overlap_ratio_disjoint() {
let a = vec![("d1", 0.9), ("d2", 0.8)];
let b = vec![("d3", 0.9), ("d4", 0.8)];
assert!((overlap_ratio(&a, &b)).abs() < 1e-6);
}
#[test]
fn overlap_ratio_identical() {
let a = vec![("d1", 0.9), ("d2", 0.8)];
let b = vec![("d1", 0.7), ("d2", 0.6)];
assert!((overlap_ratio(&a, &b) - 1.0).abs() < 1e-6);
}
#[test]
fn overlap_ratio_partial() {
let a = vec![("d1", 0.9), ("d2", 0.8)];
let b = vec![("d2", 0.7), ("d3", 0.6)];
assert!((overlap_ratio(&a, &b) - 1.0 / 3.0).abs() < 1e-6);
}
#[test]
fn complementarity_high() {
let qrels: HashMap<&str, u32> = HashMap::from([("d1", 1), ("d2", 1), ("d3", 1), ("d4", 1)]);
let a = vec![("d1", 0.9), ("d2", 0.8)];
let b = vec![("d3", 0.9), ("d4", 0.8)];
let c = complementarity(&a, &b, &qrels);
assert!((c - 1.0).abs() < 1e-6);
}
#[test]
fn complementarity_zero() {
let qrels: HashMap<&str, u32> = HashMap::from([("d1", 1), ("d2", 1)]);
let a = vec![("d1", 0.9), ("d2", 0.8)];
let b = vec![("d1", 0.7), ("d2", 0.6)];
let c = complementarity(&a, &b, &qrels);
assert!(c.abs() < 1e-6);
}
#[test]
fn rank_correlation_identical() {
let a = vec![("d1", 0.9), ("d2", 0.8), ("d3", 0.7)];
let b = vec![("d1", 0.9), ("d2", 0.8), ("d3", 0.7)];
let tau = rank_correlation(&a, &b);
assert!(
(tau - 1.0).abs() < 1e-6,
"identical rankings should have tau=1.0"
);
}
#[test]
fn rank_correlation_reversed() {
let a = vec![("d1", 0.9), ("d2", 0.8), ("d3", 0.7)];
let b = vec![("d3", 0.9), ("d2", 0.8), ("d1", 0.7)];
let tau = rank_correlation(&a, &b);
assert!(
(tau - (-1.0)).abs() < 1e-6,
"reversed rankings should have tau=-1.0, got {}",
tau
);
}
#[test]
fn diagnose_complementary() {
let qrels: HashMap<&str, u32> = HashMap::from([("d1", 1), ("d2", 1), ("d3", 1), ("d4", 1)]);
let a = vec![("d1", 0.9), ("d2", 0.8), ("x1", 0.5)];
let b = vec![("d3", 0.9), ("d4", 0.8), ("x2", 0.5)];
let diag = diagnose(&a, &b, Some(&qrels), 3);
assert!(diag.complementarity.unwrap() > 0.9);
assert!(matches!(
diag.suggestion,
FusionSuggestion::FuseRecommended { .. }
));
assert_eq!(diag.unique_to_a.len(), 3); assert_eq!(diag.unique_to_b.len(), 3);
}
#[test]
fn diagnose_redundant() {
let qrels: HashMap<&str, u32> = HashMap::from([("d1", 1), ("d2", 1)]);
let a = vec![("d1", 0.9), ("d2", 0.8)];
let b = vec![("d1", 0.7), ("d2", 0.6)];
let diag = diagnose(&a, &b, Some(&qrels), 2);
assert!(diag.complementarity.unwrap() < 0.01);
assert!(matches!(
diag.suggestion,
FusionSuggestion::SkipFusion { .. }
));
}
#[test]
fn diagnose_multi_three_retrievers() {
let qrels: HashMap<&str, u32> = HashMap::from([
("d1", 1),
("d2", 1),
("d3", 1),
("d4", 1),
("d5", 1),
("d6", 1),
]);
let bm25 = vec![("d1", 0.9), ("d2", 0.8), ("x1", 0.5)];
let dense = vec![("d3", 0.9), ("d4", 0.8), ("x2", 0.5)];
let sparse = vec![("d5", 0.9), ("d6", 0.8), ("x3", 0.5)];
let diag = diagnose_multi(
&[("bm25", &bm25), ("dense", &dense), ("sparse", &sparse)],
Some(&qrels),
);
assert_eq!(diag.stats.len(), 3);
assert_eq!(diag.pairwise_overlap.len(), 3);
assert_eq!(diag.pairwise_correlation.len(), 3);
assert_eq!(diag.pairwise_complementarity.len(), 3);
assert!(diag.full_overlap < 0.01);
assert!(matches!(
diag.suggestion,
FusionSuggestion::FuseRecommended { .. }
));
}
#[test]
fn diagnose_multi_redundant() {
let qrels: HashMap<&str, u32> = HashMap::from([("d1", 1), ("d2", 1)]);
let a = vec![("d1", 0.9), ("d2", 0.8)];
let b = vec![("d1", 0.7), ("d2", 0.6)];
let c = vec![("d1", 0.85), ("d2", 0.75)];
let diag = diagnose_multi(&[("a", &a), ("b", &b), ("c", &c)], Some(&qrels));
assert!((diag.full_overlap - 1.0).abs() < 1e-6);
assert!(matches!(
diag.suggestion,
FusionSuggestion::SkipFusion { .. }
));
}
#[test]
fn diagnose_multi_no_qrels() {
let a = vec![("d1", 0.9), ("d2", 0.8)];
let b = vec![("d3", 0.9), ("d4", 0.8)];
let diag = diagnose_multi(&[("a", &a), ("b", &b)], None);
assert!(diag.pairwise_complementarity.is_empty());
assert_eq!(diag.pairwise_overlap.len(), 1);
}
}