1use simsimd::SpatialSimilarity;
11
12#[inline]
15pub fn cosine_similarity_f32(a: &[f32], b: &[f32]) -> f32 {
16 match f32::cosine(a, b) {
17 Some(distance) => (1.0 - distance) as f32,
18 None => 0.0,
19 }
20}
21
22#[inline]
25pub fn cosine_distance_f32(a: &[f32], b: &[f32]) -> f32 {
26 match f32::cosine(a, b) {
27 Some(distance) => distance as f32,
28 None => 1.0,
29 }
30}
31
32#[inline]
34pub fn l2_squared_f32(a: &[f32], b: &[f32]) -> f32 {
35 match f32::sqeuclidean(a, b) {
36 Some(distance) => distance as f32,
37 None => f32::MAX,
38 }
39}
40
41#[inline]
43pub fn l2_distance_f32(a: &[f32], b: &[f32]) -> f32 {
44 l2_squared_f32(a, b).sqrt()
45}
46
47#[inline]
49pub fn dot_product_f32(a: &[f32], b: &[f32]) -> f32 {
50 match f32::dot(a, b) {
51 Some(product) => product as f32,
52 None => 0.0,
53 }
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum DistanceMetric {
59 Cosine,
61 L2,
63 L2Squared,
65 InnerProduct,
67 Manhattan,
69}
70
71impl DistanceMetric {
72 #[inline]
74 pub fn compute(&self, a: &[f32], b: &[f32]) -> f32 {
75 match self {
76 Self::Cosine => cosine_distance_f32(a, b),
77 Self::L2 => l2_distance_f32(a, b),
78 Self::L2Squared => l2_squared_f32(a, b),
79 Self::InnerProduct => -dot_product_f32(a, b), Self::Manhattan => manhattan_distance_f32(a, b),
81 }
82 }
83
84 #[inline]
86 pub fn to_similarity(&self, distance: f32) -> f32 {
87 match self {
88 Self::Cosine => 1.0 - distance,
89 Self::L2 | Self::L2Squared | Self::Manhattan => 1.0 / (1.0 + distance),
90 Self::InnerProduct => -distance, }
92 }
93}
94
95pub fn manhattan_distance_f32(a: &[f32], b: &[f32]) -> f32 {
97 a.iter()
98 .zip(b.iter())
99 .map(|(&x, &y)| (x - y).abs())
100 .sum()
101}
102
103pub fn batch_cosine_similarity(query: &[f32], vectors: &[&[f32]]) -> Vec<f32> {
105 vectors
106 .iter()
107 .map(|v| cosine_similarity_f32(query, v))
108 .collect()
109}
110
111pub fn batch_l2_distance(query: &[f32], vectors: &[&[f32]]) -> Vec<f32> {
113 vectors.iter().map(|v| l2_distance_f32(query, v)).collect()
114}
115
116pub fn top_k_similar(
118 query: &[f32],
119 vectors: &[(&str, &[f32])],
120 k: usize,
121 metric: DistanceMetric,
122) -> Vec<(String, f32)> {
123 let mut scored: Vec<(String, f32)> = vectors
124 .iter()
125 .map(|(id, v)| {
126 let dist = metric.compute(query, v);
127 let sim = metric.to_similarity(dist);
128 (id.to_string(), sim)
129 })
130 .collect();
131
132 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
134 scored.truncate(k);
135 scored
136}
137
138pub fn normalize_f32(v: &mut [f32]) {
140 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
141 if norm > 1e-10 {
142 for x in v.iter_mut() {
143 *x /= norm;
144 }
145 }
146}
147
148pub fn is_normalized(v: &[f32], tolerance: f32) -> bool {
150 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
151 (norm - 1.0).abs() < tolerance
152}
153
154pub fn magnitude_f32(v: &[f32]) -> f32 {
156 v.iter().map(|x| x * x).sum::<f32>().sqrt()
157}
158
159pub fn centroid(vectors: &[&[f32]]) -> Vec<f32> {
161 if vectors.is_empty() {
162 return Vec::new();
163 }
164
165 let dim = vectors[0].len();
166 let mut result = vec![0.0f32; dim];
167 let n = vectors.len() as f32;
168
169 for v in vectors {
170 for (i, &x) in v.iter().enumerate() {
171 if i < dim {
172 result[i] += x;
173 }
174 }
175 }
176
177 for x in &mut result {
178 *x /= n;
179 }
180
181 result
182}
183
184pub fn variance(vectors: &[&[f32]], centroid: &[f32]) -> f32 {
186 if vectors.is_empty() {
187 return 0.0;
188 }
189
190 vectors
191 .iter()
192 .map(|v| l2_squared_f32(v, centroid))
193 .sum::<f32>()
194 / vectors.len() as f32
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn test_cosine_similarity() {
203 let a = [1.0, 0.0, 0.0];
204 let b = [1.0, 0.0, 0.0];
205 assert!((cosine_similarity_f32(&a, &b) - 1.0).abs() < 0.001);
206
207 let c = [1.0, 0.0, 0.0];
208 let d = [0.0, 1.0, 0.0];
209 assert!(cosine_similarity_f32(&c, &d).abs() < 0.001);
210 }
211
212 #[test]
213 fn test_l2_distance() {
214 let a = [0.0, 0.0, 0.0];
215 let b = [3.0, 4.0, 0.0];
216 assert!((l2_distance_f32(&a, &b) - 5.0).abs() < 0.001);
217 }
218
219 #[test]
220 fn test_dot_product() {
221 let a = [1.0, 2.0, 3.0];
222 let b = [4.0, 5.0, 6.0];
223 assert!((dot_product_f32(&a, &b) - 32.0).abs() < 0.001);
225 }
226
227 #[test]
228 fn test_manhattan_distance() {
229 let a = [1.0, 2.0, 3.0];
230 let b = [4.0, 6.0, 3.0];
231 assert!((manhattan_distance_f32(&a, &b) - 7.0).abs() < 0.001);
233 }
234
235 #[test]
236 fn test_normalize() {
237 let mut v = [3.0, 4.0, 0.0];
238 normalize_f32(&mut v);
239 assert!((magnitude_f32(&v) - 1.0).abs() < 0.001);
240 assert!(is_normalized(&v, 0.001));
241 }
242
243 #[test]
244 fn test_top_k() {
245 let query = [1.0, 0.0, 0.0];
246 let vectors: Vec<(&str, &[f32])> = vec![
247 ("a", &[1.0, 0.0, 0.0][..]),
248 ("b", &[0.0, 1.0, 0.0][..]),
249 ("c", &[0.7, 0.7, 0.0][..]),
250 ];
251
252 let result = top_k_similar(&query, &vectors, 2, DistanceMetric::Cosine);
253 assert_eq!(result.len(), 2);
254 assert_eq!(result[0].0, "a"); }
256
257 #[test]
258 fn test_centroid() {
259 let v1 = [0.0, 0.0];
260 let v2 = [2.0, 2.0];
261 let vectors: Vec<&[f32]> = vec![&v1, &v2];
262
263 let c = centroid(&vectors);
264 assert!((c[0] - 1.0).abs() < 0.001);
265 assert!((c[1] - 1.0).abs() < 0.001);
266 }
267
268 #[test]
269 fn test_distance_metrics() {
270 let a = [1.0, 0.0, 0.0];
271 let b = [0.0, 1.0, 0.0];
272
273 let cosine_dist = DistanceMetric::Cosine.compute(&a, &b);
274 assert!((cosine_dist - 1.0).abs() < 0.001);
275
276 let l2_dist = DistanceMetric::L2.compute(&a, &b);
277 assert!((l2_dist - std::f32::consts::SQRT_2).abs() < 0.001);
278 }
279}