oxiblas_core/
simd.rs

1//! SIMD abstraction layer for OxiBLAS.
2//!
3//! This module provides a unified interface over architecture-specific SIMD
4//! intrinsics from `core::arch`. It supports:
5//! - x86_64: AVX2 (256-bit), AVX512F (512-bit), SSE4.2 (128-bit)
6//! - AArch64: NEON (128-bit), 256-bit emulated
7//! - WASM32: SIMD128 (128-bit), 256-bit emulated
8//! - Scalar fallback for unsupported platforms
9//!
10//! The design uses runtime feature detection to dispatch to the best
11//! available implementation.
12//!
13//! # Complex SIMD
14//!
15//! The `complex` submodule provides SIMD types for complex numbers in
16//! interleaved format `[re0, im0, re1, im1, ...]`.
17
18#[cfg(target_arch = "x86_64")]
19pub mod x86_64;
20
21#[cfg(target_arch = "aarch64")]
22pub mod aarch64;
23
24#[cfg(target_arch = "wasm32")]
25pub mod wasm32;
26
27pub mod complex;
28pub mod dispatch;
29pub mod scalar;
30
31use crate::scalar::{Field, Real, Scalar};
32
33/// SIMD capability level detected at runtime.
34#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
35pub enum SimdLevel {
36    /// No SIMD, scalar operations only
37    Scalar,
38    /// 128-bit SIMD (SSE2 on x86, NEON on ARM)
39    Simd128,
40    /// 256-bit SIMD (AVX2 on x86)
41    Simd256,
42    /// 512-bit SIMD (AVX512F on x86)
43    Simd512,
44}
45
46impl SimdLevel {
47    /// Returns the number of lanes for a given scalar type.
48    #[inline]
49    pub const fn lanes<T: Scalar>(self) -> usize {
50        match self {
51            SimdLevel::Scalar => 1,
52            SimdLevel::Simd128 => 16 / core::mem::size_of::<T>(),
53            SimdLevel::Simd256 => 32 / core::mem::size_of::<T>(),
54            SimdLevel::Simd512 => 64 / core::mem::size_of::<T>(),
55        }
56    }
57
58    /// Returns the register width in bytes.
59    #[inline]
60    pub const fn width_bytes(self) -> usize {
61        match self {
62            SimdLevel::Scalar => 8, // Treat as 64-bit for alignment
63            SimdLevel::Simd128 => 16,
64            SimdLevel::Simd256 => 32,
65            SimdLevel::Simd512 => 64,
66        }
67    }
68}
69
70/// Detects the best available SIMD level at runtime.
71///
72/// This function respects the following feature flags:
73/// - `force-scalar`: Always returns `SimdLevel::Scalar` (useful for debugging)
74/// - `max-simd-128`: Limits maximum to `SimdLevel::Simd128`
75/// - `max-simd-256`: Limits maximum to `SimdLevel::Simd256`
76#[inline]
77pub fn detect_simd_level() -> SimdLevel {
78    // Feature flag: force scalar operations (useful for debugging)
79    #[cfg(feature = "force-scalar")]
80    {
81        SimdLevel::Scalar
82    }
83
84    #[cfg(not(feature = "force-scalar"))]
85    {
86        let detected = detect_simd_level_raw();
87
88        // Apply maximum SIMD level limits from feature flags
89        #[cfg(feature = "max-simd-128")]
90        {
91            return if detected > SimdLevel::Simd128 {
92                SimdLevel::Simd128
93            } else {
94                detected
95            };
96        }
97
98        #[cfg(feature = "max-simd-256")]
99        #[cfg(not(feature = "max-simd-128"))]
100        {
101            return if detected > SimdLevel::Simd256 {
102                SimdLevel::Simd256
103            } else {
104                detected
105            };
106        }
107
108        #[cfg(not(any(feature = "max-simd-128", feature = "max-simd-256")))]
109        {
110            detected
111        }
112    }
113}
114
115/// Raw SIMD level detection without feature flag limits.
116///
117/// This is the internal detection function that returns the actual
118/// hardware SIMD capability.
119#[inline]
120pub fn detect_simd_level_raw() -> SimdLevel {
121    #[cfg(target_arch = "x86_64")]
122    {
123        if is_x86_feature_detected!("avx512f") {
124            SimdLevel::Simd512
125        } else if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
126            SimdLevel::Simd256
127        } else if is_x86_feature_detected!("sse2") {
128            SimdLevel::Simd128
129        } else {
130            SimdLevel::Scalar
131        }
132    }
133
134    #[cfg(target_arch = "aarch64")]
135    {
136        // NEON is always available on AArch64
137        SimdLevel::Simd128
138    }
139
140    #[cfg(target_arch = "wasm32")]
141    {
142        // WASM SIMD128 when simd128 feature is enabled
143        #[cfg(target_feature = "simd128")]
144        {
145            SimdLevel::Simd128
146        }
147        #[cfg(not(target_feature = "simd128"))]
148        {
149            SimdLevel::Scalar
150        }
151    }
152
153    #[cfg(not(any(
154        target_arch = "x86_64",
155        target_arch = "aarch64",
156        target_arch = "wasm32"
157    )))]
158    {
159        SimdLevel::Scalar
160    }
161}
162
163/// Trait for SIMD-capable scalar types.
164///
165/// This trait provides the interface for types that can be vectorized
166/// using SIMD operations.
167pub trait SimdScalar: Field {
168    /// The 256-bit SIMD register type for this scalar (e.g., AVX2 on x86-64).
169    type Simd256: SimdRegister<Scalar = Self>;
170    /// The 512-bit SIMD register type for this scalar (e.g., AVX-512 on x86-64).
171    type Simd512: SimdRegister<Scalar = Self>;
172
173    /// Number of elements that fit in a 256-bit register.
174    const LANES_256: usize = 32 / core::mem::size_of::<Self>();
175
176    /// Number of elements that fit in a 512-bit register.
177    const LANES_512: usize = 64 / core::mem::size_of::<Self>();
178}
179
180/// Trait for SIMD register types.
181///
182/// This provides a unified interface for SIMD operations across
183/// different architectures and vector widths.
184pub trait SimdRegister: Copy + Clone + Send + Sync {
185    /// The scalar type this register holds.
186    type Scalar: SimdScalar;
187
188    /// Number of lanes in this register.
189    const LANES: usize;
190
191    /// Creates a register with all lanes set to zero.
192    fn zero() -> Self;
193
194    /// Creates a register with all lanes set to the same value.
195    fn splat(value: Self::Scalar) -> Self;
196
197    /// Loads from an aligned pointer.
198    ///
199    /// # Safety
200    /// The pointer must be aligned to the register width and point to
201    /// at least LANES valid elements.
202    unsafe fn load_aligned(ptr: *const Self::Scalar) -> Self;
203
204    /// Loads from an unaligned pointer.
205    ///
206    /// # Safety
207    /// The pointer must point to at least LANES valid elements.
208    unsafe fn load_unaligned(ptr: *const Self::Scalar) -> Self;
209
210    /// Stores to an aligned pointer.
211    ///
212    /// # Safety
213    /// The pointer must be aligned to the register width and point to
214    /// at least LANES valid writable elements.
215    unsafe fn store_aligned(self, ptr: *mut Self::Scalar);
216
217    /// Stores to an unaligned pointer.
218    ///
219    /// # Safety
220    /// The pointer must point to at least LANES valid writable elements.
221    unsafe fn store_unaligned(self, ptr: *mut Self::Scalar);
222
223    /// Element-wise addition.
224    fn add(self, other: Self) -> Self;
225
226    /// Element-wise subtraction.
227    fn sub(self, other: Self) -> Self;
228
229    /// Element-wise multiplication.
230    fn mul(self, other: Self) -> Self;
231
232    /// Element-wise division.
233    fn div(self, other: Self) -> Self;
234
235    /// Fused multiply-add: self * a + b
236    fn mul_add(self, a: Self, b: Self) -> Self;
237
238    /// Fused multiply-subtract: self * a - b
239    fn mul_sub(self, a: Self, b: Self) -> Self;
240
241    /// Fused negative multiply-add: -(self * a) + b = b - self * a
242    fn neg_mul_add(self, a: Self, b: Self) -> Self;
243
244    /// Horizontal sum of all lanes.
245    fn reduce_sum(self) -> Self::Scalar;
246
247    /// Horizontal maximum of all lanes (for real types).
248    fn reduce_max(self) -> Self::Scalar
249    where
250        Self::Scalar: Real;
251
252    /// Horizontal minimum of all lanes (for real types).
253    fn reduce_min(self) -> Self::Scalar
254    where
255        Self::Scalar: Real;
256
257    /// Extracts a single lane.
258    fn extract(self, index: usize) -> Self::Scalar;
259
260    /// Inserts a value into a single lane.
261    fn insert(self, index: usize, value: Self::Scalar) -> Self;
262}
263
264/// Extension trait for masked SIMD operations.
265pub trait SimdMask: SimdRegister {
266    /// The mask type for this register.
267    type Mask: Copy + Clone;
268
269    /// Creates a mask from a boolean array.
270    fn mask_from_bools(bools: &[bool]) -> Self::Mask;
271
272    /// Masked load: only loads elements where mask is true.
273    ///
274    /// # Safety
275    /// For lanes where mask is true, the corresponding pointer element must be valid.
276    unsafe fn load_masked(ptr: *const Self::Scalar, mask: Self::Mask, default: Self) -> Self;
277
278    /// Masked store: only stores elements where mask is true.
279    ///
280    /// # Safety
281    /// For lanes where mask is true, the corresponding pointer element must be valid and writable.
282    unsafe fn store_masked(self, ptr: *mut Self::Scalar, mask: Self::Mask);
283
284    /// Blends two registers based on mask: if mask\[i\] then a\[i\] else b\[i\].
285    fn blend(mask: Self::Mask, a: Self, b: Self) -> Self;
286}
287
288/// Helper struct for iterating over SIMD chunks with proper head/body/tail handling.
289#[derive(Debug, Clone, Copy)]
290pub struct SimdChunks {
291    /// Total number of elements.
292    pub len: usize,
293    /// Number of lanes per SIMD register.
294    pub lanes: usize,
295    /// Index where head (unaligned prefix) ends.
296    pub head_end: usize,
297    /// Index where body (aligned middle) ends.
298    pub body_end: usize,
299}
300
301impl SimdChunks {
302    /// Creates a new chunk iterator for the given length and alignment.
303    #[inline]
304    pub fn new<T: Scalar>(ptr: *const T, len: usize, level: SimdLevel) -> Self {
305        let lanes = level.lanes::<T>();
306        let align = level.width_bytes();
307
308        if lanes <= 1 || len < lanes * 2 {
309            // Not worth SIMD, treat everything as head
310            return SimdChunks {
311                len,
312                lanes,
313                head_end: len,
314                body_end: len,
315            };
316        }
317
318        let addr = ptr as usize;
319        let misalign = addr % align;
320
321        let head_end = if misalign == 0 {
322            0
323        } else {
324            let elements_to_align = (align - misalign) / core::mem::size_of::<T>();
325            elements_to_align.min(len)
326        };
327
328        let remaining = len - head_end;
329        let full_vectors = remaining / lanes;
330        let body_end = head_end + full_vectors * lanes;
331
332        SimdChunks {
333            len,
334            lanes,
335            head_end,
336            body_end,
337        }
338    }
339
340    /// Returns the number of head elements (before aligned body).
341    #[inline]
342    pub fn head_len(&self) -> usize {
343        self.head_end
344    }
345
346    /// Returns the number of body elements (aligned middle).
347    #[inline]
348    pub fn body_len(&self) -> usize {
349        self.body_end - self.head_end
350    }
351
352    /// Returns the number of tail elements (after aligned body).
353    #[inline]
354    pub fn tail_len(&self) -> usize {
355        self.len - self.body_end
356    }
357
358    /// Returns the number of full SIMD vectors in the body.
359    #[inline]
360    pub fn body_vectors(&self) -> usize {
361        self.body_len() / self.lanes
362    }
363}
364
365/// Dispatch macro for runtime SIMD level selection.
366///
367/// This macro generates code that dispatches to the appropriate SIMD
368/// implementation based on the detected CPU features.
369#[macro_export]
370macro_rules! simd_dispatch {
371    ($level:expr, $scalar:ty, |$reg:ident| $body:expr) => {{
372        match $level {
373            $crate::simd::SimdLevel::Simd512 => {
374                #[cfg(target_arch = "x86_64")]
375                {
376                    type $reg = <$scalar as $crate::simd::SimdScalar>::Simd512;
377                    $body
378                }
379                #[cfg(not(target_arch = "x86_64"))]
380                {
381                    // Fallback to 256-bit
382                    type $reg = <$scalar as $crate::simd::SimdScalar>::Simd256;
383                    $body
384                }
385            }
386            $crate::simd::SimdLevel::Simd256 | $crate::simd::SimdLevel::Simd128 => {
387                type $reg = <$scalar as $crate::simd::SimdScalar>::Simd256;
388                $body
389            }
390            $crate::simd::SimdLevel::Scalar => {
391                // Scalar fallback - process one element at a time
392                type $reg = $scalar;
393                $body
394            }
395        }
396    }};
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    #[test]
404    fn test_detect_simd_level() {
405        let level = detect_simd_level();
406        println!("Detected SIMD level: {:?}", level);
407
408        // When force-scalar is enabled, should always be Scalar
409        #[cfg(feature = "force-scalar")]
410        {
411            assert_eq!(level, SimdLevel::Scalar);
412            // But raw detection should still show hardware capability
413            let raw = detect_simd_level_raw();
414            println!("Raw hardware SIMD level: {:?}", raw);
415        }
416
417        // Without force-scalar, should detect hardware SIMD
418        #[cfg(not(feature = "force-scalar"))]
419        {
420            #[cfg(target_arch = "x86_64")]
421            assert!(level >= SimdLevel::Simd128);
422
423            #[cfg(target_arch = "aarch64")]
424            assert_eq!(level, SimdLevel::Simd128);
425        }
426    }
427
428    #[test]
429    fn test_simd_level_lanes() {
430        assert_eq!(SimdLevel::Simd256.lanes::<f64>(), 4);
431        assert_eq!(SimdLevel::Simd256.lanes::<f32>(), 8);
432        assert_eq!(SimdLevel::Simd512.lanes::<f64>(), 8);
433        assert_eq!(SimdLevel::Simd512.lanes::<f32>(), 16);
434    }
435
436    #[test]
437    fn test_simd_chunks() {
438        // Create a pointer with known alignment
439        let data: Vec<f64> = vec![0.0; 100];
440        let ptr = data.as_ptr();
441
442        let chunks = SimdChunks::new(ptr, 100, SimdLevel::Simd256);
443        println!(
444            "Chunks: head_end={}, body_end={}",
445            chunks.head_end, chunks.body_end
446        );
447
448        // Verify that head + body + tail = len
449        assert_eq!(
450            chunks.head_len() + chunks.body_len() + chunks.tail_len(),
451            100
452        );
453    }
454
455    // =============================================================================
456    // Comprehensive SIMD correctness tests
457    // =============================================================================
458
459    /// Test scalar fallback FMA accuracy.
460    #[test]
461    fn test_scalar_fma_accuracy() {
462        use crate::simd::scalar::ScalarF64;
463
464        let a = ScalarF64::splat(1.0 + 1e-15);
465        let b = ScalarF64::splat(1.0 + 1e-15);
466        let c = ScalarF64::splat(-(1.0 + 2e-15));
467
468        // FMA should preserve more precision than separate mul+add
469        let fma_result = a.mul_add(b, c);
470        let mul_add_result = a.mul(b).add(c);
471
472        // Both should be very small but may differ slightly
473        assert!(fma_result.0.abs() < 1e-14);
474        assert!(mul_add_result.0.abs() < 1e-14);
475    }
476
477    /// Test load/store roundtrip.
478    #[test]
479    fn test_load_store_roundtrip() {
480        use crate::simd::scalar::ScalarF64;
481
482        let values = [42.0f64, 1.5, -3.5, 1000.0];
483
484        for &val in &values {
485            let v = ScalarF64::splat(val);
486            assert_eq!(v.reduce_sum(), val);
487            assert_eq!(v.extract(0), val);
488        }
489    }
490
491    /// Test arithmetic identities.
492    #[test]
493    fn test_arithmetic_identities() {
494        use crate::simd::scalar::{ScalarF32, ScalarF64};
495
496        // Test with f64
497        let a = ScalarF64::splat(5.0);
498        let zero = ScalarF64::zero();
499        let one = ScalarF64::splat(1.0);
500
501        // a + 0 = a
502        assert_eq!(a.add(zero).0, 5.0);
503        // a - 0 = a
504        assert_eq!(a.sub(zero).0, 5.0);
505        // a * 1 = a
506        assert_eq!(a.mul(one).0, 5.0);
507        // a / 1 = a
508        assert_eq!(a.div(one).0, 5.0);
509        // a * 0 = 0
510        assert_eq!(a.mul(zero).0, 0.0);
511
512        // Test with f32
513        let a32 = ScalarF32::splat(5.0);
514        let zero32 = ScalarF32::zero();
515        let one32 = ScalarF32::splat(1.0);
516
517        assert_eq!(a32.add(zero32).0, 5.0);
518        assert_eq!(a32.mul(one32).0, 5.0);
519    }
520
521    /// Test reduction operations.
522    #[test]
523    fn test_reductions() {
524        use crate::simd::scalar::{ScalarF32, ScalarF64};
525
526        // For scalar types, all reductions return the same value
527        let a = ScalarF64::splat(42.0);
528        assert_eq!(a.reduce_sum(), 42.0);
529        assert_eq!(a.reduce_max(), 42.0);
530        assert_eq!(a.reduce_min(), 42.0);
531
532        let b = ScalarF32::splat(-3.5);
533        assert_eq!(b.reduce_sum(), -3.5);
534        assert_eq!(b.reduce_max(), -3.5);
535        assert_eq!(b.reduce_min(), -3.5);
536    }
537
538    /// Test negative value handling.
539    #[test]
540    fn test_negative_values() {
541        use crate::simd::scalar::ScalarF64;
542
543        let neg = ScalarF64::splat(-5.0);
544        let pos = ScalarF64::splat(3.0);
545
546        // -5 + 3 = -2
547        assert_eq!(neg.add(pos).0, -2.0);
548        // -5 * 3 = -15
549        assert_eq!(neg.mul(pos).0, -15.0);
550        // -5 - 3 = -8
551        assert_eq!(neg.sub(pos).0, -8.0);
552    }
553
554    /// Test FMA variants.
555    #[test]
556    fn test_fma_variants() {
557        use crate::simd::scalar::ScalarF64;
558
559        let a = ScalarF64::splat(2.0);
560        let b = ScalarF64::splat(3.0);
561        let c = ScalarF64::splat(4.0);
562
563        // mul_add: a * b + c = 2 * 3 + 4 = 10
564        assert_eq!(a.mul_add(b, c).0, 10.0);
565
566        // mul_sub: a * b - c = 2 * 3 - 4 = 2
567        assert_eq!(a.mul_sub(b, c).0, 2.0);
568
569        // neg_mul_add: -(a * b) + c = -6 + 4 = -2
570        assert_eq!(a.neg_mul_add(b, c).0, -2.0);
571    }
572
573    /// Test insert/extract operations.
574    #[test]
575    fn test_insert_extract() {
576        use crate::simd::scalar::ScalarF64;
577
578        let a = ScalarF64::splat(1.0);
579        let b = a.insert(0, 42.0);
580        assert_eq!(b.extract(0), 42.0);
581    }
582
583    /// Platform-specific tests for native SIMD.
584    #[cfg(target_arch = "aarch64")]
585    #[test]
586    fn test_aarch64_simd_correctness() {
587        use crate::simd::aarch64::{F32x4, F64x2, F64x4};
588
589        // Test F64x2
590        let a = F64x2::splat(2.0);
591        let b = F64x2::splat(3.0);
592
593        let sum = a.add(b);
594        assert_eq!(sum.extract(0), 5.0);
595        assert_eq!(sum.extract(1), 5.0);
596
597        let fma = a.mul_add(b, F64x2::splat(1.0));
598        assert_eq!(fma.extract(0), 7.0); // 2*3 + 1
599
600        // Test F64x4 (emulated)
601        let c = F64x4::splat(2.0);
602        let d = F64x4::splat(3.0);
603
604        assert_eq!(c.add(d).reduce_sum(), 20.0); // 4 * 5.0
605
606        // Test F32x4
607        let e = F32x4::splat(2.0);
608        let f = F32x4::splat(3.0);
609
610        assert_eq!(e.add(f).reduce_sum(), 20.0); // 4 * 5.0
611    }
612}