pub fn mmr_select<T, F>(
mut candidates: Vec<(T, f32)>,
similarity: F,
lambda: f32,
top_k: usize,
) -> Vec<(T, f32)>
where
T: Clone,
F: Fn(&T, &T) -> f32,
{
let cap = top_k.min(candidates.len());
let mut selected: Vec<(T, f32)> = Vec::with_capacity(cap);
while selected.len() < cap && !candidates.is_empty() {
let (best_idx, best_score) = candidates
.iter()
.enumerate()
.map(|(i, (item, rel))| {
let max_sim =
selected.iter().map(|(s, _)| similarity(item, s)).fold(0.0_f32, f32::max);
let mmr = lambda * rel - (1.0 - lambda) * max_sim;
(i, mmr)
})
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
.expect("loop condition guarantees candidates is non-empty");
let (item, _rel) = candidates.swap_remove(best_idx);
selected.push((item, best_score));
}
selected
}
#[cfg(test)]
mod tests {
use super::*;
fn sim_zero(_: &&str, _: &&str) -> f32 {
0.0
}
#[test]
fn empty_input_returns_empty() {
let got: Vec<(&str, f32)> = mmr_select(Vec::new(), sim_zero, 0.5, 10);
assert!(got.is_empty());
}
#[test]
fn top_k_clipped_to_candidate_count() {
let cands = vec![("a", 0.9_f32), ("b", 0.5)];
let got = mmr_select(cands, sim_zero, 1.0, 10);
assert_eq!(got.len(), 2);
}
#[test]
fn lambda_one_yields_relevance_descending() {
let cands = vec![("a", 0.5_f32), ("b", 0.9), ("c", 0.7)];
let got = mmr_select(cands, sim_zero, 1.0, 3);
assert_eq!(got[0].0, "b");
assert_eq!(got[1].0, "c");
assert_eq!(got[2].0, "a");
assert!((got[0].1 - 0.9).abs() < f32::EPSILON);
assert!((got[1].1 - 0.7).abs() < f32::EPSILON);
assert!((got[2].1 - 0.5).abs() < f32::EPSILON);
}
#[test]
fn lambda_zero_with_uniform_relevance_picks_diverse() {
let cands = vec![("a", 1.0_f32), ("a-dup", 1.0), ("b", 1.0)];
let sim = |x: &&str, y: &&str| {
if x.starts_with(x.chars().next().unwrap()) && y.starts_with(x.chars().next().unwrap())
{
if x.chars().next() == y.chars().next() {
1.0
} else {
0.0
}
} else {
0.0
}
};
let got = mmr_select(cands, sim, 0.0, 3);
assert_eq!(got.len(), 3);
}
}