use std::collections::HashMap;
#[derive(Debug, Clone, Copy)]
pub struct RrfParams {
pub k: f64,
}
impl Default for RrfParams {
fn default() -> Self {
Self { k: 60.0 }
}
}
#[must_use]
pub fn rrf_fuse(lists: &[Vec<u64>], params: RrfParams) -> Vec<(u64, f64)> {
let mut scores: HashMap<u64, f64> = HashMap::new();
for list in lists {
for (rank, doc_id) in list.iter().enumerate() {
let contrib = 1.0 / (params.k + (rank + 1) as f64);
*scores.entry(*doc_id).or_insert(0.0) += contrib;
}
}
let mut out: Vec<_> = scores.into_iter().collect();
out.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then(a.0.cmp(&b.0))
});
out
}
#[must_use]
pub fn mmr_rerank(mut candidates: Vec<(u64, f64, Vec<f32>)>, lambda: f64, top_k: usize) -> Vec<(u64, f64)> {
let lambda = lambda.clamp(0.0, 1.0);
let mut selected: Vec<(u64, f64, Vec<f32>)> = Vec::with_capacity(top_k.min(candidates.len()));
candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
while !candidates.is_empty() && selected.len() < top_k {
let mut best_idx = 0;
let mut best_score = f64::MIN;
for (i, (_, q_sim, vec)) in candidates.iter().enumerate() {
let max_sim = selected
.iter()
.map(|(_, _, sv)| cosine(vec, sv))
.fold(0.0_f64, f64::max);
let mmr = lambda * q_sim - (1.0 - lambda) * max_sim;
if mmr > best_score {
best_score = mmr;
best_idx = i;
}
}
let chosen = candidates.swap_remove(best_idx);
selected.push((chosen.0, best_score, chosen.2));
}
selected.into_iter().map(|(id, score, _)| (id, score)).collect()
}
fn cosine(a: &[f32], b: &[f32]) -> f64 {
if a.is_empty() || b.is_empty() {
return 0.0;
}
let n = a.len().min(b.len());
let mut dot = 0.0_f64;
let mut na = 0.0_f64;
let mut nb = 0.0_f64;
for i in 0..n {
let x = f64::from(a[i]);
let y = f64::from(b[i]);
dot += x * y;
na += x * x;
nb += y * y;
}
let denom = (na.sqrt() * nb.sqrt()).max(f64::MIN_POSITIVE);
dot / denom
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rrf_fuses_two_lists() {
let lists = vec![vec![1u64, 2, 3], vec![3, 1, 4]];
let fused = rrf_fuse(&lists, RrfParams::default());
let ids: Vec<_> = fused.iter().map(|(id, _)| *id).collect();
assert_eq!(ids[0], 1, "doc 1 should rank first");
assert_eq!(ids[1], 3);
}
#[test]
fn rrf_empty_lists_yields_empty() {
let fused = rrf_fuse(&[], RrfParams::default());
assert!(fused.is_empty());
}
#[test]
fn rrf_known_smoke_values() {
let lists = vec![vec![10u64], vec![10u64]];
let fused = rrf_fuse(&lists, RrfParams { k: 60.0 });
assert_eq!(fused.len(), 1);
assert_eq!(fused[0].0, 10);
let expected = 2.0 / 61.0;
assert!((fused[0].1 - expected).abs() < 1e-9);
}
#[test]
fn mmr_picks_top_relevance_first() {
let candidates = vec![
(1u64, 0.9, vec![1.0, 0.0]),
(2, 0.8, vec![1.0, 0.0]), (3, 0.5, vec![0.0, 1.0]), ];
let res = mmr_rerank(candidates, 1.0, 3);
assert_eq!(res[0].0, 1);
assert_eq!(res[1].0, 2);
assert_eq!(res[2].0, 3);
}
#[test]
fn mmr_balances_diversity_when_lambda_low() {
let candidates = vec![
(1u64, 0.9, vec![1.0, 0.0]),
(2, 0.85, vec![1.0, 0.0]), (3, 0.7, vec![0.0, 1.0]), ];
let res = mmr_rerank(candidates, 0.3, 3);
assert_eq!(res[0].0, 1);
assert_eq!(res[1].0, 3);
}
#[test]
fn mmr_top_k_bounds_output() {
let candidates = vec![(1u64, 0.9, vec![1.0]), (2, 0.8, vec![1.0]), (3, 0.7, vec![1.0])];
let res = mmr_rerank(candidates, 0.5, 2);
assert_eq!(res.len(), 2);
}
#[test]
fn cosine_handles_zero_vector() {
assert_eq!(cosine(&[], &[1.0]), 0.0);
assert_eq!(cosine(&[1.0], &[]), 0.0);
}
#[test]
fn cosine_unit_vectors() {
let s = cosine(&[1.0, 0.0], &[1.0, 0.0]);
assert!((s - 1.0).abs() < 1e-9);
let o = cosine(&[1.0, 0.0], &[0.0, 1.0]);
assert!(o.abs() < 1e-9);
}
}