use super::simd;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct MmrConfig {
pub lambda: f32,
pub k: usize,
}
impl Default for MmrConfig {
fn default() -> Self {
Self { lambda: 0.5, k: 10 }
}
}
impl MmrConfig {
#[must_use]
pub fn new(lambda: f32, k: usize) -> Self {
Self {
lambda: lambda.clamp(0.0, 1.0),
k,
}
}
#[must_use]
pub fn with_lambda(mut self, lambda: f32) -> Self {
self.lambda = lambda.clamp(0.0, 1.0);
self
}
#[must_use]
pub const fn with_k(mut self, k: usize) -> Self {
self.k = k;
self
}
}
pub mod tuning {
use super::{mmr, MmrConfig};
#[derive(Debug, Clone)]
pub struct MmrDiagnostics {
pub lambda: f32,
pub avg_relevance: f32,
pub avg_diversity: f32,
pub tradeoff_score: f32,
}
#[must_use]
pub fn tune_lambda<I: Clone + Eq>(
candidates: &[(I, f32)],
similarity: &[f32],
lambda_values: &[f32],
k: usize,
) -> Vec<MmrDiagnostics> {
let n = candidates.len();
let mut diagnostics = Vec::new();
for &lambda in lambda_values {
let config = MmrConfig::new(lambda, k);
let selected = mmr(candidates, similarity, config);
if selected.is_empty() {
continue;
}
let avg_relevance: f32 =
selected.iter().map(|(_, score)| score).sum::<f32>() / selected.len() as f32;
let mut diversity_sum = 0.0;
for (i, (id1, _)) in selected.iter().enumerate() {
let mut max_sim: f32 = 0.0;
for (j, (id2, _)) in selected.iter().enumerate() {
if i != j {
let idx1 = candidates
.iter()
.position(|(id, _)| *id == *id1)
.expect("selected item must exist in candidates");
let idx2 = candidates
.iter()
.position(|(id, _)| *id == *id2)
.expect("selected item must exist in candidates");
let sim = similarity[idx1 * n + idx2];
max_sim = max_sim.max(sim);
}
}
diversity_sum += 1.0 - max_sim;
}
let avg_diversity = diversity_sum / selected.len() as f32;
let tradeoff_score = lambda * avg_relevance + (1.0 - lambda) * avg_diversity;
diagnostics.push(MmrDiagnostics {
lambda,
avg_relevance,
avg_diversity,
tradeoff_score,
});
}
diagnostics.sort_unstable_by(|a, b| b.tradeoff_score.total_cmp(&a.tradeoff_score));
diagnostics
}
#[must_use]
pub fn mmr_adaptive<I: Clone + Eq>(
candidates: &[(I, f32)],
similarity: &[f32],
initial_lambda: f32,
final_lambda: f32,
k: usize,
) -> Vec<(I, f32)> {
let mut selected = Vec::new();
let mut remaining: Vec<(I, f32)> = candidates.to_vec();
for step in 0..k.min(remaining.len()) {
let progress = step as f32 / (k - 1).max(1) as f32;
let lambda = initial_lambda + (final_lambda - initial_lambda) * progress;
let config = MmrConfig::new(lambda, 1);
let step_selected = mmr(&remaining, similarity, config);
if let Some((id, score)) = step_selected.first() {
selected.push((id.clone(), *score));
remaining.retain(|(rid, _)| rid != id);
} else {
break;
}
}
selected
}
}
#[must_use]
pub fn mmr<I: Clone>(
candidates: &[(I, f32)],
similarity: &[f32],
config: MmrConfig,
) -> Vec<(I, f32)> {
try_mmr(candidates, similarity, config).expect("similarity matrix must be n×n")
}
pub fn try_mmr<I: Clone>(
candidates: &[(I, f32)],
similarity: &[f32],
config: MmrConfig,
) -> Result<Vec<(I, f32)>, super::RerankError> {
let n = candidates.len();
if similarity.len() != n * n {
return Err(super::RerankError::DimensionMismatch {
expected: n * n,
got: similarity.len(),
});
}
if n == 0 || config.k == 0 {
return Ok(Vec::new());
}
let (rel_min, rel_max) = candidates
.iter()
.map(|(_, s)| *s)
.fold((f32::INFINITY, f32::NEG_INFINITY), |(lo, hi), s| {
(lo.min(s), hi.max(s))
});
const REL_RANGE_EPSILON: f32 = 1e-9;
let rel_range = rel_max - rel_min;
let rel_norm: Vec<f32> = if rel_range > REL_RANGE_EPSILON {
candidates
.iter()
.map(|(_, s)| (s - rel_min) / rel_range)
.collect()
} else {
vec![1.0; n] };
let mut selected_indices: Vec<usize> = Vec::with_capacity(config.k.min(n));
let mut remaining: Vec<usize> = (0..n).collect();
for _ in 0..config.k.min(n) {
if remaining.is_empty() {
break;
}
let mut best_idx = 0;
let mut best_score = f32::NEG_INFINITY;
for (remaining_pos, &cand_idx) in remaining.iter().enumerate() {
let relevance = rel_norm[cand_idx];
let max_sim = if selected_indices.is_empty() {
0.0
} else {
selected_indices
.iter()
.map(|&sel_idx| similarity[cand_idx * n + sel_idx])
.fold(f32::NEG_INFINITY, f32::max)
};
let mmr_score = config.lambda * relevance - (1.0 - config.lambda) * max_sim;
if mmr_score > best_score {
best_score = mmr_score;
best_idx = remaining_pos;
}
}
let chosen = remaining.swap_remove(best_idx);
selected_indices.push(chosen);
}
Ok(selected_indices
.into_iter()
.map(|idx| candidates[idx].clone())
.collect())
}
#[must_use]
pub fn mmr_cosine<I: Clone, V: AsRef<[f32]>>(
candidates: &[(I, f32)],
embeddings: &[V],
config: MmrConfig,
) -> Vec<(I, f32)> {
let n = candidates.len();
assert_eq!(
embeddings.len(),
n,
"embeddings must have same length as candidates"
);
if n == 0 || config.k == 0 {
return Vec::new();
}
let (rel_min, rel_max) = candidates
.iter()
.map(|(_, s)| *s)
.fold((f32::INFINITY, f32::NEG_INFINITY), |(lo, hi), s| {
(lo.min(s), hi.max(s))
});
let rel_range = rel_max - rel_min;
let rel_norm: Vec<f32> = if rel_range > 1e-9 {
candidates
.iter()
.map(|(_, s)| (s - rel_min) / rel_range)
.collect()
} else {
vec![1.0; n]
};
let mut selected_indices: Vec<usize> = Vec::with_capacity(config.k.min(n));
let mut remaining: Vec<usize> = (0..n).collect();
for _ in 0..config.k.min(n) {
if remaining.is_empty() {
break;
}
let mut best_idx = 0;
let mut best_score = f32::NEG_INFINITY;
for (remaining_pos, &cand_idx) in remaining.iter().enumerate() {
let relevance = rel_norm[cand_idx];
let cand_emb = embeddings[cand_idx].as_ref();
let max_sim = if selected_indices.is_empty() {
0.0
} else {
selected_indices
.iter()
.map(|&sel_idx| simd::cosine(cand_emb, embeddings[sel_idx].as_ref()))
.fold(f32::NEG_INFINITY, f32::max)
};
let mmr_score = config.lambda * relevance - (1.0 - config.lambda) * max_sim;
if mmr_score > best_score {
best_score = mmr_score;
best_idx = remaining_pos;
}
}
let chosen = remaining.swap_remove(best_idx);
selected_indices.push(chosen);
}
selected_indices
.into_iter()
.map(|idx| candidates[idx].clone())
.collect()
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct DppConfig {
pub k: usize,
pub alpha: f32,
}
impl Default for DppConfig {
fn default() -> Self {
Self { k: 10, alpha: 1.0 }
}
}
impl DppConfig {
#[must_use]
pub fn new(k: usize, alpha: f32) -> Self {
Self {
k,
alpha: alpha.max(0.0),
}
}
#[must_use]
pub const fn with_k(mut self, k: usize) -> Self {
self.k = k;
self
}
#[must_use]
pub fn with_alpha(mut self, alpha: f32) -> Self {
self.alpha = alpha.max(0.0);
self
}
}
#[must_use]
pub fn dpp<I: Clone, V: AsRef<[f32]>>(
candidates: &[(I, f32)],
embeddings: &[V],
config: DppConfig,
) -> Vec<(I, f32)> {
let n = candidates.len();
assert_eq!(
embeddings.len(),
n,
"embeddings must have same length as candidates"
);
if n == 0 || config.k == 0 {
return Vec::new();
}
let qualities: Vec<f32> = candidates
.iter()
.map(|(_, r)| (r * config.alpha).exp())
.collect();
let mut selected_indices: Vec<usize> = Vec::with_capacity(config.k.min(n));
let mut remaining: Vec<usize> = (0..n).collect();
let mut c: Vec<f32> = embeddings
.iter()
.map(|e| {
let v = e.as_ref();
simd::dot(v, v)
})
.collect();
for _ in 0..config.k.min(n) {
if remaining.is_empty() {
break;
}
let mut best_idx = 0;
let mut best_score = f32::NEG_INFINITY;
for (pos, &cand_idx) in remaining.iter().enumerate() {
let score = qualities[cand_idx] * c[cand_idx].max(0.0).sqrt();
if score > best_score {
best_score = score;
best_idx = pos;
}
}
let chosen = remaining.swap_remove(best_idx);
selected_indices.push(chosen);
let chosen_emb = embeddings[chosen].as_ref();
let c_chosen = c[chosen].max(1e-9);
for &idx in &remaining {
let v_i = embeddings[idx].as_ref();
let dot_product = simd::dot(v_i, chosen_emb);
c[idx] -= (dot_product * dot_product) / c_chosen;
c[idx] = c[idx].max(0.0); }
}
selected_indices
.into_iter()
.map(|idx| candidates[idx].clone())
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mmr_pure_relevance() {
let candidates = vec![("a", 0.9), ("b", 0.8), ("c", 0.7)];
let sim = vec![
1.0, 0.9, 0.9, 0.9, 1.0, 0.9, 0.9, 0.9, 1.0,
];
let result = mmr(&candidates, &sim, MmrConfig::new(1.0, 3));
assert_eq!(result[0].0, "a");
assert_eq!(result[1].0, "b");
assert_eq!(result[2].0, "c");
}
#[test]
fn mmr_config_clamps_lambda() {
assert_eq!(MmrConfig::new(-0.5, 10).lambda, 0.0);
assert_eq!(MmrConfig::new(1.5, 10).lambda, 1.0);
assert_eq!(MmrConfig::default().with_lambda(-0.5).lambda, 0.0);
assert_eq!(MmrConfig::default().with_lambda(1.5).lambda, 1.0);
assert_eq!(MmrConfig::default().with_lambda(0.7).lambda, 0.7);
}
#[test]
fn mmr_prefers_diverse() {
let candidates = vec![("a", 0.9), ("b", 0.85), ("c", 0.8)];
let sim = vec![
1.0, 0.95, 0.1, 0.95, 1.0, 0.1, 0.1, 0.1, 1.0, ];
let result = mmr(&candidates, &sim, MmrConfig::new(0.5, 2));
assert_eq!(result[0].0, "a"); assert_eq!(result[1].0, "c"); }
#[test]
fn mmr_pure_diversity() {
let candidates = vec![("a", 0.9), ("b", 0.85), ("c", 0.1)];
let sim = vec![
1.0, 0.99, 0.01, 0.99, 1.0, 0.01, 0.01, 0.01, 1.0,
];
let result = mmr(&candidates, &sim, MmrConfig::new(0.0, 2));
assert!(result.iter().any(|(id, _)| *id == "c"));
}
#[test]
fn mmr_cosine_basic() {
let candidates = vec![("a", 0.9), ("b", 0.85), ("c", 0.8)];
let embeddings: Vec<Vec<f32>> = vec![
vec![1.0, 0.0, 0.0],
vec![0.99, 0.1, 0.0], vec![0.0, 0.0, 1.0], ];
let result = mmr_cosine(&candidates, &embeddings, MmrConfig::new(0.5, 2));
assert_eq!(result[0].0, "a");
assert_eq!(result[1].0, "c"); }
#[test]
fn mmr_empty_candidates() {
let candidates: Vec<(&str, f32)> = vec![];
let sim: Vec<f32> = vec![];
let result = mmr(&candidates, &sim, MmrConfig::default());
assert!(result.is_empty());
}
#[test]
fn mmr_k_larger_than_n() {
let candidates = vec![("a", 0.9), ("b", 0.8)];
let sim = vec![1.0, 0.5, 0.5, 1.0];
let result = mmr(&candidates, &sim, MmrConfig::new(0.5, 10));
assert_eq!(result.len(), 2);
}
#[test]
fn mmr_single_candidate() {
let candidates = vec![("a", 0.9)];
let sim = vec![1.0];
let result = mmr(&candidates, &sim, MmrConfig::new(0.5, 1));
assert_eq!(result.len(), 1);
assert_eq!(result[0].0, "a");
}
#[test]
fn try_mmr_invalid_matrix() {
let candidates = vec![("a", 0.9), ("b", 0.8)];
let sim = vec![1.0]; let result = try_mmr(&candidates, &sim, MmrConfig::default());
assert!(result.is_err());
}
#[test]
fn mmr_exact_formula_first_selection() {
let candidates = vec![("a", 0.5), ("b", 1.0), ("c", 0.8)];
let sim = vec![
1.0, 0.5, 0.5, 0.5, 1.0, 0.5, 0.5, 0.5, 1.0,
];
let result = mmr(&candidates, &sim, MmrConfig::new(0.7, 1));
assert_eq!(
result[0].0, "b",
"First selection should be highest relevance"
);
}
#[test]
fn mmr_exact_formula_second_selection() {
let candidates = vec![("a", 0.9), ("b", 0.6), ("c", 0.3)];
let sim = vec![
1.0, 0.9, 0.1, 0.9, 1.0, 0.2, 0.1, 0.2, 1.0, ];
let result = mmr(&candidates, &sim, MmrConfig::new(0.5, 2));
assert_eq!(result[0].0, "a", "First should be 'a' (highest relevance)");
assert_eq!(
result[1].0, "c",
"Second should be 'c' (more diverse from 'a')"
);
}
#[test]
fn mmr_pure_diversity_equal_relevance() {
let candidates = vec![("a", 0.5), ("b", 0.5), ("c", 0.5)];
let sim = vec![
1.0, 0.99, 0.01, 0.99, 1.0, 0.01, 0.01, 0.01, 1.0, ];
let result = mmr(&candidates, &sim, MmrConfig::new(0.0, 3));
assert_eq!(result.len(), 3);
let ids: Vec<_> = result.iter().map(|(id, _)| *id).collect();
assert!(ids.contains(&"a"));
assert!(ids.contains(&"b"));
assert!(ids.contains(&"c"));
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn mmr_output_length_bounded(
n in 1usize..10,
k in 1usize..20,
lambda in 0.0f32..1.0,
) {
let candidates: Vec<(u32, f32)> = (0..n as u32)
.map(|i| (i, 1.0 - i as f32 * 0.1))
.collect();
let sim: Vec<f32> = (0..n * n).map(|_| 0.5).collect();
let result = mmr(&candidates, &sim, MmrConfig::new(lambda, k));
prop_assert!(result.len() <= k.min(n));
}
#[test]
fn mmr_diversity_uses_subtraction(k in 1usize..3usize) {
let candidates: Vec<(u32, f32)> = vec![(0, 0.9), (1, 0.9), (2, 0.9)];
let n = 3;
let mut sim: Vec<f32> = vec![0.0; n * n];
for i in 0..n {
for j in 0..n {
if i == j {
sim[i * n + j] = 1.0;
} else if (i == 0 && j == 1) || (i == 1 && j == 0) {
sim[i * n + j] = 0.9; } else {
sim[i * n + j] = 0.1; }
}
}
let result = mmr(&candidates, &sim, MmrConfig::new(0.0, k));
prop_assert!(!result.is_empty(), "MMR should return results");
let ids: std::collections::HashSet<u32> = result.iter().map(|(id, _)| *id).collect();
prop_assert_eq!(ids.len(), result.len(), "no duplicate IDs");
if k >= 2 {
let second_id = result[1].0;
prop_assert_eq!(second_id, 2u32, "doc2 should be second pick under pure diversity");
let second_score = result[1].1;
prop_assert!((0.0..=1.0).contains(&second_score),
"diversity score {second_score} out of [0,1]; division would exceed 1.0");
}
}
#[test]
fn mmr_unique_ids(n in 1usize..10) {
let candidates: Vec<(u32, f32)> = (0..n as u32)
.map(|i| (i, 1.0 - i as f32 * 0.1))
.collect();
let sim: Vec<f32> = (0..n * n).map(|_| 0.5).collect();
let result = mmr(&candidates, &sim, MmrConfig::default().with_k(n));
let mut seen = std::collections::HashSet::new();
for (id, _) in &result {
prop_assert!(seen.insert(*id), "Duplicate ID: {}", id);
}
}
#[test]
fn mmr_lambda_1_is_relevance_order(n in 2usize..8) {
let candidates: Vec<(u32, f32)> = (0..n as u32)
.map(|i| (i, 1.0 - i as f32 * 0.1))
.collect();
let sim: Vec<f32> = (0..n)
.flat_map(|i| (0..n).map(move |j| if i == j { 1.0 } else { 0.0 }))
.collect();
let result = mmr(&candidates, &sim, MmrConfig::new(1.0, n));
for window in result.windows(2) {
prop_assert!(window[0].1 >= window[1].1,
"Not sorted: {:?} >= {:?}", window[0], window[1]);
}
}
#[test]
fn mmr_empty_returns_empty(k in 0usize..10, lambda in 0.0f32..1.0) {
let candidates: Vec<(u32, f32)> = vec![];
let sim: Vec<f32> = vec![];
let result = mmr(&candidates, &sim, MmrConfig::new(lambda, k));
prop_assert!(result.is_empty());
}
#[test]
fn mmr_k_zero_returns_empty(n in 1usize..10) {
let candidates: Vec<(u32, f32)> = (0..n as u32)
.map(|i| (i, 0.5))
.collect();
let sim: Vec<f32> = vec![0.5; n * n];
let result = mmr(&candidates, &sim, MmrConfig::new(0.5, 0));
prop_assert!(result.is_empty());
}
#[test]
fn try_mmr_wrong_size_errors(n in 2usize..10, wrong_size in 0usize..5) {
let candidates: Vec<(u32, f32)> = (0..n as u32)
.map(|i| (i, 0.5))
.collect();
let correct_size = n * n;
let actual_size = if wrong_size == 0 { 0 } else { correct_size.saturating_sub(wrong_size) };
prop_assume!(actual_size != correct_size);
let sim: Vec<f32> = vec![0.5; actual_size];
let result = try_mmr(&candidates, &sim, MmrConfig::default());
prop_assert!(result.is_err());
}
#[test]
fn mmr_equal_relevance(n in 1usize..8) {
let candidates: Vec<(u32, f32)> = (0..n as u32)
.map(|i| (i, 0.5)) .collect();
let sim: Vec<f32> = vec![0.5; n * n];
let result = mmr(&candidates, &sim, MmrConfig::default().with_k(n));
prop_assert_eq!(result.len(), n);
}
#[test]
fn mmr_cosine_consistent_with_mmr(n in 2usize..6) {
let candidates: Vec<(u32, f32)> = (0..n as u32)
.map(|i| (i, 1.0 - i as f32 * 0.1))
.collect();
let embeddings: Vec<Vec<f32>> = (0..n)
.map(|i| {
let mut v = vec![0.0; n];
v[i] = 1.0;
v
})
.collect();
let mut sim = Vec::with_capacity(n * n);
for i in 0..n {
for j in 0..n {
sim.push(simd::cosine(&embeddings[i], &embeddings[j]));
}
}
let mmr_result = mmr(&candidates, &sim, MmrConfig::new(0.5, n));
let cosine_result = mmr_cosine(&candidates, &embeddings, MmrConfig::new(0.5, n));
let mmr_ids: std::collections::HashSet<_> = mmr_result.iter().map(|(id, _)| *id).collect();
let cosine_ids: std::collections::HashSet<_> = cosine_result.iter().map(|(id, _)| *id).collect();
prop_assert_eq!(mmr_ids, cosine_ids);
}
#[test]
fn dpp_empty_returns_empty(k in 0usize..10) {
let candidates: Vec<(u32, f32)> = vec![];
let embeddings: Vec<Vec<f32>> = vec![];
let result = dpp(&candidates, &embeddings, DppConfig::default().with_k(k));
prop_assert!(result.is_empty());
}
#[test]
fn dpp_k_zero_returns_empty(n in 1usize..10) {
let candidates: Vec<(u32, f32)> = (0..n as u32)
.map(|i| (i, 0.5))
.collect();
let embeddings: Vec<Vec<f32>> = (0..n)
.map(|i| {
let mut v = vec![0.0; 3];
v[i % 3] = 1.0;
v
})
.collect();
let result = dpp(&candidates, &embeddings, DppConfig::default().with_k(0));
prop_assert!(result.is_empty());
}
#[test]
fn dpp_selects_at_most_k(n in 2usize..10, k in 1usize..8) {
let candidates: Vec<(u32, f32)> = (0..n as u32)
.map(|i| (i, 0.5))
.collect();
let embeddings: Vec<Vec<f32>> = (0..n)
.map(|_| vec![1.0, 0.0, 0.0])
.collect();
let result = dpp(&candidates, &embeddings, DppConfig::default().with_k(k));
prop_assert!(result.len() <= k.min(n));
}
#[test]
fn dpp_prefers_orthogonal(n in 3usize..6) {
let mut embeddings: Vec<Vec<f32>> = vec![
vec![1.0, 0.0, 0.0],
vec![0.99, 0.1, 0.0], ];
for i in 2..n {
let mut v = vec![0.0; 3];
v[i % 3] = 1.0;
embeddings.push(v);
}
let candidates: Vec<(u32, f32)> = (0..n as u32)
.map(|i| (i, 0.9 - i as f32 * 0.05)) .collect();
let result = dpp(&candidates, &embeddings, DppConfig::default().with_k(2));
prop_assert_eq!(result[0].0, 0);
if n > 2 {
prop_assert_ne!(result[1].0, 1, "DPP should prefer orthogonal over similar");
}
}
#[test]
fn dpp_equal_embeddings(n in 1usize..8) {
let candidates: Vec<(u32, f32)> = (0..n as u32)
.map(|i| (i, 1.0 - i as f32 * 0.1))
.collect();
let embeddings: Vec<Vec<f32>> = vec![vec![1.0, 0.0]; n];
let result = dpp(&candidates, &embeddings, DppConfig::default().with_k(n));
prop_assert!(!result.is_empty() || n == 0);
}
}
}
#[cfg(test)]
mod dpp_tests {
use super::*;
#[test]
fn dpp_orthogonal_prefers_diverse() {
let candidates = vec![("a", 0.9), ("b", 0.85), ("c", 0.8)];
let embeddings = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
];
let result = dpp(&candidates, &embeddings, DppConfig::default().with_k(3));
assert_eq!(result.len(), 3);
}
#[test]
fn dpp_similar_items_penalized() {
let candidates = vec![("a", 0.95), ("b", 0.90), ("c", 0.85)];
let embeddings = vec![
vec![1.0, 0.0, 0.0],
vec![0.99, 0.1, 0.0], vec![0.0, 0.0, 1.0], ];
let result = dpp(&candidates, &embeddings, DppConfig::default().with_k(2));
assert_eq!(result.len(), 2);
assert_eq!(result[0].0, "a"); assert_eq!(result[1].0, "c");
}
#[test]
fn dpp_config_alpha() {
let config = DppConfig::default().with_alpha(10.0);
assert_eq!(config.alpha, 10.0);
let config = DppConfig::default().with_alpha(-1.0);
assert_eq!(config.alpha, 0.0);
}
}
#[cfg(test)]
mod failure_mode_tests {
use super::*;
#[test]
fn mmr_negative_similarity_handled() {
let candidates = vec![("a", 0.9), ("b", 0.85), ("c", 0.8)];
let sim = vec![
1.0, -0.9, 0.5, -0.9, 1.0, 0.5, 0.5, 0.5, 1.0, ];
let result = mmr(&candidates, &sim, MmrConfig::new(0.5, 2));
assert_eq!(result.len(), 2);
assert_eq!(result[0].0, "a");
assert_eq!(result[1].0, "b"); }
#[test]
fn mmr_similarity_outside_unit_range() {
let candidates = vec![("a", 0.9), ("b", 0.85)];
let sim = vec![
1.0, 2.0, 2.0, 1.0,
];
let result = mmr(&candidates, &sim, MmrConfig::new(0.5, 2));
assert_eq!(result.len(), 2);
}
#[test]
fn dpp_zero_norm_embeddings() {
let candidates = vec![("a", 0.9), ("b", 0.85)];
let embeddings = vec![
vec![0.0, 0.0, 0.0], vec![1.0, 0.0, 0.0], ];
let result = dpp(&candidates, &embeddings, DppConfig::default().with_k(2));
assert!(!result.is_empty());
}
#[test]
fn dpp_all_identical_embeddings() {
let candidates = vec![("a", 0.9), ("b", 0.85), ("c", 0.8)];
let embeddings = vec![
vec![1.0, 0.0, 0.0],
vec![1.0, 0.0, 0.0], vec![1.0, 0.0, 0.0], ];
let result = dpp(&candidates, &embeddings, DppConfig::default().with_k(3));
assert!(!result.is_empty());
}
#[test]
fn mmr_cosine_zero_norm_embeddings() {
let candidates = vec![("a", 0.9), ("b", 0.85)];
let embeddings = vec![
vec![0.0, 0.0], vec![1.0, 0.0], ];
let result = mmr_cosine(&candidates, &embeddings, MmrConfig::new(0.5, 2));
assert_eq!(result.len(), 2);
}
#[test]
fn dpp_anticorrelated_embeddings() {
let candidates = vec![("a", 0.9), ("b", 0.85)];
let embeddings = vec![
vec![1.0, 0.0],
vec![-1.0, 0.0], ];
let result = dpp(&candidates, &embeddings, DppConfig::default().with_k(2));
assert_eq!(result.len(), 2);
}
#[test]
fn mmr_nan_in_similarity() {
let candidates = vec![("a", 0.9), ("b", 0.85)];
let sim = vec![1.0, f32::NAN, f32::NAN, 1.0];
let result = mmr(&candidates, &sim, MmrConfig::new(0.5, 2));
assert!(!result.is_empty());
}
#[test]
fn dpp_nan_in_embeddings() {
let candidates = vec![("a", 0.9), ("b", 0.85)];
let embeddings = vec![vec![1.0, 0.0], vec![f32::NAN, 0.0]];
let result = dpp(&candidates, &embeddings, DppConfig::default().with_k(2));
assert!(!result.is_empty());
}
#[test]
fn mmr_preserves_original_scores() {
let candidates = vec![("a", 0.95), ("b", 0.85), ("c", 0.75)];
let sim = vec![1.0, 0.1, 0.1, 0.1, 1.0, 0.1, 0.1, 0.1, 1.0];
let result = mmr(&candidates, &sim, MmrConfig::new(0.5, 3));
for (id, score) in &result {
let original = candidates
.iter()
.find(|(i, _)| i == id)
.expect("result item must exist in candidates")
.1;
assert!(
(score - original).abs() < 1e-6,
"Score for {} was modified: {} vs {}",
id,
score,
original
);
}
}
}