use std::collections::HashMap;
use crate::search::SearchResult;
pub fn normalize_minmax(results: &mut [SearchResult]) {
if results.is_empty() {
return;
}
for r in results.iter_mut() {
if !r.score.is_finite() {
r.score = 0.0;
}
}
let min = results
.iter()
.map(|r| r.score)
.fold(f32::INFINITY, f32::min);
let max = results
.iter()
.map(|r| r.score)
.fold(f32::NEG_INFINITY, f32::max);
if (min - max).abs() < f32::EPSILON {
for r in results.iter_mut() {
r.score = 1.0;
}
return;
}
let span = max - min;
for r in results.iter_mut() {
r.score = (r.score - min) / span;
}
}
pub fn weighted_sum_fuse(result_sets: &[Vec<SearchResult>], weights: &[f32]) -> Vec<SearchResult> {
assert_eq!(
result_sets.len(),
weights.len(),
"weighted_sum_fuse: result_sets.len() ({}) must equal weights.len() ({})",
result_sets.len(),
weights.len(),
);
let mut scores: HashMap<String, f32> = HashMap::new();
let mut texts: HashMap<String, String> = HashMap::new();
for (results, weight) in result_sets.iter().zip(weights.iter()) {
for result in results {
*scores.entry(result.id.clone()).or_insert(0.0) += weight * result.score;
texts
.entry(result.id.clone())
.or_insert_with(|| result.text.clone());
}
}
let mut fused: Vec<SearchResult> = scores
.into_iter()
.map(|(id, score)| SearchResult {
text: texts.remove(&id).unwrap_or_default(),
id,
score,
})
.collect();
fused.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
fused
}
pub fn combmnz_fuse(result_sets: &[Vec<SearchResult>], k: usize) -> Vec<SearchResult> {
assert!(
k >= 1,
"combmnz_fuse: k must be >= 1 (got {k}); k == 0 zeroes every score"
);
let mut scores: HashMap<String, f32> = HashMap::new();
let mut texts: HashMap<String, String> = HashMap::new();
let mut topk_counts: HashMap<String, usize> = HashMap::new();
for results in result_sets {
let topk = results.len().min(k);
for (rank, result) in results.iter().enumerate() {
*scores.entry(result.id.clone()).or_insert(0.0) += result.score;
texts
.entry(result.id.clone())
.or_insert_with(|| result.text.clone());
if rank < topk {
*topk_counts.entry(result.id.clone()).or_insert(0) += 1;
}
}
}
let mut fused: Vec<SearchResult> = scores
.into_iter()
.map(|(id, sum)| SearchResult {
text: texts.remove(&id).unwrap_or_default(),
score: topk_counts.get(&id).copied().unwrap_or(0) as f32 * sum,
id,
})
.collect();
fused.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
fused
}
#[cfg(test)]
mod tests {
use super::*;
fn make_results(ids: &[(&str, f32)]) -> Vec<SearchResult> {
ids.iter()
.map(|(id, score)| SearchResult {
id: id.to_string(),
score: *score,
text: format!("text for {id}"),
})
.collect()
}
#[test]
fn normalize_minmax_maps_extremes() {
let mut results = make_results(&[("a", 0.2), ("b", 0.5), ("c", 0.9)]);
normalize_minmax(&mut results);
let a = results.iter().find(|r| r.id == "a").unwrap().score;
let b = results.iter().find(|r| r.id == "b").unwrap().score;
let c = results.iter().find(|r| r.id == "c").unwrap().score;
assert!((a - 0.0).abs() < 1e-6);
assert!((c - 1.0).abs() < 1e-6);
assert!((b - (0.5 - 0.2) / (0.9 - 0.2)).abs() < 1e-6);
}
#[test]
fn normalize_minmax_all_equal_sets_one() {
let mut results = make_results(&[("a", 0.5), ("b", 0.5), ("c", 0.5)]);
normalize_minmax(&mut results);
for r in &results {
assert!((r.score - 1.0).abs() < 1e-6);
}
}
#[test]
fn normalize_minmax_empty_is_noop() {
let mut results: Vec<SearchResult> = vec![];
normalize_minmax(&mut results);
assert!(results.is_empty());
}
#[test]
fn normalize_minmax_clamps_non_finite() {
let mut results = make_results(&[("a", f32::NAN), ("b", f32::INFINITY), ("c", 0.5)]);
normalize_minmax(&mut results);
for r in &results {
assert!(r.score.is_finite(), "score {:?} is not finite", r.score);
assert!(
(0.0..=1.0).contains(&r.score),
"score {:?} out of [0,1]",
r.score
);
}
}
#[test]
fn weighted_sum_fuse_formula() {
let a = make_results(&[("a", 1.0), ("b", 0.5)]);
let b = make_results(&[("b", 1.0), ("c", 0.4)]);
let fused = weighted_sum_fuse(&[a, b], &[0.7, 0.3]);
assert_eq!(fused.len(), 3);
assert_eq!(fused[0].id, "a");
assert_eq!(fused[1].id, "b");
assert_eq!(fused[2].id, "c");
let score_a = fused.iter().find(|r| r.id == "a").unwrap().score;
let score_b = fused.iter().find(|r| r.id == "b").unwrap().score;
let score_c = fused.iter().find(|r| r.id == "c").unwrap().score;
assert!((score_a - 0.70).abs() < 1e-6);
assert!((score_b - 0.65).abs() < 1e-6);
assert!((score_c - 0.12).abs() < 1e-6);
}
#[test]
fn combmnz_boosts_multi_list_doc() {
let a = make_results(&[("hot", 0.99), ("cold", 0.1)]);
let b = make_results(&[("cold", 0.99), ("hot", 0.1)]);
let c = make_results(&[("cold", 0.99), ("warm", 0.1)]);
let fused = combmnz_fuse(&[a, b, c], 1);
assert_eq!(fused[0].id, "cold");
let cold = fused.iter().find(|r| r.id == "cold").unwrap().score;
let hot = fused.iter().find(|r| r.id == "hot").unwrap().score;
assert!(cold > hot);
assert!((cold - 4.16).abs() < 1e-5);
assert!((hot - 1.09).abs() < 1e-5);
}
#[test]
#[should_panic(expected = "must equal weights.len()")]
fn weighted_sum_fuse_panics_on_length_mismatch() {
let a = make_results(&[("a", 1.0)]);
let _ = weighted_sum_fuse(&[a.clone(), a], &[0.5]);
}
}