pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let mut dot = 0.0f32;
let mut mag_a = 0.0f32;
let mut mag_b = 0.0f32;
for i in 0..a.len() {
dot += a[i] * b[i];
mag_a += a[i] * a[i];
mag_b += b[i] * b[i];
}
let denom = mag_a.sqrt() * mag_b.sqrt();
if denom < 1e-10 {
0.0
} else {
dot / denom
}
}
pub fn normalize(v: &mut [f32]) {
let mag: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
if mag > 1e-10 {
for x in v.iter_mut() {
*x /= mag;
}
}
}
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y) * (x - y))
.sum::<f32>()
.sqrt()
}
pub fn top_k(query: &[f32], candidates: &[&[f32]], k: usize) -> Vec<(usize, f32)> {
let mut scores: Vec<(usize, f32)> = candidates
.iter()
.enumerate()
.map(|(i, c)| (i, cosine_similarity(query, c)))
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(k);
scores
}
pub fn centroid(vectors: &[&[f32]]) -> Vec<f32> {
if vectors.is_empty() {
return vec![];
}
let dim = vectors[0].len();
let n = vectors.len() as f32;
let mut result = vec![0.0f32; dim];
for v in vectors {
for (i, &x) in v.iter().enumerate() {
result[i] += x;
}
}
for x in result.iter_mut() {
*x /= n;
}
result
}
pub fn sanitize_nan(v: &mut [f32]) {
for x in v.iter_mut() {
if x.is_nan() {
*x = 0.0;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity_identical() {
let a = vec![1.0f32, 0.0, 0.0];
let b = vec![1.0f32, 0.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert!((sim - 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let a = vec![1.0f32, 0.0, 0.0];
let b = vec![0.0f32, 1.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert!(sim.abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_384_dim() {
let mut a = vec![0.0f32; 384];
let mut b = vec![0.0f32; 384];
a[0] = 1.0;
b[0] = 0.5;
b[1] = 0.866;
let sim = cosine_similarity(&a, &b);
assert!((sim - 0.5).abs() < 1e-3);
}
#[test]
fn test_normalize() {
let mut v = vec![3.0f32, 4.0];
normalize(&mut v);
assert!((v[0] - 0.6).abs() < 1e-6);
assert!((v[1] - 0.8).abs() < 1e-6);
}
#[test]
fn test_normalize_zero_vector() {
let mut v = vec![0.0f32; 384];
normalize(&mut v);
assert!(v.iter().all(|&x| x == 0.0));
}
#[test]
fn test_euclidean_distance() {
let a = vec![0.0f32, 0.0];
let b = vec![3.0f32, 4.0];
let d = euclidean_distance(&a, &b);
assert!((d - 5.0).abs() < 1e-6);
}
#[test]
fn test_top_k() {
let query = vec![1.0f32, 0.0, 0.0];
let candidates: Vec<Vec<f32>> = vec![
vec![0.0, 1.0, 0.0], vec![1.0, 0.0, 0.0], vec![0.5, 0.5, 0.0], ];
let refs: Vec<&[f32]> = candidates.iter().map(|v| v.as_slice()).collect();
let results = top_k(&query, &refs, 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, 1); }
#[test]
fn test_centroid() {
let vecs: Vec<Vec<f32>> = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
let c = centroid(&refs);
assert!((c[0] - 0.5).abs() < 1e-6);
assert!((c[1] - 0.5).abs() < 1e-6);
}
}