use super::{DistanceMetric, compute_distance};
use grafeo_common::types::NodeId;
#[inline]
fn distance_to_similarity(distance: f32, metric: DistanceMetric) -> f32 {
match metric {
DistanceMetric::Cosine => 1.0 - distance,
DistanceMetric::Euclidean | DistanceMetric::Manhattan => 1.0 / (1.0 + distance),
DistanceMetric::DotProduct => -distance,
}
}
#[must_use]
pub fn mmr_select(
query: &[f32],
candidates: &[(NodeId, f32, &[f32])],
k: usize,
lambda: f32,
metric: DistanceMetric,
) -> Vec<(NodeId, f32)> {
if candidates.is_empty() || k == 0 {
return Vec::new();
}
let k = k.min(candidates.len());
let lambda = lambda.clamp(0.0, 1.0);
let _ = query;
let query_similarities: Vec<f32> = candidates
.iter()
.map(|(_, dist, _)| distance_to_similarity(*dist, metric))
.collect();
let mut selected_indices: Vec<usize> = Vec::with_capacity(k);
let mut remaining: Vec<usize> = (0..candidates.len()).collect();
for _ in 0..k {
let mut best_pos = 0;
let mut best_mmr = f32::NEG_INFINITY;
for (pos, &cand_idx) in remaining.iter().enumerate() {
let relevance = query_similarities[cand_idx];
let max_sim_to_selected = if selected_indices.is_empty() {
0.0
} else {
selected_indices
.iter()
.map(|&sel_idx| {
let dist =
compute_distance(candidates[cand_idx].2, candidates[sel_idx].2, metric);
distance_to_similarity(dist, metric)
})
.fold(f32::NEG_INFINITY, f32::max)
};
let mmr_score = lambda * relevance - (1.0 - lambda) * max_sim_to_selected;
if mmr_score > best_mmr {
best_mmr = mmr_score;
best_pos = pos;
}
}
let chosen = remaining.swap_remove(best_pos);
selected_indices.push(chosen);
}
selected_indices
.iter()
.map(|&idx| (candidates[idx].0, candidates[idx].1))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_candidates() {
let result = mmr_select(&[1.0, 0.0], &[], 5, 0.5, DistanceMetric::Euclidean);
assert!(result.is_empty());
}
#[test]
fn test_k_zero() {
let v = [1.0f32, 0.0];
let candidates = vec![(NodeId::new(1), 0.0, v.as_slice())];
let result = mmr_select(&[1.0, 0.0], &candidates, 0, 0.5, DistanceMetric::Euclidean);
assert!(result.is_empty());
}
#[test]
fn test_lambda_one_is_pure_relevance() {
let query = [1.0f32, 0.0, 0.0];
let v1 = [0.9f32, 0.1, 0.0]; let v2 = [0.5f32, 0.5, 0.0]; let v3 = [0.0f32, 1.0, 0.0];
let d1 = compute_distance(&query, &v1, DistanceMetric::Euclidean);
let d2 = compute_distance(&query, &v2, DistanceMetric::Euclidean);
let d3 = compute_distance(&query, &v3, DistanceMetric::Euclidean);
let candidates = vec![
(NodeId::new(1), d1, v1.as_slice()),
(NodeId::new(2), d2, v2.as_slice()),
(NodeId::new(3), d3, v3.as_slice()),
];
let result = mmr_select(&query, &candidates, 3, 1.0, DistanceMetric::Euclidean);
assert_eq!(result.len(), 3);
assert_eq!(result[0].0, NodeId::new(1));
}
#[test]
fn test_diversity_avoids_redundancy() {
let query = [1.0f32, 0.0, 0.0];
let v1 = [0.9f32, 0.1, 0.0]; let v2 = [0.89f32, 0.11, 0.0]; let v3 = [0.0f32, 0.0, 1.0];
let d1 = compute_distance(&query, &v1, DistanceMetric::Euclidean);
let d2 = compute_distance(&query, &v2, DistanceMetric::Euclidean);
let d3 = compute_distance(&query, &v3, DistanceMetric::Euclidean);
let candidates = vec![
(NodeId::new(1), d1, v1.as_slice()),
(NodeId::new(2), d2, v2.as_slice()),
(NodeId::new(3), d3, v3.as_slice()),
];
let result = mmr_select(&query, &candidates, 2, 0.5, DistanceMetric::Euclidean);
assert_eq!(result.len(), 2);
assert_eq!(result[0].0, NodeId::new(1)); assert_eq!(result[1].0, NodeId::new(3)); }
#[test]
fn test_k_larger_than_candidates() {
let query = [1.0f32, 0.0];
let v1 = [0.9f32, 0.1];
let v2 = [0.5f32, 0.5];
let d1 = compute_distance(&query, &v1, DistanceMetric::Cosine);
let d2 = compute_distance(&query, &v2, DistanceMetric::Cosine);
let candidates = vec![
(NodeId::new(1), d1, v1.as_slice()),
(NodeId::new(2), d2, v2.as_slice()),
];
let result = mmr_select(&query, &candidates, 10, 0.5, DistanceMetric::Cosine);
assert_eq!(result.len(), 2); }
#[test]
fn test_returns_original_distances() {
let query = [1.0f32, 0.0, 0.0];
let v1 = [0.9f32, 0.1, 0.0];
let d1 = compute_distance(&query, &v1, DistanceMetric::Euclidean);
let candidates = vec![(NodeId::new(1), d1, v1.as_slice())];
let result = mmr_select(&query, &candidates, 1, 0.5, DistanceMetric::Euclidean);
assert_eq!(result[0].1, d1);
}
#[test]
fn test_all_metrics() {
for metric in [
DistanceMetric::Cosine,
DistanceMetric::Euclidean,
DistanceMetric::DotProduct,
DistanceMetric::Manhattan,
] {
let query = [1.0f32, 0.0, 0.0];
let v1 = [0.9f32, 0.1, 0.0];
let v2 = [0.0f32, 1.0, 0.0];
let d1 = compute_distance(&query, &v1, metric);
let d2 = compute_distance(&query, &v2, metric);
let candidates = vec![
(NodeId::new(1), d1, v1.as_slice()),
(NodeId::new(2), d2, v2.as_slice()),
];
let result = mmr_select(&query, &candidates, 2, 0.5, metric);
assert_eq!(result.len(), 2, "failed for metric {metric:?}");
}
}
#[test]
fn test_distance_to_similarity_cosine() {
assert!((distance_to_similarity(0.0, DistanceMetric::Cosine) - 1.0).abs() < 1e-6);
assert!((distance_to_similarity(1.0, DistanceMetric::Cosine) - 0.0).abs() < 1e-6);
}
#[test]
fn test_distance_to_similarity_euclidean() {
assert!((distance_to_similarity(0.0, DistanceMetric::Euclidean) - 1.0).abs() < 1e-6);
assert!(distance_to_similarity(1000.0, DistanceMetric::Euclidean) < 0.01);
}
#[test]
fn test_distance_to_similarity_dot_product() {
assert!((distance_to_similarity(-32.0, DistanceMetric::DotProduct) - 32.0).abs() < 1e-6);
}
#[test]
fn test_single_candidate() {
let query = [1.0f32, 0.0];
let v1 = [0.5f32, 0.5];
let d1 = compute_distance(&query, &v1, DistanceMetric::Cosine);
let candidates = vec![(NodeId::new(42), d1, v1.as_slice())];
let result = mmr_select(&query, &candidates, 1, 0.5, DistanceMetric::Cosine);
assert_eq!(result.len(), 1);
assert_eq!(result[0].0, NodeId::new(42));
}
}