#[must_use]
pub fn dot(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "dot: length mismatch");
let arch = pulp::Arch::new();
arch.dispatch(|| {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
})
}
#[must_use]
#[allow(clippy::float_cmp)]
pub fn cosine(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "cosine: length mismatch");
let arch = pulp::Arch::new();
arch.dispatch(|| {
let (d, na, nb) = a
.iter()
.zip(b.iter())
.fold((0.0f32, 0.0f32, 0.0f32), |(d, na, nb), (&x, &y)| {
(x.mul_add(y, d), x.mul_add(x, na), y.mul_add(y, nb))
});
if na == 0.0 || nb == 0.0 {
0.0
} else {
d / (na.sqrt() * nb.sqrt())
}
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[allow(clippy::float_cmp)]
fn dot_basic() {
assert_eq!(dot(&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]), 32.0);
}
#[test]
fn cosine_orthogonal_is_zero() {
assert!((cosine(&[1.0, 0.0], &[0.0, 1.0])).abs() < 1e-6);
}
#[test]
fn cosine_parallel_is_one() {
let c = cosine(&[1.0, 2.0, 3.0], &[2.0, 4.0, 6.0]);
assert!((c - 1.0).abs() < 1e-6, "got {c}");
}
#[test]
#[allow(clippy::float_cmp)]
fn cosine_zero_vector_returns_zero() {
assert_eq!(cosine(&[0.0, 0.0, 0.0], &[1.0, 2.0, 3.0]), 0.0);
}
}