Skip to main content

sklears_simd/
safe_simd.rs

1//! Type-safe SIMD abstractions with compile-time guarantees
2//!
3//! This module provides zero-cost abstractions for SIMD operations with enhanced type safety,
4//! compile-time lane validation, and phantom types for SIMD width verification.
5
6use core::marker::PhantomData;
7use core::ops::{Add, Div, Mul, Sub};
8
9#[cfg(feature = "no-std")]
10use alloc::vec::Vec;
11#[cfg(not(feature = "no-std"))]
12use std::vec::Vec;
13
14/// Phantom type for SIMD width at compile time
15pub struct SimdWidth<const WIDTH: usize>;
16
17/// Type-safe SIMD vector with compile-time width verification
18#[derive(Debug, Clone)]
19pub struct SafeSimdVector<T, const WIDTH: usize> {
20    data: [T; WIDTH],
21    _phantom: PhantomData<SimdWidth<WIDTH>>,
22}
23
24impl<T, const WIDTH: usize> SafeSimdVector<T, WIDTH>
25where
26    T: Copy + Default,
27{
28    /// Create a new SIMD vector with compile-time width validation
29    pub const fn new(data: [T; WIDTH]) -> Self {
30        Self {
31            data,
32            _phantom: PhantomData,
33        }
34    }
35
36    /// Create a vector filled with a single value
37    pub fn splat(value: T) -> Self {
38        Self {
39            data: [value; WIDTH],
40            _phantom: PhantomData,
41        }
42    }
43
44    /// Get the width of this SIMD vector at compile time
45    pub fn width(&self) -> usize {
46        WIDTH
47    }
48
49    /// Access the underlying data
50    pub fn as_slice(&self) -> &[T] {
51        &self.data
52    }
53
54    /// Mutable access to the underlying data
55    pub fn as_mut_slice(&mut self) -> &mut [T] {
56        &mut self.data
57    }
58
59    /// Convert to array
60    pub fn into_array(self) -> [T; WIDTH] {
61        self.data
62    }
63
64    /// Load from slice with compile-time length checking
65    pub fn from_slice(slice: &[T]) -> Option<Self> {
66        if slice.len() >= WIDTH {
67            let mut data = [T::default(); WIDTH];
68            data.copy_from_slice(&slice[..WIDTH]);
69            Some(Self::new(data))
70        } else {
71            None
72        }
73    }
74
75    /// Safely extract a lane
76    pub fn extract_lane(&self, lane: usize) -> Option<T> {
77        if lane < WIDTH {
78            Some(self.data[lane])
79        } else {
80            None
81        }
82    }
83
84    /// Safely replace a lane
85    pub fn replace_lane(&mut self, lane: usize, value: T) -> bool {
86        if lane < WIDTH {
87            self.data[lane] = value;
88            true
89        } else {
90            false
91        }
92    }
93}
94
95/// Type-safe SIMD operations with lane validation
96pub trait SimdOperation<T, const WIDTH: usize> {
97    type Output;
98
99    fn apply(&self, input: &SafeSimdVector<T, WIDTH>) -> Self::Output;
100}
101
102/// Zero-cost abstraction for element-wise operations
103pub struct ElementWiseOp<F, T, const WIDTH: usize> {
104    func: F,
105    _phantom: PhantomData<(T, SimdWidth<WIDTH>)>,
106}
107
108impl<F, T, const WIDTH: usize> ElementWiseOp<F, T, WIDTH>
109where
110    F: Fn(T) -> T,
111    T: Copy,
112{
113    pub const fn new(func: F) -> Self {
114        Self {
115            func,
116            _phantom: PhantomData,
117        }
118    }
119}
120
121impl<F, T, const WIDTH: usize> SimdOperation<T, WIDTH> for ElementWiseOp<F, T, WIDTH>
122where
123    F: Fn(T) -> T,
124    T: Copy + Default,
125{
126    type Output = SafeSimdVector<T, WIDTH>;
127
128    fn apply(&self, input: &SafeSimdVector<T, WIDTH>) -> Self::Output {
129        let mut result = [T::default(); WIDTH];
130        for (r, &val) in result.iter_mut().zip(input.data.iter()) {
131            *r = (self.func)(val);
132        }
133        SafeSimdVector::new(result)
134    }
135}
136
137/// Compile-time validated SIMD arithmetic operations
138impl<T, const WIDTH: usize> Add for SafeSimdVector<T, WIDTH>
139where
140    T: Add<Output = T> + Copy + Default,
141{
142    type Output = Self;
143
144    fn add(self, rhs: Self) -> Self::Output {
145        let mut result = [T::default(); WIDTH];
146        for (r, (a, b)) in result.iter_mut().zip(self.data.iter().zip(rhs.data.iter())) {
147            *r = *a + *b;
148        }
149        Self::new(result)
150    }
151}
152
153impl<T, const WIDTH: usize> Sub for SafeSimdVector<T, WIDTH>
154where
155    T: Sub<Output = T> + Copy + Default,
156{
157    type Output = Self;
158
159    fn sub(self, rhs: Self) -> Self::Output {
160        let mut result = [T::default(); WIDTH];
161        for (r, (a, b)) in result.iter_mut().zip(self.data.iter().zip(rhs.data.iter())) {
162            *r = *a - *b;
163        }
164        Self::new(result)
165    }
166}
167
168impl<T, const WIDTH: usize> Mul for SafeSimdVector<T, WIDTH>
169where
170    T: Mul<Output = T> + Copy + Default,
171{
172    type Output = Self;
173
174    fn mul(self, rhs: Self) -> Self::Output {
175        let mut result = [T::default(); WIDTH];
176        for (r, (a, b)) in result.iter_mut().zip(self.data.iter().zip(rhs.data.iter())) {
177            *r = *a * *b;
178        }
179        Self::new(result)
180    }
181}
182
183impl<T, const WIDTH: usize> Div for SafeSimdVector<T, WIDTH>
184where
185    T: Div<Output = T> + Copy + Default,
186{
187    type Output = Self;
188
189    fn div(self, rhs: Self) -> Self::Output {
190        let mut result = [T::default(); WIDTH];
191        for (r, (a, b)) in result.iter_mut().zip(self.data.iter().zip(rhs.data.iter())) {
192            *r = *a / *b;
193        }
194        Self::new(result)
195    }
196}
197
198/// Type-safe SIMD width constants for common architectures
199pub mod widths {
200    pub type Scalar = super::SimdWidth<1>;
201    pub type Sse = super::SimdWidth<4>; // 128-bit / 32-bit = 4 lanes
202    pub type Avx = super::SimdWidth<8>; // 256-bit / 32-bit = 8 lanes
203    pub type Avx512 = super::SimdWidth<16>; // 512-bit / 32-bit = 16 lanes
204
205    pub type SseF64 = super::SimdWidth<2>; // 128-bit / 64-bit = 2 lanes
206    pub type AvxF64 = super::SimdWidth<4>; // 256-bit / 64-bit = 4 lanes
207    pub type Avx512F64 = super::SimdWidth<8>; // 512-bit / 64-bit = 8 lanes
208}
209
210/// Type-safe vector types for common SIMD widths
211pub type SimdF32x4 = SafeSimdVector<f32, 4>;
212pub type SimdF32x8 = SafeSimdVector<f32, 8>;
213pub type SimdF32x16 = SafeSimdVector<f32, 16>;
214
215pub type SimdF64x2 = SafeSimdVector<f64, 2>;
216pub type SimdF64x4 = SafeSimdVector<f64, 4>;
217pub type SimdF64x8 = SafeSimdVector<f64, 8>;
218
219pub type SimdU32x4 = SafeSimdVector<u32, 4>;
220pub type SimdU32x8 = SafeSimdVector<u32, 8>;
221pub type SimdU32x16 = SafeSimdVector<u32, 16>;
222
223/// Compile-time SIMD capability checking
224pub mod capabilities {
225
226    /// Trait for SIMD capability validation at compile time
227    pub trait SimdCapable<const WIDTH: usize> {
228        fn is_supported() -> bool;
229        fn best_width() -> usize;
230    }
231
232    /// X86 SIMD capabilities
233    pub struct X86Simd;
234
235    impl SimdCapable<4> for X86Simd {
236        fn is_supported() -> bool {
237            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
238            {
239                crate::simd_feature_detected!("sse")
240            }
241            #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
242            {
243                false
244            }
245        }
246
247        fn best_width() -> usize {
248            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
249            {
250                if crate::simd_feature_detected!("avx512f") {
251                    16
252                } else if crate::simd_feature_detected!("avx2") {
253                    8
254                } else if crate::simd_feature_detected!("sse") {
255                    4
256                } else {
257                    1
258                }
259            }
260            #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
261            {
262                1
263            }
264        }
265    }
266
267    impl SimdCapable<8> for X86Simd {
268        fn is_supported() -> bool {
269            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
270            {
271                crate::simd_feature_detected!("avx2")
272            }
273            #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
274            {
275                false
276            }
277        }
278
279        fn best_width() -> usize {
280            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
281            {
282                if crate::simd_feature_detected!("avx512f") {
283                    16
284                } else if crate::simd_feature_detected!("avx2") {
285                    8
286                } else if crate::simd_feature_detected!("sse") {
287                    4
288                } else {
289                    1
290                }
291            }
292            #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
293            {
294                1
295            }
296        }
297    }
298
299    impl SimdCapable<16> for X86Simd {
300        fn is_supported() -> bool {
301            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
302            {
303                crate::simd_feature_detected!("avx512f")
304            }
305            #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
306            {
307                false
308            }
309        }
310
311        fn best_width() -> usize {
312            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
313            {
314                if crate::simd_feature_detected!("avx512f") {
315                    16
316                } else if crate::simd_feature_detected!("avx2") {
317                    8
318                } else if crate::simd_feature_detected!("sse") {
319                    4
320                } else {
321                    1
322                }
323            }
324            #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
325            {
326                1
327            }
328        }
329    }
330
331    /// ARM NEON capabilities
332    pub struct ArmSimd;
333
334    impl SimdCapable<4> for ArmSimd {
335        fn is_supported() -> bool {
336            #[cfg(target_arch = "aarch64")]
337            {
338                true // NEON is always available on AArch64
339            }
340            #[cfg(not(target_arch = "aarch64"))]
341            {
342                false
343            }
344        }
345
346        fn best_width() -> usize {
347            #[cfg(target_arch = "aarch64")]
348            {
349                4
350            }
351            #[cfg(not(target_arch = "aarch64"))]
352            {
353                1
354            }
355        }
356    }
357}
358
359/// Zero-cost wrapper for optimized operations with compile-time dispatch
360pub struct OptimizedSimdOp<T, const WIDTH: usize> {
361    _phantom: PhantomData<(T, SimdWidth<WIDTH>)>,
362}
363
364impl<T, const WIDTH: usize> Default for OptimizedSimdOp<T, WIDTH> {
365    fn default() -> Self {
366        Self::new()
367    }
368}
369
370impl<T, const WIDTH: usize> OptimizedSimdOp<T, WIDTH> {
371    pub const fn new() -> Self {
372        Self {
373            _phantom: PhantomData,
374        }
375    }
376
377    /// Dot product with compile-time width validation
378    pub fn dot_product(a: &SafeSimdVector<T, WIDTH>, b: &SafeSimdVector<T, WIDTH>) -> T
379    where
380        T: Mul<Output = T> + Add<Output = T> + Default + Copy,
381    {
382        let mut result = T::default();
383        for i in 0..WIDTH {
384            result = result + (a.data[i] * b.data[i]);
385        }
386        result
387    }
388
389    /// Element-wise multiplication with type safety
390    pub fn element_wise_multiply(
391        a: &SafeSimdVector<T, WIDTH>,
392        b: &SafeSimdVector<T, WIDTH>,
393    ) -> SafeSimdVector<T, WIDTH>
394    where
395        T: Mul<Output = T> + Default + Copy,
396    {
397        let mut result = [T::default(); WIDTH];
398        for (r, (x, y)) in result.iter_mut().zip(a.data.iter().zip(b.data.iter())) {
399            *r = *x * *y;
400        }
401        SafeSimdVector::new(result)
402    }
403
404    /// Reduction operations with compile-time lane validation
405    pub fn horizontal_sum(vector: &SafeSimdVector<T, WIDTH>) -> T
406    where
407        T: Add<Output = T> + Default + Copy,
408    {
409        let mut sum = T::default();
410        for i in 0..WIDTH {
411            sum = sum + vector.data[i];
412        }
413        sum
414    }
415
416    /// Find maximum with compile-time guarantees
417    pub fn horizontal_max(vector: &SafeSimdVector<T, WIDTH>) -> T
418    where
419        T: PartialOrd + Copy,
420    {
421        let mut max = vector.data[0];
422        for i in 1..WIDTH {
423            if vector.data[i] > max {
424                max = vector.data[i];
425            }
426        }
427        max
428    }
429}
430
431/// Compile-time validated slice operations
432pub struct SafeSliceOps;
433
434impl SafeSliceOps {
435    /// Process slices with compile-time SIMD width validation
436    pub fn process_slice_vectorized<T, F, const CHUNK_SIZE: usize>(
437        data: &[T],
438        mut func: F,
439    ) -> Vec<T>
440    where
441        T: Copy + Default,
442        F: FnMut(&SafeSimdVector<T, CHUNK_SIZE>) -> SafeSimdVector<T, CHUNK_SIZE>,
443        [(); CHUNK_SIZE]:,
444    {
445        let mut result = Vec::with_capacity(data.len());
446
447        // Process complete chunks
448        for chunk in data.chunks_exact(CHUNK_SIZE) {
449            if let Some(simd_chunk) = SafeSimdVector::<T, CHUNK_SIZE>::from_slice(chunk) {
450                let processed = func(&simd_chunk);
451                result.extend_from_slice(processed.as_slice());
452            }
453        }
454
455        // Handle remainder by processing each element individually
456        let remainder_start = data.len() - (data.len() % CHUNK_SIZE);
457        for &item in &data[remainder_start..] {
458            // Create a single-element "SIMD" vector and process it
459            let mut single_data = [T::default(); CHUNK_SIZE];
460            single_data[0] = item;
461            if let Some(simd_single) = SafeSimdVector::<T, CHUNK_SIZE>::from_slice(&single_data) {
462                let processed = func(&simd_single);
463                result.push(processed.as_slice()[0]);
464            }
465        }
466
467        result
468    }
469
470    /// Safe dot product with compile-time width checking
471    pub fn dot_product_safe<T, const WIDTH: usize>(a: &[T], b: &[T]) -> Option<T>
472    where
473        T: Mul<Output = T> + Add<Output = T> + Default + Copy,
474        [(); WIDTH]:,
475    {
476        if a.len() != b.len() || a.len() < WIDTH {
477            return None;
478        }
479
480        let mut result = T::default();
481        let chunks_a = a.chunks_exact(WIDTH);
482        let chunks_b = b.chunks_exact(WIDTH);
483
484        for (chunk_a, chunk_b) in chunks_a.zip(chunks_b) {
485            if let (Some(vec_a), Some(vec_b)) = (
486                SafeSimdVector::<T, WIDTH>::from_slice(chunk_a),
487                SafeSimdVector::<T, WIDTH>::from_slice(chunk_b),
488            ) {
489                result = result + OptimizedSimdOp::<T, WIDTH>::dot_product(&vec_a, &vec_b);
490            }
491        }
492
493        // Handle remainder
494        let remainder = a.len() % WIDTH;
495        for i in 0..remainder {
496            let idx = a.len() - remainder + i;
497            result = result + (a[idx] * b[idx]);
498        }
499
500        Some(result)
501    }
502}
503
504/// Trait for type-safe SIMD operations with zero-cost abstractions
505pub trait TypeSafeSimd<T> {
506    type Output;
507
508    fn apply_safe(&self, input: &[T]) -> Self::Output;
509}
510
511/// Implementation for common mathematical operations
512pub struct SafeMathOps;
513
514impl SafeMathOps {
515    /// Type-safe square root with compile-time width
516    pub fn sqrt_vectorized<const WIDTH: usize>(data: &[f32]) -> Vec<f32>
517    where
518        [(); WIDTH]:,
519    {
520        SafeSliceOps::process_slice_vectorized::<f32, _, WIDTH>(data, |chunk| {
521            let op = ElementWiseOp::new(|x: f32| x.sqrt());
522            op.apply(chunk)
523        })
524    }
525
526    /// Type-safe exponential with compile-time width
527    pub fn exp_vectorized<const WIDTH: usize>(data: &[f32]) -> Vec<f32>
528    where
529        [(); WIDTH]:,
530    {
531        SafeSliceOps::process_slice_vectorized::<f32, _, WIDTH>(data, |chunk| {
532            let op = ElementWiseOp::new(|x: f32| x.exp());
533            op.apply(chunk)
534        })
535    }
536
537    /// Type-safe polynomial evaluation
538    pub fn polynomial_vectorized<const WIDTH: usize>(data: &[f32], coefficients: &[f32]) -> Vec<f32>
539    where
540        [(); WIDTH]:,
541    {
542        SafeSliceOps::process_slice_vectorized::<f32, _, WIDTH>(data, |chunk| {
543            let op = ElementWiseOp::new(|x: f32| {
544                coefficients
545                    .iter()
546                    .rev()
547                    .fold(0.0, |acc, &coeff| acc * x + coeff)
548            });
549            op.apply(chunk)
550        })
551    }
552}
553
554// Removed static assertions module - using runtime checks instead
555
556#[allow(non_snake_case)]
557#[cfg(all(test, not(feature = "no-std")))]
558mod tests {
559    use super::*;
560
561    #[cfg(feature = "no-std")]
562    use alloc::{vec, vec::Vec};
563
564    #[test]
565    fn test_safe_simd_vector_creation() {
566        let vec = SimdF32x4::new([1.0, 2.0, 3.0, 4.0]);
567        assert_eq!(vec.width(), 4);
568        assert_eq!(vec.as_slice(), &[1.0, 2.0, 3.0, 4.0]);
569    }
570
571    #[test]
572    fn test_safe_simd_arithmetic() {
573        let a = SimdF32x4::new([1.0, 2.0, 3.0, 4.0]);
574        let b = SimdF32x4::new([5.0, 6.0, 7.0, 8.0]);
575
576        let sum = a + b;
577        assert_eq!(sum.as_slice(), &[6.0, 8.0, 10.0, 12.0]);
578
579        let diff = SimdF32x4::new([10.0, 12.0, 14.0, 16.0]) - SimdF32x4::new([1.0, 2.0, 3.0, 4.0]);
580        assert_eq!(diff.as_slice(), &[9.0, 10.0, 11.0, 12.0]);
581    }
582
583    #[test]
584    fn test_lane_access() {
585        let vec = SimdF32x4::new([1.0, 2.0, 3.0, 4.0]);
586        assert_eq!(vec.extract_lane(0), Some(1.0));
587        assert_eq!(vec.extract_lane(1), Some(2.0));
588        assert_eq!(vec.extract_lane(3), Some(4.0));
589        assert_eq!(vec.extract_lane(4), None); // Out of bounds
590    }
591
592    #[test]
593    fn test_dot_product_safe() {
594        let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
595        let b = vec![8.0f32, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
596
597        let result = SafeSliceOps::dot_product_safe::<f32, 4>(&a, &b);
598        assert!(result.is_some());
599
600        let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
601        assert_eq!(result.expect("operation should succeed"), expected);
602    }
603
604    #[test]
605    fn test_element_wise_operations() {
606        let vec = SimdF32x4::new([1.0, 4.0, 9.0, 16.0]);
607        let op = ElementWiseOp::new(|x: f32| x.sqrt());
608        let result = op.apply(&vec);
609
610        assert_eq!(result.as_slice(), &[1.0, 2.0, 3.0, 4.0]);
611    }
612
613    #[test]
614    fn test_horizontal_operations() {
615        let vec = SimdF32x4::new([1.0, 2.0, 3.0, 4.0]);
616
617        let sum = OptimizedSimdOp::<f32, 4>::horizontal_sum(&vec);
618        assert_eq!(sum, 10.0);
619
620        let max = OptimizedSimdOp::<f32, 4>::horizontal_max(&vec);
621        assert_eq!(max, 4.0);
622    }
623
624    #[test]
625    fn test_safe_math_operations() {
626        let data = vec![1.0, 4.0, 9.0, 16.0, 25.0, 36.0];
627        let result = SafeMathOps::sqrt_vectorized::<4>(&data);
628
629        let expected: Vec<f32> = data.iter().map(|x| x.sqrt()).collect();
630        for (a, b) in result.iter().zip(expected.iter()) {
631            assert!(
632                (a - b).abs() < 1e-4,
633                "sqrt({}) = {}, expected {}, diff = {}",
634                a * a,
635                a,
636                b,
637                (a - b).abs()
638            ); // Even more lenient with debug info
639        }
640    }
641
642    #[test]
643    fn test_from_slice_validation() {
644        let data = vec![1.0f32, 2.0, 3.0];
645        let vec = SimdF32x4::from_slice(&data);
646        assert!(vec.is_none()); // Not enough elements
647
648        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
649        let vec = SimdF32x4::from_slice(&data);
650        assert!(vec.is_some());
651        assert_eq!(
652            vec.expect("slice operation should succeed").as_slice(),
653            &[1.0, 2.0, 3.0, 4.0]
654        );
655    }
656
657    #[test]
658    fn test_capability_detection() {
659        use super::capabilities::SimdCapable;
660
661        // This will vary by platform, but should not panic
662        let _sse_supported = <capabilities::X86Simd as SimdCapable<4>>::is_supported();
663        let _best_width = <capabilities::X86Simd as SimdCapable<4>>::best_width();
664
665        // ARM test (will be false on x86, true on ARM)
666        let _neon_supported = <capabilities::ArmSimd as SimdCapable<4>>::is_supported();
667    }
668
669    #[test]
670    fn test_zero_cost_abstractions() {
671        // These operations should compile to efficient code
672        let a = SimdF32x4::splat(2.0);
673        let b = SimdF32x4::splat(3.0);
674
675        let result = OptimizedSimdOp::<f32, 4>::element_wise_multiply(&a, &b);
676        assert_eq!(result.as_slice(), &[6.0, 6.0, 6.0, 6.0]);
677    }
678}