Skip to main content

oxirs_core/
simd.rs

1//! SIMD operations abstraction for OxiRS
2//!
3//! This module provides unified SIMD operations across the OxiRS ecosystem.
4//! All SIMD operations must go through this module - direct SIMD usage in other modules is forbidden.
5
6/// Unified SIMD operations trait
7pub trait SimdOps {
8    /// Add two slices element-wise
9    fn add(a: &[Self], b: &[Self]) -> Vec<Self>
10    where
11        Self: Sized;
12
13    /// Subtract two slices element-wise
14    fn sub(a: &[Self], b: &[Self]) -> Vec<Self>
15    where
16        Self: Sized;
17
18    /// Multiply two slices element-wise
19    fn mul(a: &[Self], b: &[Self]) -> Vec<Self>
20    where
21        Self: Sized;
22
23    /// Compute dot product
24    fn dot(a: &[Self], b: &[Self]) -> Self
25    where
26        Self: Sized;
27
28    /// Compute cosine distance (1 - cosine_similarity)
29    fn cosine_distance(a: &[Self], b: &[Self]) -> Self
30    where
31        Self: Sized;
32
33    /// Compute Euclidean distance
34    fn euclidean_distance(a: &[Self], b: &[Self]) -> Self
35    where
36        Self: Sized;
37
38    /// Compute Manhattan distance
39    fn manhattan_distance(a: &[Self], b: &[Self]) -> Self
40    where
41        Self: Sized;
42
43    /// Compute L2 norm
44    fn norm(a: &[Self]) -> Self
45    where
46        Self: Sized;
47
48    /// Sum all elements
49    fn sum(a: &[Self]) -> Self
50    where
51        Self: Sized;
52
53    /// Compute mean
54    fn mean(a: &[Self]) -> Self
55    where
56        Self: Sized;
57}
58
59// Platform-specific SIMD implementations
60#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
61mod x86_simd;
62
63// ARM NEON SIMD support for Apple Silicon and ARM processors
64#[cfg(all(target_arch = "aarch64", feature = "simd"))]
65mod arm_simd;
66
67// Generic scalar fallback implementation
68mod scalar;
69
70// Export the appropriate implementation based on features and platform
71#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
72pub use x86_simd::*;
73
74#[cfg(all(target_arch = "aarch64", feature = "simd"))]
75pub use arm_simd::*;
76
77#[cfg(not(feature = "simd"))]
78pub use scalar::*;
79
80// Fallback to scalar on unsupported platforms
81#[cfg(all(
82    feature = "simd",
83    not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64"))
84))]
85pub use scalar::*;
86
87/// SIMD implementation for f32
88impl SimdOps for f32 {
89    fn add(a: &[Self], b: &[Self]) -> Vec<Self> {
90        debug_assert_eq!(a.len(), b.len());
91        #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
92        unsafe {
93            x86_simd::add_f32(a, b)
94        }
95        #[cfg(all(target_arch = "aarch64", feature = "simd"))]
96        unsafe {
97            arm_simd::add_f32(a, b)
98        }
99        #[cfg(not(any(
100            all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
101            all(target_arch = "aarch64", feature = "simd")
102        )))]
103        scalar::add_f32(a, b)
104    }
105
106    fn sub(a: &[Self], b: &[Self]) -> Vec<Self> {
107        debug_assert_eq!(a.len(), b.len());
108        #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
109        unsafe {
110            x86_simd::sub_f32(a, b)
111        }
112        #[cfg(all(target_arch = "aarch64", feature = "simd"))]
113        unsafe {
114            arm_simd::sub_f32(a, b)
115        }
116        #[cfg(not(any(
117            all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
118            all(target_arch = "aarch64", feature = "simd")
119        )))]
120        scalar::sub_f32(a, b)
121    }
122
123    fn mul(a: &[Self], b: &[Self]) -> Vec<Self> {
124        debug_assert_eq!(a.len(), b.len());
125        #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
126        unsafe {
127            x86_simd::mul_f32(a, b)
128        }
129        #[cfg(all(target_arch = "aarch64", feature = "simd"))]
130        unsafe {
131            arm_simd::mul_f32(a, b)
132        }
133        #[cfg(not(any(
134            all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
135            all(target_arch = "aarch64", feature = "simd")
136        )))]
137        scalar::mul_f32(a, b)
138    }
139
140    fn dot(a: &[Self], b: &[Self]) -> Self {
141        debug_assert_eq!(a.len(), b.len());
142        #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
143        unsafe {
144            x86_simd::dot_f32(a, b)
145        }
146        #[cfg(all(target_arch = "aarch64", feature = "simd"))]
147        unsafe {
148            arm_simd::dot_f32(a, b)
149        }
150        #[cfg(not(any(
151            all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
152            all(target_arch = "aarch64", feature = "simd")
153        )))]
154        scalar::dot_f32(a, b)
155    }
156
157    fn cosine_distance(a: &[Self], b: &[Self]) -> Self {
158        debug_assert_eq!(a.len(), b.len());
159        #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
160        unsafe {
161            x86_simd::cosine_distance_f32(a, b)
162        }
163        #[cfg(all(target_arch = "aarch64", feature = "simd"))]
164        unsafe {
165            arm_simd::cosine_distance_f32(a, b)
166        }
167        #[cfg(not(any(
168            all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
169            all(target_arch = "aarch64", feature = "simd")
170        )))]
171        scalar::cosine_distance_f32(a, b)
172    }
173
174    fn euclidean_distance(a: &[Self], b: &[Self]) -> Self {
175        debug_assert_eq!(a.len(), b.len());
176        #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
177        unsafe {
178            x86_simd::euclidean_distance_f32(a, b)
179        }
180        #[cfg(all(target_arch = "aarch64", feature = "simd"))]
181        unsafe {
182            arm_simd::euclidean_distance_f32(a, b)
183        }
184        #[cfg(not(any(
185            all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
186            all(target_arch = "aarch64", feature = "simd")
187        )))]
188        scalar::euclidean_distance_f32(a, b)
189    }
190
191    fn manhattan_distance(a: &[Self], b: &[Self]) -> Self {
192        debug_assert_eq!(a.len(), b.len());
193        #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
194        unsafe {
195            x86_simd::manhattan_distance_f32(a, b)
196        }
197        #[cfg(all(target_arch = "aarch64", feature = "simd"))]
198        unsafe {
199            arm_simd::manhattan_distance_f32(a, b)
200        }
201        #[cfg(not(any(
202            all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
203            all(target_arch = "aarch64", feature = "simd")
204        )))]
205        scalar::manhattan_distance_f32(a, b)
206    }
207
208    fn norm(a: &[Self]) -> Self {
209        #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
210        unsafe {
211            x86_simd::norm_f32(a)
212        }
213        #[cfg(all(target_arch = "aarch64", feature = "simd"))]
214        unsafe {
215            arm_simd::norm_f32(a)
216        }
217        #[cfg(not(any(
218            all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
219            all(target_arch = "aarch64", feature = "simd")
220        )))]
221        scalar::norm_f32(a)
222    }
223
224    fn sum(a: &[Self]) -> Self {
225        #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
226        unsafe {
227            x86_simd::sum_f32(a)
228        }
229        #[cfg(all(target_arch = "aarch64", feature = "simd"))]
230        unsafe {
231            arm_simd::sum_f32(a)
232        }
233        #[cfg(not(any(
234            all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
235            all(target_arch = "aarch64", feature = "simd")
236        )))]
237        scalar::sum_f32(a)
238    }
239
240    fn mean(a: &[Self]) -> Self {
241        if a.is_empty() {
242            return 0.0;
243        }
244        Self::sum(a) / a.len() as f32
245    }
246}
247
248/// SIMD implementation for f64
249impl SimdOps for f64 {
250    fn add(a: &[Self], b: &[Self]) -> Vec<Self> {
251        debug_assert_eq!(a.len(), b.len());
252        #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
253        unsafe {
254            x86_simd::add_f64(a, b)
255        }
256        #[cfg(all(target_arch = "aarch64", feature = "simd"))]
257        unsafe {
258            arm_simd::add_f64(a, b)
259        }
260        #[cfg(not(any(
261            all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
262            all(target_arch = "aarch64", feature = "simd")
263        )))]
264        scalar::add_f64(a, b)
265    }
266
267    fn sub(a: &[Self], b: &[Self]) -> Vec<Self> {
268        debug_assert_eq!(a.len(), b.len());
269        #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
270        unsafe {
271            x86_simd::sub_f64(a, b)
272        }
273        #[cfg(all(target_arch = "aarch64", feature = "simd"))]
274        unsafe {
275            arm_simd::sub_f64(a, b)
276        }
277        #[cfg(not(any(
278            all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
279            all(target_arch = "aarch64", feature = "simd")
280        )))]
281        scalar::sub_f64(a, b)
282    }
283
284    fn mul(a: &[Self], b: &[Self]) -> Vec<Self> {
285        debug_assert_eq!(a.len(), b.len());
286        #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
287        unsafe {
288            x86_simd::mul_f64(a, b)
289        }
290        #[cfg(all(target_arch = "aarch64", feature = "simd"))]
291        unsafe {
292            arm_simd::mul_f64(a, b)
293        }
294        #[cfg(not(any(
295            all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
296            all(target_arch = "aarch64", feature = "simd")
297        )))]
298        scalar::mul_f64(a, b)
299    }
300
301    fn dot(a: &[Self], b: &[Self]) -> Self {
302        debug_assert_eq!(a.len(), b.len());
303        #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
304        unsafe {
305            x86_simd::dot_f64(a, b)
306        }
307        #[cfg(all(target_arch = "aarch64", feature = "simd"))]
308        unsafe {
309            arm_simd::dot_f64(a, b)
310        }
311        #[cfg(not(any(
312            all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
313            all(target_arch = "aarch64", feature = "simd")
314        )))]
315        scalar::dot_f64(a, b)
316    }
317
318    fn cosine_distance(a: &[Self], b: &[Self]) -> Self {
319        debug_assert_eq!(a.len(), b.len());
320        #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
321        unsafe {
322            x86_simd::cosine_distance_f64(a, b)
323        }
324        #[cfg(all(target_arch = "aarch64", feature = "simd"))]
325        unsafe {
326            arm_simd::cosine_distance_f64(a, b)
327        }
328        #[cfg(not(any(
329            all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
330            all(target_arch = "aarch64", feature = "simd")
331        )))]
332        scalar::cosine_distance_f64(a, b)
333    }
334
335    fn euclidean_distance(a: &[Self], b: &[Self]) -> Self {
336        debug_assert_eq!(a.len(), b.len());
337        #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
338        unsafe {
339            x86_simd::euclidean_distance_f64(a, b)
340        }
341        #[cfg(all(target_arch = "aarch64", feature = "simd"))]
342        unsafe {
343            arm_simd::euclidean_distance_f64(a, b)
344        }
345        #[cfg(not(any(
346            all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
347            all(target_arch = "aarch64", feature = "simd")
348        )))]
349        scalar::euclidean_distance_f64(a, b)
350    }
351
352    fn manhattan_distance(a: &[Self], b: &[Self]) -> Self {
353        debug_assert_eq!(a.len(), b.len());
354        #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
355        unsafe {
356            x86_simd::manhattan_distance_f64(a, b)
357        }
358        #[cfg(all(target_arch = "aarch64", feature = "simd"))]
359        unsafe {
360            arm_simd::manhattan_distance_f64(a, b)
361        }
362        #[cfg(not(any(
363            all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
364            all(target_arch = "aarch64", feature = "simd")
365        )))]
366        scalar::manhattan_distance_f64(a, b)
367    }
368
369    fn norm(a: &[Self]) -> Self {
370        #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
371        unsafe {
372            x86_simd::norm_f64(a)
373        }
374        #[cfg(all(target_arch = "aarch64", feature = "simd"))]
375        unsafe {
376            arm_simd::norm_f64(a)
377        }
378        #[cfg(not(any(
379            all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
380            all(target_arch = "aarch64", feature = "simd")
381        )))]
382        scalar::norm_f64(a)
383    }
384
385    fn sum(a: &[Self]) -> Self {
386        #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
387        unsafe {
388            x86_simd::sum_f64(a)
389        }
390        #[cfg(all(target_arch = "aarch64", feature = "simd"))]
391        unsafe {
392            arm_simd::sum_f64(a)
393        }
394        #[cfg(not(any(
395            all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"),
396            all(target_arch = "aarch64", feature = "simd")
397        )))]
398        scalar::sum_f64(a)
399    }
400
401    fn mean(a: &[Self]) -> Self {
402        if a.is_empty() {
403            return 0.0;
404        }
405        Self::sum(a) / a.len() as f64
406    }
407}
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412
413    const EPSILON_F32: f32 = 1e-5;
414    const EPSILON_F64: f64 = 1e-10;
415
416    // --- f32 tests ---
417
418    #[test]
419    fn test_f32_dot_product_basic() {
420        let a = [1.0_f32, 2.0, 3.0];
421        let b = [4.0_f32, 5.0, 6.0];
422        let result = f32::dot(&a, &b);
423        // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
424        assert!(
425            (result - 32.0_f32).abs() < EPSILON_F32,
426            "Expected 32.0, got {result}"
427        );
428    }
429
430    #[test]
431    fn test_f32_dot_product_zeros() {
432        let a = [0.0_f32; 8];
433        let b = [1.0_f32; 8];
434        let result = f32::dot(&a, &b);
435        assert!(
436            (result - 0.0_f32).abs() < EPSILON_F32,
437            "Expected 0.0, got {result}"
438        );
439    }
440
441    #[test]
442    fn test_f32_cosine_distance_identical_vectors() {
443        let a = [1.0_f32, 0.0, 0.0];
444        let b = [1.0_f32, 0.0, 0.0];
445        // cosine_distance = 1 - cosine_similarity = 1 - 1 = 0
446        let result = f32::cosine_distance(&a, &b);
447        assert!(
448            result.abs() < EPSILON_F32,
449            "Identical vectors should have cosine distance 0, got {result}"
450        );
451    }
452
453    #[test]
454    fn test_f32_cosine_distance_orthogonal_vectors() {
455        let a = [1.0_f32, 0.0, 0.0];
456        let b = [0.0_f32, 1.0, 0.0];
457        // cosine_similarity = 0, so cosine_distance = 1
458        let result = f32::cosine_distance(&a, &b);
459        assert!(
460            (result - 1.0_f32).abs() < EPSILON_F32,
461            "Orthogonal vectors should have cosine distance 1, got {result}"
462        );
463    }
464
465    #[test]
466    fn test_f32_euclidean_distance() {
467        let a = [0.0_f32, 0.0, 0.0];
468        let b = [3.0_f32, 4.0, 0.0];
469        // sqrt(9 + 16) = 5
470        let result = f32::euclidean_distance(&a, &b);
471        assert!(
472            (result - 5.0_f32).abs() < EPSILON_F32,
473            "Expected 5.0, got {result}"
474        );
475    }
476
477    #[test]
478    fn test_f32_manhattan_distance() {
479        let a = [1.0_f32, 2.0, 3.0];
480        let b = [4.0_f32, 6.0, 8.0];
481        // |1-4| + |2-6| + |3-8| = 3 + 4 + 5 = 12
482        let result = f32::manhattan_distance(&a, &b);
483        assert!(
484            (result - 12.0_f32).abs() < EPSILON_F32,
485            "Expected 12.0, got {result}"
486        );
487    }
488
489    #[test]
490    fn test_f32_norm_unit_vector() {
491        let a = [1.0_f32, 0.0, 0.0];
492        let result = f32::norm(&a);
493        assert!(
494            (result - 1.0_f32).abs() < EPSILON_F32,
495            "Unit vector norm should be 1.0, got {result}"
496        );
497    }
498
499    #[test]
500    fn test_f32_norm_3_4_5() {
501        let a = [3.0_f32, 4.0, 0.0];
502        let result = f32::norm(&a);
503        assert!(
504            (result - 5.0_f32).abs() < EPSILON_F32,
505            "Expected norm 5.0, got {result}"
506        );
507    }
508
509    #[test]
510    fn test_f32_sum_and_mean() {
511        let a = [1.0_f32, 2.0, 3.0, 4.0];
512        let sum = f32::sum(&a);
513        let mean = f32::mean(&a);
514        assert!(
515            (sum - 10.0_f32).abs() < EPSILON_F32,
516            "Expected sum 10.0, got {sum}"
517        );
518        assert!(
519            (mean - 2.5_f32).abs() < EPSILON_F32,
520            "Expected mean 2.5, got {mean}"
521        );
522    }
523
524    #[test]
525    fn test_f32_mean_empty_slice() {
526        let a: [f32; 0] = [];
527        let result = f32::mean(&a);
528        assert!(
529            (result - 0.0_f32).abs() < EPSILON_F32,
530            "Mean of empty slice should be 0.0, got {result}"
531        );
532    }
533
534    #[test]
535    fn test_f32_add_element_wise() {
536        let a = [1.0_f32, 2.0, 3.0];
537        let b = [4.0_f32, 5.0, 6.0];
538        let result = f32::add(&a, &b);
539        assert_eq!(result.len(), 3);
540        assert!((result[0] - 5.0_f32).abs() < EPSILON_F32);
541        assert!((result[1] - 7.0_f32).abs() < EPSILON_F32);
542        assert!((result[2] - 9.0_f32).abs() < EPSILON_F32);
543    }
544
545    #[test]
546    fn test_f32_sub_element_wise() {
547        let a = [5.0_f32, 7.0, 9.0];
548        let b = [1.0_f32, 2.0, 3.0];
549        let result = f32::sub(&a, &b);
550        assert_eq!(result.len(), 3);
551        assert!((result[0] - 4.0_f32).abs() < EPSILON_F32);
552        assert!((result[1] - 5.0_f32).abs() < EPSILON_F32);
553        assert!((result[2] - 6.0_f32).abs() < EPSILON_F32);
554    }
555
556    #[test]
557    fn test_f32_mul_element_wise() {
558        let a = [2.0_f32, 3.0, 4.0];
559        let b = [5.0_f32, 6.0, 7.0];
560        let result = f32::mul(&a, &b);
561        assert_eq!(result.len(), 3);
562        assert!((result[0] - 10.0_f32).abs() < EPSILON_F32);
563        assert!((result[1] - 18.0_f32).abs() < EPSILON_F32);
564        assert!((result[2] - 28.0_f32).abs() < EPSILON_F32);
565    }
566
567    // --- f64 tests ---
568
569    #[test]
570    fn test_f64_dot_product_basic() {
571        let a = [1.0_f64, 2.0, 3.0];
572        let b = [4.0_f64, 5.0, 6.0];
573        let result = f64::dot(&a, &b);
574        assert!(
575            (result - 32.0_f64).abs() < EPSILON_F64,
576            "Expected 32.0, got {result}"
577        );
578    }
579
580    #[test]
581    fn test_f64_euclidean_distance_zero() {
582        let a = [1.0_f64, 2.0, 3.0];
583        let b = [1.0_f64, 2.0, 3.0];
584        let result = f64::euclidean_distance(&a, &b);
585        assert!(
586            result.abs() < EPSILON_F64,
587            "Identical vectors should have distance 0, got {result}"
588        );
589    }
590
591    #[test]
592    fn test_f64_cosine_distance_opposite_vectors() {
593        // Opposite vectors: a = [1,0,0], b = [-1,0,0]
594        // cosine_similarity = -1, cosine_distance = 1 - (-1) = 2
595        let a = [1.0_f64, 0.0, 0.0];
596        let b = [-1.0_f64, 0.0, 0.0];
597        let result = f64::cosine_distance(&a, &b);
598        assert!(
599            (result - 2.0_f64).abs() < EPSILON_F64,
600            "Opposite vectors should have cosine distance 2.0, got {result}"
601        );
602    }
603
604    #[test]
605    fn test_f64_manhattan_distance_symmetry() {
606        let a = [1.0_f64, 2.0, 3.0];
607        let b = [4.0_f64, 6.0, 8.0];
608        let d_ab = f64::manhattan_distance(&a, &b);
609        let d_ba = f64::manhattan_distance(&b, &a);
610        assert!(
611            (d_ab - d_ba).abs() < EPSILON_F64,
612            "Manhattan distance should be symmetric"
613        );
614    }
615
616    #[test]
617    fn test_f64_norm_of_standard_basis() {
618        let a = [0.0_f64, 0.0, 1.0, 0.0];
619        let result = f64::norm(&a);
620        assert!(
621            (result - 1.0_f64).abs() < EPSILON_F64,
622            "Norm of standard basis vector should be 1.0, got {result}"
623        );
624    }
625
626    #[test]
627    fn test_f64_sum_large_slice() {
628        // Verify sum of arithmetic series: 1+2+...+100 = 5050
629        let a: Vec<f64> = (1..=100).map(|x| x as f64).collect();
630        let result = f64::sum(&a);
631        assert!(
632            (result - 5050.0_f64).abs() < EPSILON_F64,
633            "Expected 5050.0, got {result}"
634        );
635    }
636
637    #[test]
638    fn test_f64_mean_empty_slice() {
639        let a: [f64; 0] = [];
640        let result = f64::mean(&a);
641        assert!(
642            result.abs() < EPSILON_F64,
643            "Mean of empty slice should be 0.0, got {result}"
644        );
645    }
646
647    #[test]
648    fn test_f64_add_sub_roundtrip() {
649        let a = [3.0_f64, 1.0, 4.0, 1.0, 5.0];
650        let b = [1.0_f64, 2.0, 3.0, 4.0, 5.0];
651        let added = f64::add(&a, &b);
652        let subtracted = f64::sub(added.as_slice(), &b);
653        for (orig, recovered) in a.iter().zip(subtracted.iter()) {
654            assert!(
655                (orig - recovered).abs() < EPSILON_F64,
656                "Add-sub roundtrip failed: {orig} vs {recovered}"
657            );
658        }
659    }
660
661    // --- Cross-type consistency tests ---
662
663    #[test]
664    fn test_euclidean_and_manhattan_triangle_inequality() {
665        // For any two vectors, Euclidean distance <= Manhattan distance
666        let a = [1.0_f32, 2.0, 3.0];
667        let b = [4.0_f32, 6.0, 8.0];
668        let euclidean = f32::euclidean_distance(&a, &b);
669        let manhattan = f32::manhattan_distance(&a, &b);
670        assert!(
671            euclidean <= manhattan + EPSILON_F32,
672            "Euclidean should be <= Manhattan: {euclidean} vs {manhattan}"
673        );
674    }
675
676    #[test]
677    fn test_cosine_distance_range() {
678        // cosine_distance should be in [0, 2] for real vectors
679        let a = [1.0_f32, 0.5, 0.25];
680        let b = [0.5_f32, 1.0, 2.0];
681        let result = f32::cosine_distance(&a, &b);
682        assert!(
683            result >= 0.0,
684            "Cosine distance should be non-negative, got {result}"
685        );
686        assert!(
687            result <= 2.0 + EPSILON_F32,
688            "Cosine distance should be <= 2.0, got {result}"
689        );
690    }
691}