1pub trait SimdOps {
8 fn add(a: &[Self], b: &[Self]) -> Vec<Self>
10 where
11 Self: Sized;
12
13 fn sub(a: &[Self], b: &[Self]) -> Vec<Self>
15 where
16 Self: Sized;
17
18 fn mul(a: &[Self], b: &[Self]) -> Vec<Self>
20 where
21 Self: Sized;
22
23 fn dot(a: &[Self], b: &[Self]) -> Self
25 where
26 Self: Sized;
27
28 fn cosine_distance(a: &[Self], b: &[Self]) -> Self
30 where
31 Self: Sized;
32
33 fn euclidean_distance(a: &[Self], b: &[Self]) -> Self
35 where
36 Self: Sized;
37
38 fn manhattan_distance(a: &[Self], b: &[Self]) -> Self
40 where
41 Self: Sized;
42
43 fn norm(a: &[Self]) -> Self
45 where
46 Self: Sized;
47
48 fn sum(a: &[Self]) -> Self
50 where
51 Self: Sized;
52
53 fn mean(a: &[Self]) -> Self
55 where
56 Self: Sized;
57}
58
59#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
61mod x86_simd;
62
63#[cfg(all(target_arch = "aarch64", feature = "simd"))]
65mod arm_simd;
66
67mod scalar;
69
70#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
72pub use x86_simd::*;
73
74#[cfg(all(target_arch = "aarch64", feature = "simd"))]
75pub use arm_simd::*;
76
77#[cfg(not(feature = "simd"))]
78pub use scalar::*;
79
80#[cfg(all(
82 feature = "simd",
83 not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64"))
84))]
85pub use scalar::*;
86
87impl SimdOps for f32 {
89 fn add(a: &[Self], b: &[Self]) -> Vec<Self> {
90 debug_assert_eq!(a.len(), b.len());
91 #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
92 unsafe {
93 x86_simd::add_f32(a, b)
94 }
95 #[cfg(all(target_arch = "aarch64", feature = "simd"))]
96 unsafe {
97 arm_simd::add_f32(a, b)
98 }
99 #[cfg(not(any(
100 all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
101 all(target_arch = "aarch64", feature = "simd")
102 )))]
103 scalar::add_f32(a, b)
104 }
105
106 fn sub(a: &[Self], b: &[Self]) -> Vec<Self> {
107 debug_assert_eq!(a.len(), b.len());
108 #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
109 unsafe {
110 x86_simd::sub_f32(a, b)
111 }
112 #[cfg(all(target_arch = "aarch64", feature = "simd"))]
113 unsafe {
114 arm_simd::sub_f32(a, b)
115 }
116 #[cfg(not(any(
117 all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
118 all(target_arch = "aarch64", feature = "simd")
119 )))]
120 scalar::sub_f32(a, b)
121 }
122
123 fn mul(a: &[Self], b: &[Self]) -> Vec<Self> {
124 debug_assert_eq!(a.len(), b.len());
125 #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
126 unsafe {
127 x86_simd::mul_f32(a, b)
128 }
129 #[cfg(all(target_arch = "aarch64", feature = "simd"))]
130 unsafe {
131 arm_simd::mul_f32(a, b)
132 }
133 #[cfg(not(any(
134 all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
135 all(target_arch = "aarch64", feature = "simd")
136 )))]
137 scalar::mul_f32(a, b)
138 }
139
140 fn dot(a: &[Self], b: &[Self]) -> Self {
141 debug_assert_eq!(a.len(), b.len());
142 #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
143 unsafe {
144 x86_simd::dot_f32(a, b)
145 }
146 #[cfg(all(target_arch = "aarch64", feature = "simd"))]
147 unsafe {
148 arm_simd::dot_f32(a, b)
149 }
150 #[cfg(not(any(
151 all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
152 all(target_arch = "aarch64", feature = "simd")
153 )))]
154 scalar::dot_f32(a, b)
155 }
156
157 fn cosine_distance(a: &[Self], b: &[Self]) -> Self {
158 debug_assert_eq!(a.len(), b.len());
159 #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
160 unsafe {
161 x86_simd::cosine_distance_f32(a, b)
162 }
163 #[cfg(all(target_arch = "aarch64", feature = "simd"))]
164 unsafe {
165 arm_simd::cosine_distance_f32(a, b)
166 }
167 #[cfg(not(any(
168 all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
169 all(target_arch = "aarch64", feature = "simd")
170 )))]
171 scalar::cosine_distance_f32(a, b)
172 }
173
174 fn euclidean_distance(a: &[Self], b: &[Self]) -> Self {
175 debug_assert_eq!(a.len(), b.len());
176 #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
177 unsafe {
178 x86_simd::euclidean_distance_f32(a, b)
179 }
180 #[cfg(all(target_arch = "aarch64", feature = "simd"))]
181 unsafe {
182 arm_simd::euclidean_distance_f32(a, b)
183 }
184 #[cfg(not(any(
185 all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
186 all(target_arch = "aarch64", feature = "simd")
187 )))]
188 scalar::euclidean_distance_f32(a, b)
189 }
190
191 fn manhattan_distance(a: &[Self], b: &[Self]) -> Self {
192 debug_assert_eq!(a.len(), b.len());
193 #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
194 unsafe {
195 x86_simd::manhattan_distance_f32(a, b)
196 }
197 #[cfg(all(target_arch = "aarch64", feature = "simd"))]
198 unsafe {
199 arm_simd::manhattan_distance_f32(a, b)
200 }
201 #[cfg(not(any(
202 all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
203 all(target_arch = "aarch64", feature = "simd")
204 )))]
205 scalar::manhattan_distance_f32(a, b)
206 }
207
208 fn norm(a: &[Self]) -> Self {
209 #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
210 unsafe {
211 x86_simd::norm_f32(a)
212 }
213 #[cfg(all(target_arch = "aarch64", feature = "simd"))]
214 unsafe {
215 arm_simd::norm_f32(a)
216 }
217 #[cfg(not(any(
218 all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
219 all(target_arch = "aarch64", feature = "simd")
220 )))]
221 scalar::norm_f32(a)
222 }
223
224 fn sum(a: &[Self]) -> Self {
225 #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
226 unsafe {
227 x86_simd::sum_f32(a)
228 }
229 #[cfg(all(target_arch = "aarch64", feature = "simd"))]
230 unsafe {
231 arm_simd::sum_f32(a)
232 }
233 #[cfg(not(any(
234 all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
235 all(target_arch = "aarch64", feature = "simd")
236 )))]
237 scalar::sum_f32(a)
238 }
239
240 fn mean(a: &[Self]) -> Self {
241 if a.is_empty() {
242 return 0.0;
243 }
244 Self::sum(a) / a.len() as f32
245 }
246}
247
248impl SimdOps for f64 {
250 fn add(a: &[Self], b: &[Self]) -> Vec<Self> {
251 debug_assert_eq!(a.len(), b.len());
252 #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
253 unsafe {
254 x86_simd::add_f64(a, b)
255 }
256 #[cfg(all(target_arch = "aarch64", feature = "simd"))]
257 unsafe {
258 arm_simd::add_f64(a, b)
259 }
260 #[cfg(not(any(
261 all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
262 all(target_arch = "aarch64", feature = "simd")
263 )))]
264 scalar::add_f64(a, b)
265 }
266
267 fn sub(a: &[Self], b: &[Self]) -> Vec<Self> {
268 debug_assert_eq!(a.len(), b.len());
269 #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
270 unsafe {
271 x86_simd::sub_f64(a, b)
272 }
273 #[cfg(all(target_arch = "aarch64", feature = "simd"))]
274 unsafe {
275 arm_simd::sub_f64(a, b)
276 }
277 #[cfg(not(any(
278 all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
279 all(target_arch = "aarch64", feature = "simd")
280 )))]
281 scalar::sub_f64(a, b)
282 }
283
284 fn mul(a: &[Self], b: &[Self]) -> Vec<Self> {
285 debug_assert_eq!(a.len(), b.len());
286 #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
287 unsafe {
288 x86_simd::mul_f64(a, b)
289 }
290 #[cfg(all(target_arch = "aarch64", feature = "simd"))]
291 unsafe {
292 arm_simd::mul_f64(a, b)
293 }
294 #[cfg(not(any(
295 all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
296 all(target_arch = "aarch64", feature = "simd")
297 )))]
298 scalar::mul_f64(a, b)
299 }
300
301 fn dot(a: &[Self], b: &[Self]) -> Self {
302 debug_assert_eq!(a.len(), b.len());
303 #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
304 unsafe {
305 x86_simd::dot_f64(a, b)
306 }
307 #[cfg(all(target_arch = "aarch64", feature = "simd"))]
308 unsafe {
309 arm_simd::dot_f64(a, b)
310 }
311 #[cfg(not(any(
312 all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
313 all(target_arch = "aarch64", feature = "simd")
314 )))]
315 scalar::dot_f64(a, b)
316 }
317
318 fn cosine_distance(a: &[Self], b: &[Self]) -> Self {
319 debug_assert_eq!(a.len(), b.len());
320 #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
321 unsafe {
322 x86_simd::cosine_distance_f64(a, b)
323 }
324 #[cfg(all(target_arch = "aarch64", feature = "simd"))]
325 unsafe {
326 arm_simd::cosine_distance_f64(a, b)
327 }
328 #[cfg(not(any(
329 all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
330 all(target_arch = "aarch64", feature = "simd")
331 )))]
332 scalar::cosine_distance_f64(a, b)
333 }
334
335 fn euclidean_distance(a: &[Self], b: &[Self]) -> Self {
336 debug_assert_eq!(a.len(), b.len());
337 #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
338 unsafe {
339 x86_simd::euclidean_distance_f64(a, b)
340 }
341 #[cfg(all(target_arch = "aarch64", feature = "simd"))]
342 unsafe {
343 arm_simd::euclidean_distance_f64(a, b)
344 }
345 #[cfg(not(any(
346 all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
347 all(target_arch = "aarch64", feature = "simd")
348 )))]
349 scalar::euclidean_distance_f64(a, b)
350 }
351
352 fn manhattan_distance(a: &[Self], b: &[Self]) -> Self {
353 debug_assert_eq!(a.len(), b.len());
354 #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
355 unsafe {
356 x86_simd::manhattan_distance_f64(a, b)
357 }
358 #[cfg(all(target_arch = "aarch64", feature = "simd"))]
359 unsafe {
360 arm_simd::manhattan_distance_f64(a, b)
361 }
362 #[cfg(not(any(
363 all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
364 all(target_arch = "aarch64", feature = "simd")
365 )))]
366 scalar::manhattan_distance_f64(a, b)
367 }
368
369 fn norm(a: &[Self]) -> Self {
370 #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
371 unsafe {
372 x86_simd::norm_f64(a)
373 }
374 #[cfg(all(target_arch = "aarch64", feature = "simd"))]
375 unsafe {
376 arm_simd::norm_f64(a)
377 }
378 #[cfg(not(any(
379 all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
380 all(target_arch = "aarch64", feature = "simd")
381 )))]
382 scalar::norm_f64(a)
383 }
384
385 fn sum(a: &[Self]) -> Self {
386 #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
387 unsafe {
388 x86_simd::sum_f64(a)
389 }
390 #[cfg(all(target_arch = "aarch64", feature = "simd"))]
391 unsafe {
392 arm_simd::sum_f64(a)
393 }
394 #[cfg(not(any(
395 all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
396 all(target_arch = "aarch64", feature = "simd")
397 )))]
398 scalar::sum_f64(a)
399 }
400
401 fn mean(a: &[Self]) -> Self {
402 if a.is_empty() {
403 return 0.0;
404 }
405 Self::sum(a) / a.len() as f64
406 }
407}
408
409#[cfg(test)]
410mod tests {
411 use super::*;
412
413 const EPSILON_F32: f32 = 1e-5;
414 const EPSILON_F64: f64 = 1e-10;
415
416 #[test]
419 fn test_f32_dot_product_basic() {
420 let a = [1.0_f32, 2.0, 3.0];
421 let b = [4.0_f32, 5.0, 6.0];
422 let result = f32::dot(&a, &b);
423 assert!(
425 (result - 32.0_f32).abs() < EPSILON_F32,
426 "Expected 32.0, got {result}"
427 );
428 }
429
430 #[test]
431 fn test_f32_dot_product_zeros() {
432 let a = [0.0_f32; 8];
433 let b = [1.0_f32; 8];
434 let result = f32::dot(&a, &b);
435 assert!(
436 (result - 0.0_f32).abs() < EPSILON_F32,
437 "Expected 0.0, got {result}"
438 );
439 }
440
441 #[test]
442 fn test_f32_cosine_distance_identical_vectors() {
443 let a = [1.0_f32, 0.0, 0.0];
444 let b = [1.0_f32, 0.0, 0.0];
445 let result = f32::cosine_distance(&a, &b);
447 assert!(
448 result.abs() < EPSILON_F32,
449 "Identical vectors should have cosine distance 0, got {result}"
450 );
451 }
452
453 #[test]
454 fn test_f32_cosine_distance_orthogonal_vectors() {
455 let a = [1.0_f32, 0.0, 0.0];
456 let b = [0.0_f32, 1.0, 0.0];
457 let result = f32::cosine_distance(&a, &b);
459 assert!(
460 (result - 1.0_f32).abs() < EPSILON_F32,
461 "Orthogonal vectors should have cosine distance 1, got {result}"
462 );
463 }
464
465 #[test]
466 fn test_f32_euclidean_distance() {
467 let a = [0.0_f32, 0.0, 0.0];
468 let b = [3.0_f32, 4.0, 0.0];
469 let result = f32::euclidean_distance(&a, &b);
471 assert!(
472 (result - 5.0_f32).abs() < EPSILON_F32,
473 "Expected 5.0, got {result}"
474 );
475 }
476
477 #[test]
478 fn test_f32_manhattan_distance() {
479 let a = [1.0_f32, 2.0, 3.0];
480 let b = [4.0_f32, 6.0, 8.0];
481 let result = f32::manhattan_distance(&a, &b);
483 assert!(
484 (result - 12.0_f32).abs() < EPSILON_F32,
485 "Expected 12.0, got {result}"
486 );
487 }
488
489 #[test]
490 fn test_f32_norm_unit_vector() {
491 let a = [1.0_f32, 0.0, 0.0];
492 let result = f32::norm(&a);
493 assert!(
494 (result - 1.0_f32).abs() < EPSILON_F32,
495 "Unit vector norm should be 1.0, got {result}"
496 );
497 }
498
499 #[test]
500 fn test_f32_norm_3_4_5() {
501 let a = [3.0_f32, 4.0, 0.0];
502 let result = f32::norm(&a);
503 assert!(
504 (result - 5.0_f32).abs() < EPSILON_F32,
505 "Expected norm 5.0, got {result}"
506 );
507 }
508
509 #[test]
510 fn test_f32_sum_and_mean() {
511 let a = [1.0_f32, 2.0, 3.0, 4.0];
512 let sum = f32::sum(&a);
513 let mean = f32::mean(&a);
514 assert!(
515 (sum - 10.0_f32).abs() < EPSILON_F32,
516 "Expected sum 10.0, got {sum}"
517 );
518 assert!(
519 (mean - 2.5_f32).abs() < EPSILON_F32,
520 "Expected mean 2.5, got {mean}"
521 );
522 }
523
524 #[test]
525 fn test_f32_mean_empty_slice() {
526 let a: [f32; 0] = [];
527 let result = f32::mean(&a);
528 assert!(
529 (result - 0.0_f32).abs() < EPSILON_F32,
530 "Mean of empty slice should be 0.0, got {result}"
531 );
532 }
533
534 #[test]
535 fn test_f32_add_element_wise() {
536 let a = [1.0_f32, 2.0, 3.0];
537 let b = [4.0_f32, 5.0, 6.0];
538 let result = f32::add(&a, &b);
539 assert_eq!(result.len(), 3);
540 assert!((result[0] - 5.0_f32).abs() < EPSILON_F32);
541 assert!((result[1] - 7.0_f32).abs() < EPSILON_F32);
542 assert!((result[2] - 9.0_f32).abs() < EPSILON_F32);
543 }
544
545 #[test]
546 fn test_f32_sub_element_wise() {
547 let a = [5.0_f32, 7.0, 9.0];
548 let b = [1.0_f32, 2.0, 3.0];
549 let result = f32::sub(&a, &b);
550 assert_eq!(result.len(), 3);
551 assert!((result[0] - 4.0_f32).abs() < EPSILON_F32);
552 assert!((result[1] - 5.0_f32).abs() < EPSILON_F32);
553 assert!((result[2] - 6.0_f32).abs() < EPSILON_F32);
554 }
555
556 #[test]
557 fn test_f32_mul_element_wise() {
558 let a = [2.0_f32, 3.0, 4.0];
559 let b = [5.0_f32, 6.0, 7.0];
560 let result = f32::mul(&a, &b);
561 assert_eq!(result.len(), 3);
562 assert!((result[0] - 10.0_f32).abs() < EPSILON_F32);
563 assert!((result[1] - 18.0_f32).abs() < EPSILON_F32);
564 assert!((result[2] - 28.0_f32).abs() < EPSILON_F32);
565 }
566
567 #[test]
570 fn test_f64_dot_product_basic() {
571 let a = [1.0_f64, 2.0, 3.0];
572 let b = [4.0_f64, 5.0, 6.0];
573 let result = f64::dot(&a, &b);
574 assert!(
575 (result - 32.0_f64).abs() < EPSILON_F64,
576 "Expected 32.0, got {result}"
577 );
578 }
579
580 #[test]
581 fn test_f64_euclidean_distance_zero() {
582 let a = [1.0_f64, 2.0, 3.0];
583 let b = [1.0_f64, 2.0, 3.0];
584 let result = f64::euclidean_distance(&a, &b);
585 assert!(
586 result.abs() < EPSILON_F64,
587 "Identical vectors should have distance 0, got {result}"
588 );
589 }
590
591 #[test]
592 fn test_f64_cosine_distance_opposite_vectors() {
593 let a = [1.0_f64, 0.0, 0.0];
596 let b = [-1.0_f64, 0.0, 0.0];
597 let result = f64::cosine_distance(&a, &b);
598 assert!(
599 (result - 2.0_f64).abs() < EPSILON_F64,
600 "Opposite vectors should have cosine distance 2.0, got {result}"
601 );
602 }
603
604 #[test]
605 fn test_f64_manhattan_distance_symmetry() {
606 let a = [1.0_f64, 2.0, 3.0];
607 let b = [4.0_f64, 6.0, 8.0];
608 let d_ab = f64::manhattan_distance(&a, &b);
609 let d_ba = f64::manhattan_distance(&b, &a);
610 assert!(
611 (d_ab - d_ba).abs() < EPSILON_F64,
612 "Manhattan distance should be symmetric"
613 );
614 }
615
616 #[test]
617 fn test_f64_norm_of_standard_basis() {
618 let a = [0.0_f64, 0.0, 1.0, 0.0];
619 let result = f64::norm(&a);
620 assert!(
621 (result - 1.0_f64).abs() < EPSILON_F64,
622 "Norm of standard basis vector should be 1.0, got {result}"
623 );
624 }
625
626 #[test]
627 fn test_f64_sum_large_slice() {
628 let a: Vec<f64> = (1..=100).map(|x| x as f64).collect();
630 let result = f64::sum(&a);
631 assert!(
632 (result - 5050.0_f64).abs() < EPSILON_F64,
633 "Expected 5050.0, got {result}"
634 );
635 }
636
637 #[test]
638 fn test_f64_mean_empty_slice() {
639 let a: [f64; 0] = [];
640 let result = f64::mean(&a);
641 assert!(
642 result.abs() < EPSILON_F64,
643 "Mean of empty slice should be 0.0, got {result}"
644 );
645 }
646
647 #[test]
648 fn test_f64_add_sub_roundtrip() {
649 let a = [3.0_f64, 1.0, 4.0, 1.0, 5.0];
650 let b = [1.0_f64, 2.0, 3.0, 4.0, 5.0];
651 let added = f64::add(&a, &b);
652 let subtracted = f64::sub(added.as_slice(), &b);
653 for (orig, recovered) in a.iter().zip(subtracted.iter()) {
654 assert!(
655 (orig - recovered).abs() < EPSILON_F64,
656 "Add-sub roundtrip failed: {orig} vs {recovered}"
657 );
658 }
659 }
660
661 #[test]
664 fn test_euclidean_and_manhattan_triangle_inequality() {
665 let a = [1.0_f32, 2.0, 3.0];
667 let b = [4.0_f32, 6.0, 8.0];
668 let euclidean = f32::euclidean_distance(&a, &b);
669 let manhattan = f32::manhattan_distance(&a, &b);
670 assert!(
671 euclidean <= manhattan + EPSILON_F32,
672 "Euclidean should be <= Manhattan: {euclidean} vs {manhattan}"
673 );
674 }
675
676 #[test]
677 fn test_cosine_distance_range() {
678 let a = [1.0_f32, 0.5, 0.25];
680 let b = [0.5_f32, 1.0, 2.0];
681 let result = f32::cosine_distance(&a, &b);
682 assert!(
683 result >= 0.0,
684 "Cosine distance should be non-negative, got {result}"
685 );
686 assert!(
687 result <= 2.0 + EPSILON_F32,
688 "Cosine distance should be <= 2.0, got {result}"
689 );
690 }
691}