#![deny(missing_docs)]
pub fn cosine(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "vector length mismatch");
let mut dot = 0.0_f32;
let mut na = 0.0_f32;
let mut nb = 0.0_f32;
for i in 0..a.len() {
let ai = a[i];
let bi = b[i];
dot += ai * bi;
na += ai * ai;
nb += bi * bi;
}
let denom = (na * nb).sqrt();
if denom == 0.0 {
0.0
} else {
dot / denom
}
}
pub fn norm(v: &[f32]) -> f32 {
v.iter().map(|x| x * x).sum::<f32>().sqrt()
}
pub fn cosine_with_norm(a: &[f32], b: &[f32], b_norm: f32) -> f32 {
assert_eq!(a.len(), b.len(), "vector length mismatch");
let mut dot = 0.0_f32;
let mut na = 0.0_f32;
for i in 0..a.len() {
let ai = a[i];
let bi = b[i];
dot += ai * bi;
na += ai * ai;
}
let denom = na.sqrt() * b_norm;
if denom == 0.0 {
0.0
} else {
dot / denom
}
}
pub fn batch_cosine<'a, I>(q: &[f32], candidates: I) -> Vec<f32>
where
I: IntoIterator<Item = &'a [f32]>,
{
let q_norm = norm(q);
if q_norm == 0.0 {
return candidates.into_iter().map(|_| 0.0).collect();
}
let mut out = Vec::new();
for c in candidates {
assert_eq!(c.len(), q.len(), "vector length mismatch");
let mut dot = 0.0_f32;
let mut nc = 0.0_f32;
for i in 0..q.len() {
dot += q[i] * c[i];
nc += c[i] * c[i];
}
let denom = q_norm * nc.sqrt();
out.push(if denom == 0.0 { 0.0 } else { dot / denom });
}
out
}