1#![allow(dead_code)] use wide::f32x8;
17
18const SIMD_WIDTH: usize = 8;
20
21#[inline]
30#[must_use]
31pub fn euclidean_distance_squared(a: &[f32], b: &[f32]) -> f32 {
32 debug_assert_eq!(a.len(), b.len(), "vectors must have same dimension");
33
34 let len = a.len();
35 let simd_len = len - (len % SIMD_WIDTH);
36
37 let mut sum = f32x8::ZERO;
38
39 for i in (0..simd_len).step_by(SIMD_WIDTH) {
41 let va = f32x8::new(a[i..i + SIMD_WIDTH].try_into().unwrap());
42 let vb = f32x8::new(b[i..i + SIMD_WIDTH].try_into().unwrap());
43 let diff = va - vb;
44 sum += diff * diff;
45 }
46
47 let mut result = horizontal_sum(sum);
49
50 for i in simd_len..len {
52 let diff = a[i] - b[i];
53 result += diff * diff;
54 }
55
56 result
57}
58
59#[inline]
65#[must_use]
66pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
67 euclidean_distance_squared(a, b).sqrt()
68}
69
70#[inline]
76#[must_use]
77pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
78 debug_assert_eq!(a.len(), b.len(), "vectors must have same dimension");
79
80 let len = a.len();
81 let simd_len = len - (len % SIMD_WIDTH);
82
83 let mut sum = f32x8::ZERO;
84
85 for i in (0..simd_len).step_by(SIMD_WIDTH) {
87 let va = f32x8::new(a[i..i + SIMD_WIDTH].try_into().unwrap());
88 let vb = f32x8::new(b[i..i + SIMD_WIDTH].try_into().unwrap());
89 sum += va * vb;
90 }
91
92 let mut result = horizontal_sum(sum);
94
95 for i in simd_len..len {
97 result += a[i] * b[i];
98 }
99
100 result
101}
102
103#[inline]
107#[must_use]
108pub fn sum_of_squares(v: &[f32]) -> f32 {
109 let len = v.len();
110 let simd_len = len - (len % SIMD_WIDTH);
111
112 let mut sum = f32x8::ZERO;
113
114 for i in (0..simd_len).step_by(SIMD_WIDTH) {
116 let vv = f32x8::new(v[i..i + SIMD_WIDTH].try_into().unwrap());
117 sum += vv * vv;
118 }
119
120 let mut result = horizontal_sum(sum);
122
123 for i in simd_len..len {
125 result += v[i] * v[i];
126 }
127
128 result
129}
130
131#[inline]
133#[must_use]
134pub fn l2_norm(v: &[f32]) -> f32 {
135 sum_of_squares(v).sqrt()
136}
137
138#[inline]
147#[must_use]
148pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
149 debug_assert_eq!(a.len(), b.len(), "vectors must have same dimension");
150
151 let len = a.len();
152 let simd_len = len - (len % SIMD_WIDTH);
153
154 let mut sum = f32x8::ZERO;
155
156 for i in (0..simd_len).step_by(SIMD_WIDTH) {
158 let va = f32x8::new(a[i..i + SIMD_WIDTH].try_into().unwrap());
159 let vb = f32x8::new(b[i..i + SIMD_WIDTH].try_into().unwrap());
160 let diff = va - vb;
161 sum += diff.abs();
162 }
163
164 let mut result = horizontal_sum(sum);
166
167 for i in simd_len..len {
169 result += (a[i] - b[i]).abs();
170 }
171
172 result
173}
174
175#[inline]
184#[must_use]
185pub fn chebyshev_distance(a: &[f32], b: &[f32]) -> f32 {
186 debug_assert_eq!(a.len(), b.len(), "vectors must have same dimension");
187
188 let len = a.len();
189 let simd_len = len - (len % SIMD_WIDTH);
190
191 let mut max_simd = f32x8::ZERO;
192
193 for i in (0..simd_len).step_by(SIMD_WIDTH) {
195 let va = f32x8::new(a[i..i + SIMD_WIDTH].try_into().unwrap());
196 let vb = f32x8::new(b[i..i + SIMD_WIDTH].try_into().unwrap());
197 let diff = (va - vb).abs();
198 max_simd = max_simd.max(diff);
199 }
200
201 let mut result = horizontal_max(max_simd);
203
204 for i in simd_len..len {
206 result = result.max((a[i] - b[i]).abs());
207 }
208
209 result
210}
211
212#[inline]
216fn horizontal_max(v: f32x8) -> f32 {
217 let arr: [f32; 8] = v.to_array();
218 arr.iter().copied().fold(f32::MIN, f32::max)
219}
220
221#[inline]
234#[must_use]
235pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
236 debug_assert_eq!(a.len(), b.len(), "vectors must have same dimension");
237
238 let len = a.len();
239 let simd_len = len - (len % SIMD_WIDTH);
240
241 let mut dot_sum = f32x8::ZERO;
242 let mut norm_a_sum = f32x8::ZERO;
243 let mut norm_b_sum = f32x8::ZERO;
244
245 for i in (0..simd_len).step_by(SIMD_WIDTH) {
247 let va = f32x8::new(a[i..i + SIMD_WIDTH].try_into().unwrap());
248 let vb = f32x8::new(b[i..i + SIMD_WIDTH].try_into().unwrap());
249
250 dot_sum += va * vb;
251 norm_a_sum += va * va;
252 norm_b_sum += vb * vb;
253 }
254
255 let mut dot = horizontal_sum(dot_sum);
257 let mut norm_a_sq = horizontal_sum(norm_a_sum);
258 let mut norm_b_sq = horizontal_sum(norm_b_sum);
259
260 for i in simd_len..len {
262 dot += a[i] * b[i];
263 norm_a_sq += a[i] * a[i];
264 norm_b_sq += b[i] * b[i];
265 }
266
267 let norm_product = (norm_a_sq * norm_b_sq).sqrt();
268
269 if norm_product == 0.0 {
270 return 0.0;
271 }
272
273 dot / norm_product
274}
275
276#[inline]
287#[must_use]
288pub fn cosine_similarity_with_norms(a: &[f32], b: &[f32], norm_a: f32, norm_b: f32) -> Option<f32> {
289 if norm_a == 0.0 || norm_b == 0.0 {
290 return None;
291 }
292
293 let dot = dot_product(a, b);
294 Some(dot / (norm_a * norm_b))
295}
296
297#[inline]
305#[must_use]
306pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
307 1.0 - cosine_similarity(a, b)
308}
309
310#[derive(Debug, Clone, Copy)]
330pub struct CachedNorm {
331 norm_squared: f32,
333 norm: f32,
335}
336
337impl CachedNorm {
338 #[must_use]
340 pub fn new(v: &[f32]) -> Self {
341 let norm_squared = sum_of_squares(v);
342 let norm = norm_squared.sqrt();
343 Self { norm_squared, norm }
344 }
345
346 #[inline]
348 #[must_use]
349 pub const fn norm(&self) -> f32 {
350 self.norm
351 }
352
353 #[inline]
355 #[must_use]
356 pub const fn norm_squared(&self) -> f32 {
357 self.norm_squared
358 }
359
360 #[inline]
362 #[must_use]
363 pub fn is_zero(&self) -> bool {
364 self.norm == 0.0
365 }
366}
367
368#[inline]
372fn horizontal_sum(v: f32x8) -> f32 {
373 let arr: [f32; 8] = v.to_array();
375 arr.iter().sum()
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381
382 const EPSILON: f32 = 1e-5;
383
384 fn assert_near(a: f32, b: f32, epsilon: f32) {
385 assert!(
386 (a - b).abs() < epsilon,
387 "assertion failed: {} !~ {} (diff: {})",
388 a,
389 b,
390 (a - b).abs()
391 );
392 }
393
394 #[test]
395 fn test_dot_product_small() {
396 let a = [1.0, 2.0, 3.0];
397 let b = [4.0, 5.0, 6.0];
398 assert_near(dot_product(&a, &b), 32.0, EPSILON);
399 }
400
401 #[test]
402 fn test_dot_product_simd_aligned() {
403 let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
405 let b = [1.0; 8];
406 assert_near(dot_product(&a, &b), 36.0, EPSILON);
408 }
409
410 #[test]
411 fn test_dot_product_mixed() {
412 let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
414 let b = [1.0; 10];
415 assert_near(dot_product(&a, &b), 55.0, EPSILON);
417 }
418
419 #[test]
420 fn test_euclidean_distance_small() {
421 let a = [0.0, 0.0];
422 let b = [3.0, 4.0];
423 assert_near(euclidean_distance(&a, &b), 5.0, EPSILON);
424 }
425
426 #[test]
427 fn test_euclidean_large() {
428 let a: Vec<f32> = (0..1536).map(|i| i as f32 * 0.001).collect();
430 let b: Vec<f32> = (0..1536).map(|i| (i + 1) as f32 * 0.001).collect();
431
432 let dist = euclidean_distance(&a, &b);
433 assert!(dist > 0.039 && dist < 0.040, "Expected ~0.0392, got {}", dist);
437 }
438
439 #[test]
440 fn test_sum_of_squares() {
441 let v = [3.0, 4.0];
442 assert_near(sum_of_squares(&v), 25.0, EPSILON);
443 }
444
445 #[test]
446 fn test_l2_norm() {
447 let v = [3.0, 4.0];
448 assert_near(l2_norm(&v), 5.0, EPSILON);
449 }
450
451 #[test]
452 fn test_cosine_similarity_identical() {
453 let a = [1.0, 0.0];
454 assert_near(cosine_similarity(&a, &a), 1.0, EPSILON);
455 }
456
457 #[test]
458 fn test_cosine_similarity_orthogonal() {
459 let a = [1.0, 0.0];
460 let b = [0.0, 1.0];
461 assert_near(cosine_similarity(&a, &b), 0.0, EPSILON);
462 }
463
464 #[test]
465 fn test_cosine_similarity_large() {
466 let a: Vec<f32> = (0..1536).map(|i| (i % 10) as f32).collect();
468 let b = a.clone();
469 assert_near(cosine_similarity(&a, &b), 1.0, EPSILON);
470 }
471
472 #[test]
473 fn test_cosine_with_norms() {
474 let a = [3.0, 4.0];
475 let b = [3.0, 4.0];
476 let norm_a = l2_norm(&a);
477 let norm_b = l2_norm(&b);
478
479 let sim = cosine_similarity_with_norms(&a, &b, norm_a, norm_b);
480 assert!(sim.is_some());
481 assert_near(sim.unwrap(), 1.0, EPSILON);
482 }
483
484 #[test]
485 fn test_cached_norm() {
486 let v = [3.0, 4.0];
487 let cached = CachedNorm::new(&v);
488
489 assert_near(cached.norm(), 5.0, EPSILON);
490 assert_near(cached.norm_squared(), 25.0, EPSILON);
491 assert!(!cached.is_zero());
492 }
493
494 #[test]
495 fn test_cached_norm_zero() {
496 let v = [0.0, 0.0, 0.0];
497 let cached = CachedNorm::new(&v);
498
499 assert!(cached.is_zero());
500 }
501
502 #[test]
503 fn test_horizontal_sum() {
504 let v = f32x8::new([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
505 assert_near(horizontal_sum(v), 36.0, EPSILON);
506 }
507
508 #[test]
509 fn test_horizontal_max() {
510 let v = f32x8::new([1.0, 8.0, 3.0, 4.0, 5.0, 6.0, 7.0, 2.0]);
511 assert_near(horizontal_max(v), 8.0, EPSILON);
512 }
513
514 #[test]
515 fn test_manhattan_distance_small() {
516 let a = [0.0, 0.0];
517 let b = [3.0, 4.0];
518 assert_near(manhattan_distance(&a, &b), 7.0, EPSILON);
519 }
520
521 #[test]
522 fn test_manhattan_distance_simd_aligned() {
523 let a = [0.0; 8];
525 let b = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
526 assert_near(manhattan_distance(&a, &b), 36.0, EPSILON);
528 }
529
530 #[test]
531 fn test_manhattan_distance_mixed() {
532 let a = [0.0; 10];
534 let b = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
535 assert_near(manhattan_distance(&a, &b), 55.0, EPSILON);
537 }
538
539 #[test]
540 fn test_manhattan_distance_large() {
541 let a: Vec<f32> = (0..1536).map(|i| i as f32).collect();
543 let b: Vec<f32> = (0..1536).map(|i| (i + 1) as f32).collect();
544
545 let dist = manhattan_distance(&a, &b);
546 assert_near(dist, 1536.0, EPSILON);
548 }
549
550 #[test]
551 fn test_chebyshev_distance_small() {
552 let a = [0.0, 0.0];
553 let b = [3.0, 4.0];
554 assert_near(chebyshev_distance(&a, &b), 4.0, EPSILON);
555 }
556
557 #[test]
558 fn test_chebyshev_distance_simd_aligned() {
559 let a = [0.0; 8];
561 let b = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
562 assert_near(chebyshev_distance(&a, &b), 8.0, EPSILON);
564 }
565
566 #[test]
567 fn test_chebyshev_distance_mixed() {
568 let a = [0.0; 10];
570 let b = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
571 assert_near(chebyshev_distance(&a, &b), 10.0, EPSILON);
573 }
574
575 #[test]
576 fn test_chebyshev_distance_max_in_remainder() {
577 let a = [0.0; 10];
579 let b = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 100.0, 10.0];
580 assert_near(chebyshev_distance(&a, &b), 100.0, EPSILON);
581 }
582
583 #[test]
584 fn test_chebyshev_distance_large() {
585 let a: Vec<f32> = (0..1536).map(|_| 0.0).collect();
587 let mut b: Vec<f32> = (0..1536).map(|i| i as f32 * 0.001).collect();
588 b[1000] = 999.0; let dist = chebyshev_distance(&a, &b);
591 assert_near(dist, 999.0, EPSILON);
592 }
593}