pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
let denom = mag_a * mag_b;
if denom < f32::EPSILON {
return 0.0;
}
dot / denom
}
pub fn maximal_marginal_relevance(
query_embedding: &[f32],
embeddings: &[Vec<f32>],
k: usize,
lambda: f32,
) -> Vec<usize> {
if embeddings.is_empty() || k == 0 {
return Vec::new();
}
let k = k.min(embeddings.len());
let query_sims: Vec<f32> = embeddings
.iter()
.map(|e| cosine_similarity(query_embedding, e))
.collect();
let mut selected: Vec<usize> = Vec::with_capacity(k);
let mut remaining: Vec<usize> = (0..embeddings.len()).collect();
let first = remaining.iter().copied().max_by(|&a, &b| {
query_sims[a]
.partial_cmp(&query_sims[b])
.unwrap_or(std::cmp::Ordering::Equal)
});
if let Some(first_idx) = first {
selected.push(first_idx);
remaining.retain(|&i| i != first_idx);
}
while selected.len() < k && !remaining.is_empty() {
let mut best_idx = None;
let mut best_score = f32::NEG_INFINITY;
for &candidate in &remaining {
let relevance = query_sims[candidate];
let max_sim_to_selected = selected
.iter()
.map(|&s| cosine_similarity(&embeddings[candidate], &embeddings[s]))
.fold(f32::NEG_INFINITY, f32::max);
let mmr_score = lambda.mul_add(relevance, -(1.0 - lambda) * max_sim_to_selected);
if mmr_score > best_score {
best_score = mmr_score;
best_idx = Some(candidate);
}
}
if let Some(idx) = best_idx {
selected.push(idx);
remaining.retain(|&i| i != idx);
} else {
break;
}
}
selected
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn cosine_similarity_identical_vectors() {
let v = vec![1.0, 2.0, 3.0];
let sim = cosine_similarity(&v, &v);
assert!((sim - 1.0).abs() < 1e-5);
}
#[test]
fn cosine_similarity_orthogonal_vectors() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
let sim = cosine_similarity(&a, &b);
assert!(sim.abs() < 1e-5);
}
#[test]
fn cosine_similarity_zero_vector() {
let a = vec![1.0, 2.0];
let b = vec![0.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert!(sim.abs() < 1e-5);
}
#[test]
fn mmr_empty_embeddings() {
let query = vec![1.0, 0.0];
let result = maximal_marginal_relevance(&query, &[], 5, 0.5);
assert!(result.is_empty());
}
#[test]
fn mmr_k_zero() {
let query = vec![1.0, 0.0];
let embeddings = vec![vec![1.0, 0.0]];
let result = maximal_marginal_relevance(&query, &embeddings, 0, 0.5);
assert!(result.is_empty());
}
#[test]
fn mmr_selects_k_results() {
let query = vec![1.0, 0.0];
let embeddings = vec![
vec![1.0, 0.0],
vec![0.9, 0.1],
vec![0.0, 1.0],
vec![0.5, 0.5],
];
let result = maximal_marginal_relevance(&query, &embeddings, 3, 0.5);
assert_eq!(result.len(), 3);
}
#[test]
fn mmr_pure_relevance_selects_most_similar() {
let query = vec![1.0, 0.0];
let embeddings = vec![
vec![0.0, 1.0], vec![1.0, 0.0], vec![0.5, 0.5], ];
let result = maximal_marginal_relevance(&query, &embeddings, 1, 1.0);
assert_eq!(result, vec![1]); }
#[test]
fn mmr_diverse_selection() {
let query = vec![1.0, 0.0];
let embeddings = vec![
vec![1.0, 0.0], vec![0.99, 0.01], vec![0.0, 1.0], ];
let result = maximal_marginal_relevance(&query, &embeddings, 2, 0.0);
assert_eq!(result[0], 0); assert_eq!(result[1], 2); }
#[test]
fn mmr_k_exceeds_embeddings() {
let query = vec![1.0, 0.0];
let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let result = maximal_marginal_relevance(&query, &embeddings, 10, 0.5);
assert_eq!(result.len(), 2);
}
}