pub fn cosine(a: &[f32], b: &[f32]) -> f32 {
if a.is_empty() || a.len() != b.len() {
return 0.0;
}
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
if norm_a <= 0.0 || norm_b <= 0.0 {
return 0.0;
}
let denom = (norm_a.sqrt()) * (norm_b.sqrt());
if denom <= 0.0 {
return 0.0;
}
let sim = dot / denom;
if sim.is_finite() {
sim.clamp(-1.0, 1.0)
} else {
0.0
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Scored {
pub index: usize,
pub score: f32,
}
pub fn top_k(query: &[f32], corpus: &[Vec<f32>], k: usize) -> Vec<Scored> {
if k == 0 || corpus.is_empty() {
return Vec::new();
}
let mut scored: Vec<Scored> = corpus
.iter()
.enumerate()
.map(|(index, vec)| Scored {
index,
score: cosine(query, vec),
})
.collect();
scored.sort_by(|a, b| {
b.score
.total_cmp(&a.score)
.then_with(|| a.index.cmp(&b.index))
});
scored.truncate(k);
scored
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cosine_identical_vectors_is_one() {
let v = vec![1.0, 2.0, 3.0];
assert!((cosine(&v, &v) - 1.0).abs() < 1e-6);
}
#[test]
fn cosine_orthogonal_is_zero() {
assert!(cosine(&[1.0, 0.0], &[0.0, 1.0]).abs() < 1e-6);
}
#[test]
fn cosine_opposite_is_negative_one() {
assert!((cosine(&[1.0, 0.0], &[-1.0, 0.0]) + 1.0).abs() < 1e-6);
}
#[test]
fn cosine_handles_degenerate_inputs() {
assert_eq!(cosine(&[], &[]), 0.0);
assert_eq!(cosine(&[1.0, 2.0], &[1.0]), 0.0); assert_eq!(cosine(&[0.0, 0.0], &[1.0, 1.0]), 0.0); assert_eq!(cosine(&[f32::NAN, 1.0], &[1.0, 1.0]), 0.0); }
#[test]
fn top_k_ranks_highest_first() {
let query = vec![1.0, 0.0];
let corpus = vec![
vec![0.0, 1.0], vec![1.0, 0.0], vec![1.0, 1.0], vec![-1.0, 0.0], ];
let ranked = top_k(&query, &corpus, 3);
assert_eq!(ranked.len(), 3);
assert_eq!(ranked[0].index, 1);
assert_eq!(ranked[1].index, 2);
assert_eq!(ranked[2].index, 0);
}
#[test]
fn top_k_is_deterministic_on_ties() {
let query = vec![1.0, 0.0];
let corpus = vec![vec![1.0, 0.0], vec![1.0, 0.0], vec![1.0, 0.0]];
let ranked = top_k(&query, &corpus, 3);
assert_eq!(
ranked.iter().map(|s| s.index).collect::<Vec<_>>(),
vec![0, 1, 2]
);
}
#[test]
fn top_k_empty_and_zero_k() {
assert!(top_k(&[1.0], &[], 5).is_empty());
assert!(top_k(&[1.0], &[vec![1.0]], 0).is_empty());
}
#[test]
fn top_k_clamps_to_corpus_size() {
let ranked = top_k(&[1.0], &[vec![1.0], vec![0.5]], 100);
assert_eq!(ranked.len(), 2);
}
}