1#![deny(missing_docs)]
26
27pub fn cosine(a: &[f32], b: &[f32]) -> f32 {
31 assert_eq!(a.len(), b.len(), "vector length mismatch");
32 let mut dot = 0.0_f32;
33 let mut na = 0.0_f32;
34 let mut nb = 0.0_f32;
35 for i in 0..a.len() {
36 let ai = a[i];
37 let bi = b[i];
38 dot += ai * bi;
39 na += ai * ai;
40 nb += bi * bi;
41 }
42 let denom = (na * nb).sqrt();
43 if denom == 0.0 {
44 0.0
45 } else {
46 dot / denom
47 }
48}
49
50pub fn norm(v: &[f32]) -> f32 {
53 v.iter().map(|x| x * x).sum::<f32>().sqrt()
54}
55
56pub fn cosine_with_norm(a: &[f32], b: &[f32], b_norm: f32) -> f32 {
59 assert_eq!(a.len(), b.len(), "vector length mismatch");
60 let mut dot = 0.0_f32;
61 let mut na = 0.0_f32;
62 for i in 0..a.len() {
63 let ai = a[i];
64 let bi = b[i];
65 dot += ai * bi;
66 na += ai * ai;
67 }
68 let denom = na.sqrt() * b_norm;
69 if denom == 0.0 {
70 0.0
71 } else {
72 dot / denom
73 }
74}
75
76pub fn batch_cosine<'a, I>(q: &[f32], candidates: I) -> Vec<f32>
78where
79 I: IntoIterator<Item = &'a [f32]>,
80{
81 let q_norm = norm(q);
82 if q_norm == 0.0 {
83 return candidates.into_iter().map(|_| 0.0).collect();
84 }
85 let mut out = Vec::new();
86 for c in candidates {
87 assert_eq!(c.len(), q.len(), "vector length mismatch");
88 let mut dot = 0.0_f32;
89 let mut nc = 0.0_f32;
90 for i in 0..q.len() {
91 dot += q[i] * c[i];
92 nc += c[i] * c[i];
93 }
94 let denom = q_norm * nc.sqrt();
95 out.push(if denom == 0.0 { 0.0 } else { dot / denom });
96 }
97 out
98}