Skip to main content

oxiblas_core/
simd.rs

1//! SIMD abstraction layer for OxiBLAS.
2//!
3//! This module provides a unified interface over architecture-specific SIMD
4//! intrinsics from `core::arch`. It supports:
5//! - x86_64: AVX2 (256-bit), AVX512F (512-bit), SSE4.2 (128-bit)
6//! - AArch64: NEON (128-bit), 256-bit emulated
7//! - WASM32: SIMD128 (128-bit), 256-bit emulated
8//! - Scalar fallback for unsupported platforms
9//!
10//! The design uses runtime feature detection to dispatch to the best
11//! available implementation.
12//!
13//! # Complex SIMD
14//!
15//! The `complex` submodule provides SIMD types for complex numbers in
16//! interleaved format `[re0, im0, re1, im1, ...]`.
17
18#[cfg(target_arch = "x86_64")]
19pub mod x86_64;
20
21#[cfg(target_arch = "aarch64")]
22pub mod aarch64;
23
24#[cfg(target_arch = "wasm32")]
25pub mod wasm32;
26
27pub mod complex;
28pub mod dispatch;
29pub mod multiver;
30pub mod scalar;
31
32use crate::scalar::{Field, Real, Scalar};
33
34/// SIMD capability level detected at runtime.
35#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
36pub enum SimdLevel {
37    /// No SIMD, scalar operations only
38    Scalar,
39    /// 128-bit SIMD (SSE2 on x86, NEON on ARM)
40    Simd128,
41    /// 256-bit SIMD (AVX2 on x86)
42    Simd256,
43    /// 512-bit SIMD (AVX512F on x86)
44    Simd512,
45}
46
47impl SimdLevel {
48    /// Returns the number of lanes for a given scalar type.
49    #[inline]
50    pub const fn lanes<T: Scalar>(self) -> usize {
51        match self {
52            SimdLevel::Scalar => 1,
53            SimdLevel::Simd128 => 16 / core::mem::size_of::<T>(),
54            SimdLevel::Simd256 => 32 / core::mem::size_of::<T>(),
55            SimdLevel::Simd512 => 64 / core::mem::size_of::<T>(),
56        }
57    }
58
59    /// Returns the register width in bytes.
60    #[inline]
61    pub const fn width_bytes(self) -> usize {
62        match self {
63            SimdLevel::Scalar => 8, // Treat as 64-bit for alignment
64            SimdLevel::Simd128 => 16,
65            SimdLevel::Simd256 => 32,
66            SimdLevel::Simd512 => 64,
67        }
68    }
69}
70
71/// Detects the best available SIMD level at runtime.
72///
73/// This function respects the following feature flags:
74/// - `force-scalar`: Always returns `SimdLevel::Scalar` (useful for debugging)
75/// - `max-simd-128`: Limits maximum to `SimdLevel::Simd128`
76/// - `max-simd-256`: Limits maximum to `SimdLevel::Simd256`
77#[inline]
78pub fn detect_simd_level() -> SimdLevel {
79    // Feature flag: force scalar operations (useful for debugging)
80    #[cfg(feature = "force-scalar")]
81    {
82        SimdLevel::Scalar
83    }
84
85    #[cfg(not(feature = "force-scalar"))]
86    {
87        let detected = detect_simd_level_raw();
88
89        // Apply maximum SIMD level limits from feature flags
90        #[cfg(feature = "max-simd-128")]
91        {
92            return if detected > SimdLevel::Simd128 {
93                SimdLevel::Simd128
94            } else {
95                detected
96            };
97        }
98
99        #[cfg(feature = "max-simd-256")]
100        #[cfg(not(feature = "max-simd-128"))]
101        {
102            return if detected > SimdLevel::Simd256 {
103                SimdLevel::Simd256
104            } else {
105                detected
106            };
107        }
108
109        #[cfg(not(any(feature = "max-simd-128", feature = "max-simd-256")))]
110        {
111            detected
112        }
113    }
114}
115
116/// Raw SIMD level detection without feature flag limits.
117///
118/// This is the internal detection function that returns the actual
119/// hardware SIMD capability.
120#[inline]
121pub fn detect_simd_level_raw() -> SimdLevel {
122    // On x86_64 with std, use runtime feature detection
123    #[cfg(all(target_arch = "x86_64", feature = "std"))]
124    {
125        if is_x86_feature_detected!("avx512f") {
126            SimdLevel::Simd512
127        } else if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
128            SimdLevel::Simd256
129        } else if is_x86_feature_detected!("sse2") {
130            SimdLevel::Simd128
131        } else {
132            SimdLevel::Scalar
133        }
134    }
135
136    // On x86_64 without std, use compile-time target features only
137    #[cfg(all(target_arch = "x86_64", not(feature = "std")))]
138    {
139        #[cfg(target_feature = "avx512f")]
140        {
141            SimdLevel::Simd512
142        }
143        #[cfg(all(
144            target_feature = "avx2",
145            target_feature = "fma",
146            not(target_feature = "avx512f")
147        ))]
148        {
149            SimdLevel::Simd256
150        }
151        #[cfg(all(
152            target_feature = "sse2",
153            not(target_feature = "avx2"),
154            not(target_feature = "avx512f")
155        ))]
156        {
157            SimdLevel::Simd128
158        }
159        #[cfg(not(any(
160            target_feature = "sse2",
161            target_feature = "avx2",
162            target_feature = "avx512f"
163        )))]
164        {
165            SimdLevel::Scalar
166        }
167    }
168
169    #[cfg(target_arch = "aarch64")]
170    {
171        // NEON is always available on AArch64
172        SimdLevel::Simd128
173    }
174
175    #[cfg(target_arch = "wasm32")]
176    {
177        // WASM SIMD128 when simd128 feature is enabled
178        #[cfg(target_feature = "simd128")]
179        {
180            SimdLevel::Simd128
181        }
182        #[cfg(not(target_feature = "simd128"))]
183        {
184            SimdLevel::Scalar
185        }
186    }
187
188    #[cfg(not(any(
189        target_arch = "x86_64",
190        target_arch = "aarch64",
191        target_arch = "wasm32"
192    )))]
193    {
194        SimdLevel::Scalar
195    }
196}
197
198/// Trait for SIMD-capable scalar types.
199///
200/// This trait provides the interface for types that can be vectorized
201/// using SIMD operations.
202pub trait SimdScalar: Field {
203    /// The 256-bit SIMD register type for this scalar (e.g., AVX2 on x86-64).
204    type Simd256: SimdRegister<Scalar = Self>;
205    /// The 512-bit SIMD register type for this scalar (e.g., AVX-512 on x86-64).
206    type Simd512: SimdRegister<Scalar = Self>;
207
208    /// Number of elements that fit in a 256-bit register.
209    const LANES_256: usize = 32 / core::mem::size_of::<Self>();
210
211    /// Number of elements that fit in a 512-bit register.
212    const LANES_512: usize = 64 / core::mem::size_of::<Self>();
213}
214
215/// Trait for SIMD register types.
216///
217/// This provides a unified interface for SIMD operations across
218/// different architectures and vector widths.
219pub trait SimdRegister: Copy + Clone + Send + Sync {
220    /// The scalar type this register holds.
221    type Scalar: SimdScalar;
222
223    /// Number of lanes in this register.
224    const LANES: usize;
225
226    /// Creates a register with all lanes set to zero.
227    fn zero() -> Self;
228
229    /// Creates a register with all lanes set to the same value.
230    fn splat(value: Self::Scalar) -> Self;
231
232    /// Loads from an aligned pointer.
233    ///
234    /// # Safety
235    /// The pointer must be aligned to the register width and point to
236    /// at least LANES valid elements.
237    unsafe fn load_aligned(ptr: *const Self::Scalar) -> Self;
238
239    /// Loads from an unaligned pointer.
240    ///
241    /// # Safety
242    /// The pointer must point to at least LANES valid elements.
243    unsafe fn load_unaligned(ptr: *const Self::Scalar) -> Self;
244
245    /// Stores to an aligned pointer.
246    ///
247    /// # Safety
248    /// The pointer must be aligned to the register width and point to
249    /// at least LANES valid writable elements.
250    unsafe fn store_aligned(self, ptr: *mut Self::Scalar);
251
252    /// Stores to an unaligned pointer.
253    ///
254    /// # Safety
255    /// The pointer must point to at least LANES valid writable elements.
256    unsafe fn store_unaligned(self, ptr: *mut Self::Scalar);
257
258    /// Element-wise addition.
259    fn add(self, other: Self) -> Self;
260
261    /// Element-wise subtraction.
262    fn sub(self, other: Self) -> Self;
263
264    /// Element-wise multiplication.
265    fn mul(self, other: Self) -> Self;
266
267    /// Element-wise division.
268    fn div(self, other: Self) -> Self;
269
270    /// Fused multiply-add: self * a + b
271    fn mul_add(self, a: Self, b: Self) -> Self;
272
273    /// Fused multiply-subtract: self * a - b
274    fn mul_sub(self, a: Self, b: Self) -> Self;
275
276    /// Fused negative multiply-add: -(self * a) + b = b - self * a
277    fn neg_mul_add(self, a: Self, b: Self) -> Self;
278
279    /// Horizontal sum of all lanes.
280    fn reduce_sum(self) -> Self::Scalar;
281
282    /// Horizontal maximum of all lanes (for real types).
283    fn reduce_max(self) -> Self::Scalar
284    where
285        Self::Scalar: Real;
286
287    /// Horizontal minimum of all lanes (for real types).
288    fn reduce_min(self) -> Self::Scalar
289    where
290        Self::Scalar: Real;
291
292    /// Extracts a single lane.
293    fn extract(self, index: usize) -> Self::Scalar;
294
295    /// Inserts a value into a single lane.
296    fn insert(self, index: usize, value: Self::Scalar) -> Self;
297}
298
299/// Extension trait for masked SIMD operations.
300pub trait SimdMask: SimdRegister {
301    /// The mask type for this register.
302    type Mask: Copy + Clone;
303
304    /// Creates a mask from a boolean array.
305    fn mask_from_bools(bools: &[bool]) -> Self::Mask;
306
307    /// Masked load: only loads elements where mask is true.
308    ///
309    /// # Safety
310    /// For lanes where mask is true, the corresponding pointer element must be valid.
311    unsafe fn load_masked(ptr: *const Self::Scalar, mask: Self::Mask, default: Self) -> Self;
312
313    /// Masked store: only stores elements where mask is true.
314    ///
315    /// # Safety
316    /// For lanes where mask is true, the corresponding pointer element must be valid and writable.
317    unsafe fn store_masked(self, ptr: *mut Self::Scalar, mask: Self::Mask);
318
319    /// Blends two registers based on mask: if mask\[i\] then a\[i\] else b\[i\].
320    fn blend(mask: Self::Mask, a: Self, b: Self) -> Self;
321}
322
323/// Helper struct for iterating over SIMD chunks with proper head/body/tail handling.
324#[derive(Debug, Clone, Copy)]
325pub struct SimdChunks {
326    /// Total number of elements.
327    pub len: usize,
328    /// Number of lanes per SIMD register.
329    pub lanes: usize,
330    /// Index where head (unaligned prefix) ends.
331    pub head_end: usize,
332    /// Index where body (aligned middle) ends.
333    pub body_end: usize,
334}
335
336impl SimdChunks {
337    /// Creates a new chunk iterator for the given length and alignment.
338    #[inline]
339    pub fn new<T: Scalar>(ptr: *const T, len: usize, level: SimdLevel) -> Self {
340        let lanes = level.lanes::<T>();
341        let align = level.width_bytes();
342
343        if lanes <= 1 || len < lanes * 2 {
344            // Not worth SIMD, treat everything as head
345            return SimdChunks {
346                len,
347                lanes,
348                head_end: len,
349                body_end: len,
350            };
351        }
352
353        let addr = ptr as usize;
354        let misalign = addr % align;
355
356        let head_end = if misalign == 0 {
357            0
358        } else {
359            let elements_to_align = (align - misalign) / core::mem::size_of::<T>();
360            elements_to_align.min(len)
361        };
362
363        let remaining = len - head_end;
364        let full_vectors = remaining / lanes;
365        let body_end = head_end + full_vectors * lanes;
366
367        SimdChunks {
368            len,
369            lanes,
370            head_end,
371            body_end,
372        }
373    }
374
375    /// Returns the number of head elements (before aligned body).
376    #[inline]
377    pub fn head_len(&self) -> usize {
378        self.head_end
379    }
380
381    /// Returns the number of body elements (aligned middle).
382    #[inline]
383    pub fn body_len(&self) -> usize {
384        self.body_end - self.head_end
385    }
386
387    /// Returns the number of tail elements (after aligned body).
388    #[inline]
389    pub fn tail_len(&self) -> usize {
390        self.len - self.body_end
391    }
392
393    /// Returns the number of full SIMD vectors in the body.
394    #[inline]
395    pub fn body_vectors(&self) -> usize {
396        self.body_len() / self.lanes
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    #[test]
405    fn test_detect_simd_level() {
406        let level = detect_simd_level();
407        println!("Detected SIMD level: {:?}", level);
408
409        // When force-scalar is enabled, should always be Scalar
410        #[cfg(feature = "force-scalar")]
411        {
412            assert_eq!(level, SimdLevel::Scalar);
413            // But raw detection should still show hardware capability
414            let raw = detect_simd_level_raw();
415            println!("Raw hardware SIMD level: {:?}", raw);
416        }
417
418        // Without force-scalar, should detect hardware SIMD
419        #[cfg(not(feature = "force-scalar"))]
420        {
421            #[cfg(target_arch = "x86_64")]
422            assert!(level >= SimdLevel::Simd128);
423
424            #[cfg(target_arch = "aarch64")]
425            assert_eq!(level, SimdLevel::Simd128);
426        }
427    }
428
429    #[test]
430    fn test_simd_level_lanes() {
431        assert_eq!(SimdLevel::Simd256.lanes::<f64>(), 4);
432        assert_eq!(SimdLevel::Simd256.lanes::<f32>(), 8);
433        assert_eq!(SimdLevel::Simd512.lanes::<f64>(), 8);
434        assert_eq!(SimdLevel::Simd512.lanes::<f32>(), 16);
435    }
436
437    #[test]
438    fn test_simd_chunks() {
439        // Create a pointer with known alignment
440        let data: Vec<f64> = vec![0.0; 100];
441        let ptr = data.as_ptr();
442
443        let chunks = SimdChunks::new(ptr, 100, SimdLevel::Simd256);
444        println!(
445            "Chunks: head_end={}, body_end={}",
446            chunks.head_end, chunks.body_end
447        );
448
449        // Verify that head + body + tail = len
450        assert_eq!(
451            chunks.head_len() + chunks.body_len() + chunks.tail_len(),
452            100
453        );
454    }
455
456    // =============================================================================
457    // Comprehensive SIMD correctness tests
458    // =============================================================================
459
460    /// Test scalar fallback FMA accuracy.
461    #[test]
462    fn test_scalar_fma_accuracy() {
463        use crate::simd::scalar::ScalarF64;
464
465        let a = ScalarF64::splat(1.0 + 1e-15);
466        let b = ScalarF64::splat(1.0 + 1e-15);
467        let c = ScalarF64::splat(-(1.0 + 2e-15));
468
469        // FMA should preserve more precision than separate mul+add
470        let fma_result = a.mul_add(b, c);
471        let mul_add_result = a.mul(b).add(c);
472
473        // Both should be very small but may differ slightly
474        assert!(fma_result.0.abs() < 1e-14);
475        assert!(mul_add_result.0.abs() < 1e-14);
476    }
477
478    /// Test load/store roundtrip.
479    #[test]
480    fn test_load_store_roundtrip() {
481        use crate::simd::scalar::ScalarF64;
482
483        let values = [42.0f64, 1.5, -3.5, 1000.0];
484
485        for &val in &values {
486            let v = ScalarF64::splat(val);
487            assert_eq!(v.reduce_sum(), val);
488            assert_eq!(v.extract(0), val);
489        }
490    }
491
492    /// Test arithmetic identities.
493    #[test]
494    fn test_arithmetic_identities() {
495        use crate::simd::scalar::{ScalarF32, ScalarF64};
496
497        // Test with f64
498        let a = ScalarF64::splat(5.0);
499        let zero = ScalarF64::zero();
500        let one = ScalarF64::splat(1.0);
501
502        // a + 0 = a
503        assert_eq!(a.add(zero).0, 5.0);
504        // a - 0 = a
505        assert_eq!(a.sub(zero).0, 5.0);
506        // a * 1 = a
507        assert_eq!(a.mul(one).0, 5.0);
508        // a / 1 = a
509        assert_eq!(a.div(one).0, 5.0);
510        // a * 0 = 0
511        assert_eq!(a.mul(zero).0, 0.0);
512
513        // Test with f32
514        let a32 = ScalarF32::splat(5.0);
515        let zero32 = ScalarF32::zero();
516        let one32 = ScalarF32::splat(1.0);
517
518        assert_eq!(a32.add(zero32).0, 5.0);
519        assert_eq!(a32.mul(one32).0, 5.0);
520    }
521
522    /// Test reduction operations.
523    #[test]
524    fn test_reductions() {
525        use crate::simd::scalar::{ScalarF32, ScalarF64};
526
527        // For scalar types, all reductions return the same value
528        let a = ScalarF64::splat(42.0);
529        assert_eq!(a.reduce_sum(), 42.0);
530        assert_eq!(a.reduce_max(), 42.0);
531        assert_eq!(a.reduce_min(), 42.0);
532
533        let b = ScalarF32::splat(-3.5);
534        assert_eq!(b.reduce_sum(), -3.5);
535        assert_eq!(b.reduce_max(), -3.5);
536        assert_eq!(b.reduce_min(), -3.5);
537    }
538
539    /// Test negative value handling.
540    #[test]
541    fn test_negative_values() {
542        use crate::simd::scalar::ScalarF64;
543
544        let neg = ScalarF64::splat(-5.0);
545        let pos = ScalarF64::splat(3.0);
546
547        // -5 + 3 = -2
548        assert_eq!(neg.add(pos).0, -2.0);
549        // -5 * 3 = -15
550        assert_eq!(neg.mul(pos).0, -15.0);
551        // -5 - 3 = -8
552        assert_eq!(neg.sub(pos).0, -8.0);
553    }
554
555    /// Test FMA variants.
556    #[test]
557    fn test_fma_variants() {
558        use crate::simd::scalar::ScalarF64;
559
560        let a = ScalarF64::splat(2.0);
561        let b = ScalarF64::splat(3.0);
562        let c = ScalarF64::splat(4.0);
563
564        // mul_add: a * b + c = 2 * 3 + 4 = 10
565        assert_eq!(a.mul_add(b, c).0, 10.0);
566
567        // mul_sub: a * b - c = 2 * 3 - 4 = 2
568        assert_eq!(a.mul_sub(b, c).0, 2.0);
569
570        // neg_mul_add: -(a * b) + c = -6 + 4 = -2
571        assert_eq!(a.neg_mul_add(b, c).0, -2.0);
572    }
573
574    /// Test insert/extract operations.
575    #[test]
576    fn test_insert_extract() {
577        use crate::simd::scalar::ScalarF64;
578
579        let a = ScalarF64::splat(1.0);
580        let b = a.insert(0, 42.0);
581        assert_eq!(b.extract(0), 42.0);
582    }
583
584    /// Platform-specific tests for native SIMD.
585    #[cfg(target_arch = "aarch64")]
586    #[test]
587    fn test_aarch64_simd_correctness() {
588        use crate::simd::aarch64::{F32x4, F64x2, F64x4};
589
590        // Test F64x2
591        let a = F64x2::splat(2.0);
592        let b = F64x2::splat(3.0);
593
594        let sum = a.add(b);
595        assert_eq!(sum.extract(0), 5.0);
596        assert_eq!(sum.extract(1), 5.0);
597
598        let fma = a.mul_add(b, F64x2::splat(1.0));
599        assert_eq!(fma.extract(0), 7.0); // 2*3 + 1
600
601        // Test F64x4 (emulated)
602        let c = F64x4::splat(2.0);
603        let d = F64x4::splat(3.0);
604
605        assert_eq!(c.add(d).reduce_sum(), 20.0); // 4 * 5.0
606
607        // Test F32x4
608        let e = F32x4::splat(2.0);
609        let f = F32x4::splat(3.0);
610
611        assert_eq!(e.add(f).reduce_sum(), 20.0); // 4 * 5.0
612    }
613}