Skip to main content

cosine_fast/
lib.rs

1//! # cosine-fast
2//!
3//! Hot-loop cosine similarity for f32 slices. Scalar core that the
4//! compiler auto-vectorizes well on AArch64 NEON and x86 AVX2; an
5//! optional `precompute_norm` lets you skip the per-call sqrt when the
6//! same query is compared against many candidates.
7//!
8//! ## Example
9//!
10//! ```
11//! use cosine_fast::{cosine, batch_cosine};
12//! let a = vec![1.0f32, 0.0, 0.0];
13//! let b = vec![0.0f32, 1.0, 0.0];
14//! assert!((cosine(&a, &b) - 0.0).abs() < 1e-6);
15//!
16//! let q = vec![1.0f32, 2.0, 3.0];
17//! let cands = vec![
18//!     vec![1.0, 2.0, 3.0], // self
19//!     vec![0.0, 0.0, 1.0],
20//! ];
21//! let out = batch_cosine(&q, cands.iter().map(|v| v.as_slice()));
22//! assert!((out[0] - 1.0).abs() < 1e-6);
23//! ```
24
25#![deny(missing_docs)]
26
27/// Cosine similarity between two equal-length f32 slices.
28///
29/// Returns 0.0 when either input has zero norm.
30pub fn cosine(a: &[f32], b: &[f32]) -> f32 {
31    assert_eq!(a.len(), b.len(), "vector length mismatch");
32    let mut dot = 0.0_f32;
33    let mut na = 0.0_f32;
34    let mut nb = 0.0_f32;
35    for i in 0..a.len() {
36        let ai = a[i];
37        let bi = b[i];
38        dot += ai * bi;
39        na += ai * ai;
40        nb += bi * bi;
41    }
42    let denom = (na * nb).sqrt();
43    if denom == 0.0 {
44        0.0
45    } else {
46        dot / denom
47    }
48}
49
50/// Compute the L2 norm of `v`. Useful as `precompute_norm` for hot-path
51/// queries.
52pub fn norm(v: &[f32]) -> f32 {
53    v.iter().map(|x| x * x).sum::<f32>().sqrt()
54}
55
56/// Cosine similarity when you already know one side's norm.
57/// `b_norm` should equal `norm(b)`.
58pub fn cosine_with_norm(a: &[f32], b: &[f32], b_norm: f32) -> f32 {
59    assert_eq!(a.len(), b.len(), "vector length mismatch");
60    let mut dot = 0.0_f32;
61    let mut na = 0.0_f32;
62    for i in 0..a.len() {
63        let ai = a[i];
64        let bi = b[i];
65        dot += ai * bi;
66        na += ai * ai;
67    }
68    let denom = na.sqrt() * b_norm;
69    if denom == 0.0 {
70        0.0
71    } else {
72        dot / denom
73    }
74}
75
76/// Compute cosine similarity between `q` and every candidate.
77pub fn batch_cosine<'a, I>(q: &[f32], candidates: I) -> Vec<f32>
78where
79    I: IntoIterator<Item = &'a [f32]>,
80{
81    let q_norm = norm(q);
82    if q_norm == 0.0 {
83        return candidates.into_iter().map(|_| 0.0).collect();
84    }
85    let mut out = Vec::new();
86    for c in candidates {
87        assert_eq!(c.len(), q.len(), "vector length mismatch");
88        let mut dot = 0.0_f32;
89        let mut nc = 0.0_f32;
90        for i in 0..q.len() {
91            dot += q[i] * c[i];
92            nc += c[i] * c[i];
93        }
94        let denom = q_norm * nc.sqrt();
95        out.push(if denom == 0.0 { 0.0 } else { dot / denom });
96    }
97    out
98}