use serde_json::Value;
#[derive(Debug, Clone)]
pub struct MmrConfig {
pub lambda: f32,
pub top_k: usize,
pub candidate_pool: usize,
}
impl Default for MmrConfig {
fn default() -> Self {
Self {
lambda: 0.7,
top_k: 10,
candidate_pool: 50,
}
}
}
#[derive(Debug, Clone)]
pub struct MmrCandidate {
pub id: i64,
pub content: String,
pub embedding: Vec<f32>,
pub original_score: f32,
pub metadata: Option<Value>,
}
#[derive(Debug, Clone)]
pub struct MmrResult {
pub id: i64,
pub content: String,
pub score: f32,
pub mmr_score: f32,
pub metadata: Option<Value>,
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.is_empty() || b.is_empty() || a.len() != b.len() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if mag_a == 0.0 || mag_b == 0.0 {
return 0.0;
}
(dot / (mag_a * mag_b)).clamp(-1.0, 1.0)
}
pub fn mmr_select(
query_embedding: &[f32],
candidates: Vec<MmrCandidate>,
config: &MmrConfig,
) -> Vec<MmrResult> {
if candidates.is_empty() {
return Vec::new();
}
let mut pool: Vec<MmrCandidate> = candidates.into_iter().take(config.candidate_pool).collect();
let target_k = config.top_k.min(pool.len());
let mut selected: Vec<MmrResult> = Vec::with_capacity(target_k);
let mut selected_embeddings: Vec<Vec<f32>> = Vec::with_capacity(target_k);
while selected.len() < target_k && !pool.is_empty() {
let mut best_idx: Option<usize> = None;
let mut best_mmr = f32::NEG_INFINITY;
for (i, candidate) in pool.iter().enumerate() {
let relevance = if query_embedding.is_empty() {
candidate.original_score
} else {
cosine_similarity(query_embedding, &candidate.embedding)
};
let max_redundancy = if selected_embeddings.is_empty() {
0.0
} else {
selected_embeddings
.iter()
.map(|s| cosine_similarity(&candidate.embedding, s))
.fold(f32::NEG_INFINITY, f32::max)
};
let mmr_score = config.lambda * relevance - (1.0 - config.lambda) * max_redundancy;
if best_idx.is_none() || mmr_score > best_mmr {
best_mmr = mmr_score;
best_idx = Some(i);
}
}
if let Some(idx) = best_idx {
let candidate = pool.remove(idx);
selected_embeddings.push(candidate.embedding.clone());
selected.push(MmrResult {
id: candidate.id,
content: candidate.content,
score: candidate.original_score,
mmr_score: best_mmr,
metadata: candidate.metadata,
});
} else {
break;
}
}
selected
}
#[cfg(test)]
mod tests {
use super::*;
fn make_candidate(id: i64, embedding: Vec<f32>, score: f32) -> MmrCandidate {
MmrCandidate {
id,
content: format!("content {id}"),
embedding,
original_score: score,
metadata: None,
}
}
#[test]
fn test_cosine_similarity_identical() {
let v = vec![1.0_f32, 2.0, 3.0];
let sim = cosine_similarity(&v, &v);
assert!(
(sim - 1.0).abs() < 1e-6,
"identical vectors must have similarity 1.0, got {sim}"
);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let a = vec![1.0_f32, 0.0, 0.0];
let b = vec![0.0_f32, 1.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert!(
sim.abs() < 1e-6,
"orthogonal vectors must have similarity 0.0, got {sim}"
);
}
#[test]
fn test_cosine_similarity_zero_vector() {
let zero = vec![0.0_f32, 0.0, 0.0];
let other = vec![1.0_f32, 2.0, 3.0];
assert_eq!(cosine_similarity(&zero, &other), 0.0);
assert_eq!(cosine_similarity(&other, &zero), 0.0);
assert_eq!(cosine_similarity(&zero, &zero), 0.0);
}
#[test]
fn test_mmr_empty_candidates() {
let config = MmrConfig::default();
let result = mmr_select(&[1.0, 0.0], vec![], &config);
assert!(result.is_empty());
}
#[test]
fn test_mmr_single_candidate() {
let config = MmrConfig {
top_k: 5,
..Default::default()
};
let candidates = vec![make_candidate(42, vec![1.0, 0.0], 0.9)];
let result = mmr_select(&[1.0, 0.0], candidates, &config);
assert_eq!(result.len(), 1);
assert_eq!(result[0].id, 42);
}
#[test]
fn test_mmr_top_k_limit() {
let config = MmrConfig {
top_k: 3,
candidate_pool: 50,
lambda: 0.7,
};
let candidates: Vec<MmrCandidate> = (0..10)
.map(|i| {
let emb = vec![i as f32 + 1.0, 0.0];
make_candidate(i, emb, 1.0 - i as f32 * 0.05)
})
.collect();
let result = mmr_select(&[1.0, 0.0], candidates, &config);
assert_eq!(result.len(), 3, "must respect top_k=3");
}
#[test]
fn test_mmr_pure_relevance() {
let config = MmrConfig {
lambda: 1.0,
top_k: 3,
candidate_pool: 50,
};
let candidates = vec![
make_candidate(1, vec![1.0, 0.0, 0.0], 0.9),
make_candidate(2, vec![0.0, 1.0, 0.0], 0.7),
make_candidate(3, vec![0.0, 0.0, 1.0], 0.5),
];
let query = vec![1.0_f32, 0.0, 0.0];
let result = mmr_select(&query, candidates, &config);
assert_eq!(result.len(), 3);
assert_eq!(result[0].id, 1, "first pick must be the most relevant");
}
#[test]
fn test_mmr_pure_diversity() {
let config = MmrConfig {
lambda: 0.0,
top_k: 3,
candidate_pool: 50,
};
let candidates = vec![
make_candidate(1, vec![1.0, 0.0], 0.9), make_candidate(2, vec![1.0, 0.0], 0.8), make_candidate(3, vec![0.0, 1.0], 0.5), ];
let query = vec![0.0_f32, 0.0]; let result = mmr_select(&query, candidates, &config);
assert_eq!(result.len(), 3);
assert_eq!(
result[1].id, 3,
"second pick must be the most diverse (orthogonal) candidate"
);
}
#[test]
fn test_mmr_balanced() {
let config = MmrConfig {
lambda: 0.5,
top_k: 3,
candidate_pool: 50,
};
let candidates = vec![
make_candidate(1, vec![1.0, 0.0, 0.0], 0.9),
make_candidate(2, vec![0.9, 0.1, 0.0], 0.85), make_candidate(3, vec![0.0, 1.0, 0.0], 0.6), make_candidate(4, vec![0.0, 0.0, 1.0], 0.4), ];
let query = vec![1.0_f32, 0.0, 0.0];
let result = mmr_select(&query, candidates, &config);
assert_eq!(result.len(), 3);
let ids: Vec<i64> = result.iter().map(|r| r.id).collect();
assert!(
ids.contains(&3) || ids.contains(&4),
"balanced MMR must include at least one diverse candidate, got ids: {ids:?}"
);
}
}