fn dot(a: &[f64], b: &[f64]) -> f64 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
fn norm(a: &[f64]) -> f64 {
a.iter().map(|x| x * x).sum::<f64>().sqrt()
}
fn cosine_similarity_pair(a: &[f64], b: &[f64]) -> f64 {
let na = norm(a);
let nb = norm(b);
if na == 0.0 || nb == 0.0 {
return 0.0;
}
dot(a, b) / (na * nb)
}
pub fn cosine_similarity_matrix(x: &[Vec<f64>], y: &[Vec<f64>]) -> Result<Vec<Vec<f64>>, String> {
if x.is_empty() || y.is_empty() {
return Ok(vec![vec![]]);
}
let dim_x = x[0].len();
let dim_y = y[0].len();
if dim_x != dim_y {
return Err(format!(
"Number of columns in X and Y must be the same. X has {} columns and Y has {} columns.",
dim_x, dim_y
));
}
for (i, row) in x.iter().enumerate() {
if row.len() != dim_x {
return Err(format!(
"Inconsistent row length in X: row {} has {} columns, expected {}.",
i,
row.len(),
dim_x
));
}
}
for (i, row) in y.iter().enumerate() {
if row.len() != dim_y {
return Err(format!(
"Inconsistent row length in Y: row {} has {} columns, expected {}.",
i,
row.len(),
dim_y
));
}
}
let result: Vec<Vec<f64>> = x
.iter()
.map(|row_x| {
y.iter()
.map(|row_y| cosine_similarity_pair(row_x, row_y))
.collect()
})
.collect();
Ok(result)
}
pub fn maximal_marginal_relevance(
query_embedding: &[f64],
embedding_list: &[Vec<f64>],
lambda_mult: f64,
k: usize,
) -> Vec<usize> {
let effective_k = k.min(embedding_list.len());
if effective_k == 0 {
return vec![];
}
let similarity_to_query: Vec<f64> = embedding_list
.iter()
.map(|emb| cosine_similarity_pair(query_embedding, emb))
.collect();
let most_similar = similarity_to_query
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx)
.unwrap();
let mut idxs: Vec<usize> = vec![most_similar];
let mut selected: Vec<Vec<f64>> = vec![embedding_list[most_similar].clone()];
while idxs.len() < effective_k {
let mut best_score = f64::NEG_INFINITY;
let mut idx_to_add = 0;
for (i, query_score) in similarity_to_query.iter().enumerate() {
if idxs.contains(&i) {
continue;
}
let redundant_score = selected
.iter()
.map(|sel| cosine_similarity_pair(&embedding_list[i], sel))
.fold(f64::NEG_INFINITY, f64::max);
let equation_score = lambda_mult * query_score - (1.0 - lambda_mult) * redundant_score;
if equation_score > best_score {
best_score = equation_score;
idx_to_add = i;
}
}
idxs.push(idx_to_add);
selected.push(embedding_list[idx_to_add].clone());
}
idxs
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f64 = 1e-9;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < EPSILON
}
#[test]
fn test_cosine_similarity_identical_vectors() {
let a = vec![1.0, 2.0, 3.0];
assert!(approx_eq(cosine_similarity_pair(&a, &a), 1.0));
}
#[test]
fn test_cosine_similarity_orthogonal_vectors() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert!(approx_eq(cosine_similarity_pair(&a, &b), 0.0));
}
#[test]
fn test_cosine_similarity_opposite_vectors() {
let a = vec![1.0, 0.0];
let b = vec![-1.0, 0.0];
assert!(approx_eq(cosine_similarity_pair(&a, &b), -1.0));
}
#[test]
fn test_cosine_similarity_zero_vector() {
let a = vec![1.0, 2.0];
let zero = vec![0.0, 0.0];
assert!(approx_eq(cosine_similarity_pair(&a, &zero), 0.0));
assert!(approx_eq(cosine_similarity_pair(&zero, &a), 0.0));
}
#[test]
fn test_matrix_identity() {
let x = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let y = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let result = cosine_similarity_matrix(&x, &y).unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].len(), 2);
assert!(approx_eq(result[0][0], 1.0));
assert!(approx_eq(result[0][1], 0.0));
assert!(approx_eq(result[1][0], 0.0));
assert!(approx_eq(result[1][1], 1.0));
}
#[test]
fn test_matrix_empty_x() {
let x: Vec<Vec<f64>> = vec![];
let y = vec![vec![1.0, 0.0]];
let result = cosine_similarity_matrix(&x, &y).unwrap();
assert_eq!(result, vec![Vec::<f64>::new()]);
}
#[test]
fn test_matrix_empty_y() {
let x = vec![vec![1.0, 0.0]];
let y: Vec<Vec<f64>> = vec![];
let result = cosine_similarity_matrix(&x, &y).unwrap();
assert_eq!(result, vec![Vec::<f64>::new()]);
}
#[test]
fn test_matrix_dimension_mismatch() {
let x = vec![vec![1.0, 0.0]];
let y = vec![vec![1.0, 0.0, 0.0]];
let result = cosine_similarity_matrix(&x, &y);
assert!(result.is_err());
}
#[test]
fn test_matrix_single_row() {
let x = vec![vec![1.0, 1.0]];
let y = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
let result = cosine_similarity_matrix(&x, &y).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].len(), 3);
let expected_cos = 1.0 / 2.0_f64.sqrt();
assert!(approx_eq(result[0][0], expected_cos));
assert!(approx_eq(result[0][1], expected_cos));
assert!(approx_eq(result[0][2], 1.0));
}
#[test]
fn test_matrix_non_square() {
let x = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
let y = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
];
let result = cosine_similarity_matrix(&x, &y).unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].len(), 3);
assert!(approx_eq(result[0][0], 1.0));
assert!(approx_eq(result[0][1], 0.0));
assert!(approx_eq(result[0][2], 0.0));
assert!(approx_eq(result[1][0], 0.0));
assert!(approx_eq(result[1][1], 1.0));
assert!(approx_eq(result[1][2], 0.0));
}
#[test]
fn test_mmr_empty_embeddings() {
let query = vec![1.0, 0.0];
let embeddings: Vec<Vec<f64>> = vec![];
let result = maximal_marginal_relevance(&query, &embeddings, 0.5, 4);
assert!(result.is_empty());
}
#[test]
fn test_mmr_k_zero() {
let query = vec![1.0, 0.0];
let embeddings = vec![vec![1.0, 0.0]];
let result = maximal_marginal_relevance(&query, &embeddings, 0.5, 0);
assert!(result.is_empty());
}
#[test]
fn test_mmr_k_greater_than_embeddings() {
let query = vec![1.0, 0.0];
let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let result = maximal_marginal_relevance(&query, &embeddings, 0.5, 10);
assert_eq!(result.len(), 2);
}
#[test]
fn test_mmr_selects_most_similar_first() {
let query = vec![1.0, 0.0];
let embeddings = vec![
vec![0.0, 1.0], vec![1.0, 0.0], vec![0.5, 0.5], ];
let result = maximal_marginal_relevance(&query, &embeddings, 0.5, 1);
assert_eq!(result, vec![1]); }
#[test]
fn test_mmr_diversity() {
let query = vec![1.0, 0.0];
let embeddings = vec![
vec![1.0, 0.0], vec![0.99, 0.01], vec![0.0, 1.0], ];
let result = maximal_marginal_relevance(&query, &embeddings, 0.0, 2);
assert_eq!(result[0], 0);
assert_eq!(result[1], 2);
}
#[test]
fn test_mmr_pure_relevance() {
let query = vec![1.0, 0.0];
let embeddings = vec![
vec![0.0, 1.0], vec![1.0, 0.0], vec![0.9, 0.1], ];
let result = maximal_marginal_relevance(&query, &embeddings, 1.0, 3);
assert_eq!(result[0], 1); assert_eq!(result[1], 2); assert_eq!(result[2], 0); }
#[test]
fn test_mmr_returns_unique_indices() {
let query = vec![1.0, 0.0, 0.0];
let embeddings = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
vec![0.5, 0.5, 0.0],
vec![0.5, 0.0, 0.5],
];
let result = maximal_marginal_relevance(&query, &embeddings, 0.5, 5);
assert_eq!(result.len(), 5);
let mut sorted = result.clone();
sorted.sort();
sorted.dedup();
assert_eq!(sorted.len(), 5, "All indices should be unique");
}
#[test]
fn test_mmr_single_embedding() {
let query = vec![1.0, 0.0];
let embeddings = vec![vec![0.5, 0.5]];
let result = maximal_marginal_relevance(&query, &embeddings, 0.5, 4);
assert_eq!(result, vec![0]);
}
}