Skip to main content

oxiblas_core/simd/
dispatch.rs

1//! Runtime SIMD dispatch system with function multi-versioning.
2//!
3//! Provides centralized CPU feature detection and caching for optimal
4//! performance, plus infrastructure for dispatching to the best available
5//! microkernel implementation at runtime.
6//!
7//! # Architecture
8//!
9//! - [`SimdCapabilities`]: Describes which SIMD extensions are available.
10//! - [`simd_caps`]: Returns a `&'static SimdCapabilities` (std) or a freshly
11//!   computed value (no_std) for the current CPU.
12//! - [`simd_dispatch!`]: Macro that selects the best implementation branch
13//!   based on detected capabilities.
14//! - [`SimdDispatcher`]: Trait for types that provide multi-versioned
15//!   implementations of a computation.
16//! - [`KernelSelector`]: Chooses the optimal GEMM microkernel kind at startup.
17//!
18//! # no_std note
19//!
20//! When the `std` feature is disabled, `SimdCapabilities::detect()` returns a
21//! freshly computed value derived from compile-time `target_feature` flags on
22//! every call (no caching), and `simd_caps()` returns that value by value.
23//! This is consistent with the no_std contract: no heap, no globals.
24
25#[cfg(feature = "std")]
26use std::sync::OnceLock;
27
28// ---------------------------------------------------------------------------
29// SimdCapabilities
30// ---------------------------------------------------------------------------
31
32/// CPU SIMD capabilities detected (or derived from compile-time flags) at
33/// startup.
34///
35/// On x86-64 with `std` enabled, detection uses the `is_x86_feature_detected!`
36/// macro which reads CPUID and is cached by the standard library.  On targets
37/// that do not have `std`, detection falls back to compile-time
38/// `target_feature` constants, which are set by RUSTFLAGS / `.cargo/config.toml`.
39///
40/// The struct is intentionally kept `Copy` so callers on no_std targets can
41/// store it cheaply on the stack without worrying about ownership.
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub struct SimdCapabilities {
44    // ------------------------------------------------------------------
45    // x86-64 fields
46    // ------------------------------------------------------------------
47    /// SSE4.2 support (x86-64 only).
48    pub has_sse42: bool,
49    /// AVX support (x86-64 only).
50    pub has_avx: bool,
51    /// AVX2 support (x86-64 only).
52    pub has_avx2: bool,
53    /// FMA (Fused Multiply-Add) support (x86-64 only).
54    pub has_fma: bool,
55    /// AVX-512 Foundation support (x86-64 only).
56    pub has_avx512f: bool,
57    /// AVX-512 Byte & Word instructions (x86-64 only).
58    pub has_avx512bw: bool,
59    /// AVX-512 Vector Length extensions (x86-64 only).
60    pub has_avx512vl: bool,
61
62    // ------------------------------------------------------------------
63    // ARM fields
64    // ------------------------------------------------------------------
65    /// NEON support.  Always `true` on AArch64; may be `true` on 32-bit ARM.
66    pub has_neon: bool,
67    /// SVE (Scalable Vector Extension) support.
68    pub has_sve: bool,
69
70    // ------------------------------------------------------------------
71    // Memory topology
72    // ------------------------------------------------------------------
73    /// Size of a single cache line in bytes (typically 64).
74    pub cache_line_bytes: usize,
75    /// Width of the widest supported SIMD vector in bytes.
76    pub vector_width_bytes: usize,
77}
78
79impl SimdCapabilities {
80    // ------------------------------------------------------------------
81    // Construction helpers
82    // ------------------------------------------------------------------
83
84    /// Compute the [`SimdCapabilities`] for the current CPU.
85    ///
86    /// When compiled with `std`, the result is cached in a [`OnceLock`] and
87    /// this function returns `&'static Self`; the detection work is only done
88    /// once per process.
89    ///
90    /// When compiled **without** `std`, this returns `Self` by value (derived
91    /// from compile-time `cfg!(target_feature = …)` flags).
92    #[cfg(feature = "std")]
93    pub fn detect() -> &'static Self {
94        simd_caps()
95    }
96
97    /// Compute the [`SimdCapabilities`] for the current CPU (no_std path).
98    ///
99    /// Returns a fresh value on every call because `OnceLock` is not available
100    /// without the standard library.
101    #[cfg(not(feature = "std"))]
102    pub fn detect() -> Self {
103        simd_caps()
104    }
105
106    // ------------------------------------------------------------------
107    // Internal constructor used by simd_caps_compute()
108    // ------------------------------------------------------------------
109
110    #[cfg(all(target_arch = "x86_64", feature = "std"))]
111    fn compute() -> Self {
112        let has_avx512f = is_x86_feature_detected!("avx512f");
113        let has_avx512bw = is_x86_feature_detected!("avx512bw");
114        let has_avx512vl = is_x86_feature_detected!("avx512vl");
115        let has_avx2 = is_x86_feature_detected!("avx2");
116        let has_fma = is_x86_feature_detected!("fma");
117        let has_avx = is_x86_feature_detected!("avx");
118        let has_sse42 = is_x86_feature_detected!("sse4.2");
119
120        let vector_width_bytes = if has_avx512f {
121            64
122        } else if has_avx2 {
123            32
124        } else if has_sse42 {
125            16
126        } else {
127            8
128        };
129
130        Self {
131            has_sse42,
132            has_avx,
133            has_avx2,
134            has_fma,
135            has_avx512f,
136            has_avx512bw,
137            has_avx512vl,
138            has_neon: false,
139            has_sve: false,
140            cache_line_bytes: 64,
141            vector_width_bytes,
142        }
143    }
144
145    #[cfg(all(target_arch = "x86_64", not(feature = "std")))]
146    fn compute() -> Self {
147        let has_avx512f = cfg!(target_feature = "avx512f");
148        let has_avx512bw = cfg!(target_feature = "avx512bw");
149        let has_avx512vl = cfg!(target_feature = "avx512vl");
150        let has_avx2 = cfg!(target_feature = "avx2");
151        let has_fma = cfg!(target_feature = "fma");
152        let has_avx = cfg!(target_feature = "avx");
153        let has_sse42 = cfg!(target_feature = "sse4.2");
154
155        let vector_width_bytes: usize = if has_avx512f {
156            64
157        } else if has_avx2 {
158            32
159        } else if has_sse42 {
160            16
161        } else {
162            8
163        };
164
165        Self {
166            has_sse42,
167            has_avx,
168            has_avx2,
169            has_fma,
170            has_avx512f,
171            has_avx512bw,
172            has_avx512vl,
173            has_neon: false,
174            has_sve: false,
175            cache_line_bytes: 64,
176            vector_width_bytes,
177        }
178    }
179
180    #[cfg(target_arch = "aarch64")]
181    fn compute() -> Self {
182        // NEON is mandatory on AArch64 per the architecture specification.
183        let has_sve = cfg!(target_feature = "sve");
184        Self {
185            has_sse42: false,
186            has_avx: false,
187            has_avx2: false,
188            has_fma: false,
189            has_avx512f: false,
190            has_avx512bw: false,
191            has_avx512vl: false,
192            has_neon: true,
193            has_sve,
194            cache_line_bytes: 64,
195            vector_width_bytes: 16,
196        }
197    }
198
199    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
200    fn compute() -> Self {
201        Self {
202            has_sse42: false,
203            has_avx: false,
204            has_avx2: false,
205            has_fma: false,
206            has_avx512f: false,
207            has_avx512bw: false,
208            has_avx512vl: false,
209            has_neon: false,
210            has_sve: false,
211            cache_line_bytes: 64,
212            vector_width_bytes: 8,
213        }
214    }
215
216    // ------------------------------------------------------------------
217    // Capability queries
218    // ------------------------------------------------------------------
219
220    /// Returns `true` when AVX-512F, BW, and VL are all available — a common
221    /// prerequisite for most practical AVX-512 code paths.
222    #[inline]
223    pub fn has_avx512_full(&self) -> bool {
224        self.has_avx512f && self.has_avx512bw && self.has_avx512vl
225    }
226
227    /// Returns `true` when both AVX2 and FMA are available — the usual
228    /// prerequisite for 256-bit fused multiply-add kernels.
229    #[inline]
230    pub fn has_avx2_fma(&self) -> bool {
231        self.has_avx2 && self.has_fma
232    }
233
234    /// Returns the number of `f64` elements that fit in the widest supported
235    /// SIMD vector register.
236    ///
237    /// | ISA             | f64 elements |
238    /// |-----------------|-------------|
239    /// | AVX-512         | 8           |
240    /// | AVX2 / NEON-256 | 4           |
241    /// | SSE4.2 / NEON   | 2           |
242    /// | scalar          | 1           |
243    #[inline]
244    pub fn f64_simd_width(&self) -> usize {
245        self.vector_width_bytes / core::mem::size_of::<f64>()
246    }
247
248    /// Returns the number of `f32` elements that fit in the widest supported
249    /// SIMD vector register.
250    #[inline]
251    pub fn f32_simd_width(&self) -> usize {
252        self.vector_width_bytes / core::mem::size_of::<f32>()
253    }
254
255    /// Returns the [`SimdLevel`] that best summarises the capabilities.
256    #[inline]
257    pub fn optimal_level(&self) -> SimdLevel {
258        if self.has_avx512_full() {
259            SimdLevel::Avx512
260        } else if self.has_avx2_fma() {
261            SimdLevel::Avx2
262        } else if self.has_avx {
263            SimdLevel::Avx
264        } else if self.has_sse42 {
265            SimdLevel::Sse42
266        } else if self.has_neon {
267            SimdLevel::Neon
268        } else if self.has_sve {
269            SimdLevel::Sve
270        } else {
271            SimdLevel::Scalar
272        }
273    }
274}
275
276// ---------------------------------------------------------------------------
277// SimdLevel  (kept distinct from the coarser SimdLevel in simd.rs)
278// ---------------------------------------------------------------------------
279
280/// Fine-grained SIMD instruction set level, used by the dispatch layer.
281///
282/// The ordering (`Ord` derived) is meaningful only within the x86-64 family;
283/// ARM levels intentionally have large discriminants so that comparisons across
284/// families are not accidentally relied upon.
285#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
286pub enum SimdLevel {
287    /// Scalar (no SIMD).
288    Scalar = 0,
289    /// SSE4.2 (128-bit, x86-64).
290    Sse42 = 1,
291    /// AVX (256-bit, x86-64, no FMA).
292    Avx = 2,
293    /// AVX2 with FMA (256-bit, x86-64).
294    Avx2 = 3,
295    /// AVX-512F+BW+VL (512-bit, x86-64).
296    Avx512 = 4,
297    /// NEON (128-bit, AArch64).
298    Neon = 10,
299    /// SVE (scalable, AArch64).
300    Sve = 11,
301}
302
303impl SimdLevel {
304    /// Human-readable name of the level.
305    #[inline]
306    pub const fn name(self) -> &'static str {
307        match self {
308            SimdLevel::Scalar => "scalar",
309            SimdLevel::Sse42 => "SSE4.2",
310            SimdLevel::Avx => "AVX",
311            SimdLevel::Avx2 => "AVX2+FMA",
312            SimdLevel::Avx512 => "AVX-512",
313            SimdLevel::Neon => "NEON",
314            SimdLevel::Sve => "SVE",
315        }
316    }
317
318    /// Number of `f64` elements in a vector of this width.
319    #[inline]
320    pub const fn f64_width(self) -> usize {
321        match self {
322            SimdLevel::Scalar => 1,
323            SimdLevel::Sse42 => 2,
324            SimdLevel::Avx | SimdLevel::Avx2 => 4,
325            SimdLevel::Avx512 => 8,
326            SimdLevel::Neon => 2,
327            SimdLevel::Sve => 2, // conservative; actual width is dynamic
328        }
329    }
330
331    /// Number of `f32` elements in a vector of this width.
332    #[inline]
333    pub const fn f32_width(self) -> usize {
334        match self {
335            SimdLevel::Scalar => 1,
336            SimdLevel::Sse42 => 4,
337            SimdLevel::Avx | SimdLevel::Avx2 => 8,
338            SimdLevel::Avx512 => 16,
339            SimdLevel::Neon => 4,
340            SimdLevel::Sve => 4, // conservative; actual width is dynamic
341        }
342    }
343}
344
345// ---------------------------------------------------------------------------
346// Cached global (std path)
347// ---------------------------------------------------------------------------
348
349#[cfg(feature = "std")]
350static SIMD_CAPS: OnceLock<SimdCapabilities> = OnceLock::new();
351
352/// Returns a reference to the globally cached [`SimdCapabilities`].
353///
354/// Detection is performed exactly once and the result is stored in a process-
355/// wide static.  All subsequent calls return the same reference.
356#[cfg(feature = "std")]
357#[inline]
358pub fn simd_caps() -> &'static SimdCapabilities {
359    SIMD_CAPS.get_or_init(SimdCapabilities::compute)
360}
361
362/// Returns the [`SimdCapabilities`] derived from compile-time target features.
363///
364/// Because `OnceLock` requires `std`, this no_std variant recomputes the value
365/// on every call.  The computation is a chain of `cfg!` evaluations and is
366/// expected to be fully inlined / constant-folded by the compiler.
367#[cfg(not(feature = "std"))]
368#[inline]
369pub fn simd_caps() -> SimdCapabilities {
370    SimdCapabilities::compute()
371}
372
373/// Returns the optimal [`SimdLevel`] for the current CPU.
374#[inline]
375pub fn optimal_simd_level() -> SimdLevel {
376    #[cfg(feature = "std")]
377    {
378        simd_caps().optimal_level()
379    }
380    #[cfg(not(feature = "std"))]
381    {
382        simd_caps().optimal_level()
383    }
384}
385
386/// Returns `true` when AVX-512F+BW+VL are all available.
387#[inline]
388pub fn has_avx512() -> bool {
389    #[cfg(feature = "std")]
390    {
391        simd_caps().has_avx512_full()
392    }
393    #[cfg(not(feature = "std"))]
394    {
395        simd_caps().has_avx512_full()
396    }
397}
398
399/// Returns `true` when AVX2 and FMA are both available.
400#[inline]
401pub fn has_avx2_fma() -> bool {
402    #[cfg(feature = "std")]
403    {
404        simd_caps().has_avx2_fma()
405    }
406    #[cfg(not(feature = "std"))]
407    {
408        simd_caps().has_avx2_fma()
409    }
410}
411
412/// Returns `true` when NEON is available.
413#[inline]
414pub fn has_neon() -> bool {
415    #[cfg(feature = "std")]
416    {
417        simd_caps().has_neon
418    }
419    #[cfg(not(feature = "std"))]
420    {
421        simd_caps().has_neon
422    }
423}
424
425// ---------------------------------------------------------------------------
426// simd_dispatch! macro
427// ---------------------------------------------------------------------------
428
429/// Dispatch to the best available SIMD implementation for the current CPU.
430///
431/// The macro evaluates the provided [`SimdCapabilities`] reference (or value)
432/// and selects **exactly one** branch, in priority order:
433///
434/// 1. `avx512`  — AVX-512F+BW+VL
435/// 2. `avx2`    — AVX2 + FMA (256-bit)
436/// 3. `sse42`   — SSE4.2 (128-bit)
437/// 4. `neon`    — AArch64 NEON
438/// 5. `scalar`  — Portable scalar fallback
439///
440/// # Example
441///
442/// ```rust,ignore
443/// use oxiblas_core::simd::dispatch::{simd_caps, simd_dispatch};
444///
445/// let result = simd_dispatch!(
446///     simd_caps(),
447///     avx512  => compute_avx512(&a, &b),
448///     avx2    => compute_avx2(&a, &b),
449///     sse42   => compute_sse42(&a, &b),
450///     neon    => compute_neon(&a, &b),
451///     scalar  => compute_scalar(&a, &b),
452/// );
453/// ```
454#[macro_export]
455macro_rules! simd_dispatch {
456    (
457        $caps:expr,
458        avx512  => $avx512:expr,
459        avx2    => $avx2:expr,
460        sse42   => $sse42:expr,
461        neon    => $neon:expr,
462        scalar  => $scalar:expr $(,)?
463    ) => {{
464        let caps = $caps;
465        if caps.has_avx512_full() {
466            $avx512
467        } else if caps.has_avx2_fma() {
468            $avx2
469        } else if caps.has_sse42 {
470            $sse42
471        } else if caps.has_neon {
472            $neon
473        } else {
474            $scalar
475        }
476    }};
477}
478
479// Re-export so the macro is accessible from `oxiblas_core::simd::dispatch`.
480pub use simd_dispatch;
481
482// ---------------------------------------------------------------------------
483// SimdDispatcher trait
484// ---------------------------------------------------------------------------
485
486/// Trait for types that provide architecture-specialised implementations of a
487/// single computation.
488///
489/// Implement the four dispatch variants and call [`SimdDispatcher::dispatch`]
490/// to let the runtime choose the best one automatically.
491///
492/// # Example
493///
494/// ```rust,ignore
495/// struct DotProductF64<'a> {
496///     x: &'a [f64],
497///     y: &'a [f64],
498/// }
499///
500/// impl SimdDispatcher for DotProductF64<'_> {
501///     type Output = f64;
502///     fn dispatch_avx512(&self) -> f64 { /* AVX-512 kernel */ unimplemented!() }
503///     fn dispatch_avx2(&self)   -> f64 { /* AVX2+FMA kernel */ unimplemented!() }
504///     fn dispatch_neon(&self)   -> f64 { /* NEON kernel */ unimplemented!() }
505///     fn dispatch_scalar(&self) -> f64 { self.x.iter().zip(self.y).map(|(a,b)| a*b).sum() }
506/// }
507///
508/// let result = DotProductF64 { x: &[1.0, 2.0], y: &[3.0, 4.0] }.dispatch();
509/// assert_eq!(result, 11.0);
510/// ```
511pub trait SimdDispatcher {
512    /// The return type of the computation.
513    type Output;
514
515    /// AVX-512F+BW+VL specialised implementation.
516    fn dispatch_avx512(&self) -> Self::Output;
517
518    /// AVX2 + FMA specialised implementation.
519    fn dispatch_avx2(&self) -> Self::Output;
520
521    /// NEON (AArch64) specialised implementation.
522    fn dispatch_neon(&self) -> Self::Output;
523
524    /// Portable scalar fallback implementation.
525    fn dispatch_scalar(&self) -> Self::Output;
526
527    /// SSE4.2 (x86-64, 128-bit) specialised implementation.
528    ///
529    /// Defaults to [`dispatch_scalar`][SimdDispatcher::dispatch_scalar] when
530    /// not overridden, so implementors only need to override it when an
531    /// SSE4.2-specific path exists.
532    fn dispatch_sse42(&self) -> Self::Output {
533        self.dispatch_scalar()
534    }
535
536    /// Select and call the best available implementation.
537    ///
538    /// The selection is based on [`SimdCapabilities::detect`] which caches the
539    /// result in a process-wide static (when `std` is enabled).
540    fn dispatch(&self) -> Self::Output {
541        #[cfg(feature = "std")]
542        let caps = SimdCapabilities::detect();
543        #[cfg(not(feature = "std"))]
544        let caps = SimdCapabilities::detect();
545
546        if caps.has_avx512_full() {
547            self.dispatch_avx512()
548        } else if caps.has_avx2_fma() {
549            self.dispatch_avx2()
550        } else if caps.has_sse42 {
551            self.dispatch_sse42()
552        } else if caps.has_neon {
553            self.dispatch_neon()
554        } else {
555            self.dispatch_scalar()
556        }
557    }
558}
559
560// ---------------------------------------------------------------------------
561// KernelSelector  /  GemmKernelKind
562// ---------------------------------------------------------------------------
563
564/// Identifies the microkernel variant chosen for GEMM operations.
565#[derive(Debug, Clone, Copy, PartialEq, Eq)]
566pub enum GemmKernelKind {
567    /// AVX-512 microkernel (x86-64, 512-bit registers).
568    Avx512,
569    /// AVX2 + FMA microkernel (x86-64, 256-bit registers).
570    Avx2,
571    /// SSE4.2 microkernel (x86-64, 128-bit registers).
572    Sse42,
573    /// NEON microkernel (AArch64, 128-bit registers).
574    Neon,
575    /// Portable scalar microkernel (fallback).
576    Scalar,
577}
578
579impl GemmKernelKind {
580    /// Returns a human-readable name for the kind.
581    #[inline]
582    pub const fn name(self) -> &'static str {
583        match self {
584            GemmKernelKind::Avx512 => "AVX-512",
585            GemmKernelKind::Avx2 => "AVX2+FMA",
586            GemmKernelKind::Sse42 => "SSE4.2",
587            GemmKernelKind::Neon => "NEON",
588            GemmKernelKind::Scalar => "scalar",
589        }
590    }
591}
592
593/// Selects the optimal GEMM microkernel for each floating-point type based on
594/// the CPU capabilities detected at runtime.
595///
596/// The selection is performed once and stored in a process-wide static (on
597/// std-enabled builds).  Call [`KernelSelector::select`] to obtain a shared
598/// reference.
599#[derive(Debug, Clone, Copy, PartialEq, Eq)]
600pub struct KernelSelector {
601    /// Best microkernel kind for double-precision GEMM.
602    pub gemm_f64_kernel: GemmKernelKind,
603    /// Best microkernel kind for single-precision GEMM.
604    pub gemm_f32_kernel: GemmKernelKind,
605}
606
607impl KernelSelector {
608    /// Choose kernel kinds from a [`SimdCapabilities`] snapshot.
609    fn from_caps(caps: &SimdCapabilities) -> Self {
610        let kind = if caps.has_avx512_full() {
611            GemmKernelKind::Avx512
612        } else if caps.has_avx2_fma() {
613            GemmKernelKind::Avx2
614        } else if caps.has_sse42 {
615            GemmKernelKind::Sse42
616        } else if caps.has_neon {
617            GemmKernelKind::Neon
618        } else {
619            GemmKernelKind::Scalar
620        };
621
622        Self {
623            gemm_f64_kernel: kind,
624            gemm_f32_kernel: kind,
625        }
626    }
627
628    /// Returns a reference to the globally cached [`KernelSelector`].
629    ///
630    /// Detection is performed exactly once per process.
631    #[cfg(feature = "std")]
632    pub fn select() -> &'static Self {
633        static KERNEL_SEL: OnceLock<KernelSelector> = OnceLock::new();
634        KERNEL_SEL.get_or_init(|| Self::from_caps(simd_caps()))
635    }
636
637    /// Recomputes the [`KernelSelector`] from compile-time features (no_std).
638    #[cfg(not(feature = "std"))]
639    pub fn select() -> Self {
640        Self::from_caps(&simd_caps())
641    }
642}
643
644// ---------------------------------------------------------------------------
645// Debugging helper (std only)
646// ---------------------------------------------------------------------------
647
648/// Prints a formatted summary of the detected SIMD capabilities to stdout.
649///
650/// Useful for diagnostics and quick sanity checks.  Only available when the
651/// `std` feature is enabled.
652#[cfg(feature = "std")]
653pub fn print_capabilities() {
654    let caps = simd_caps();
655    let level = caps.optimal_level();
656
657    println!("=== OxiBLAS SIMD Capabilities ===");
658    println!("Optimal level   : {}", level.name());
659    println!("Cache line      : {} bytes", caps.cache_line_bytes);
660    println!("Vector width    : {} bytes", caps.vector_width_bytes);
661    println!("f64 SIMD width  : {} elements", caps.f64_simd_width());
662    println!("f32 SIMD width  : {} elements", caps.f32_simd_width());
663
664    #[cfg(target_arch = "x86_64")]
665    {
666        println!("x86-64 Features:");
667        println!("  SSE4.2     : {}", caps.has_sse42);
668        println!("  AVX        : {}", caps.has_avx);
669        println!("  AVX2       : {}", caps.has_avx2);
670        println!("  FMA        : {}", caps.has_fma);
671        println!("  AVX-512F   : {}", caps.has_avx512f);
672        println!("  AVX-512BW  : {}", caps.has_avx512bw);
673        println!("  AVX-512VL  : {}", caps.has_avx512vl);
674    }
675
676    #[cfg(target_arch = "aarch64")]
677    {
678        println!("AArch64 Features:");
679        println!("  NEON       : {}", caps.has_neon);
680        println!("  SVE        : {}", caps.has_sve);
681    }
682
683    let sel = KernelSelector::select();
684    println!("Kernel Selection:");
685    println!("  GEMM f64   : {}", sel.gemm_f64_kernel.name());
686    println!("  GEMM f32   : {}", sel.gemm_f32_kernel.name());
687    println!("==================================");
688}
689
690// ---------------------------------------------------------------------------
691// Tests
692// ---------------------------------------------------------------------------
693
694#[cfg(test)]
695mod tests {
696    use super::*;
697
698    // ------------------------------------------------------------------
699    // 1. Detection smoke-test: at least one SIMD feature or scalar path
700    // ------------------------------------------------------------------
701    #[test]
702    fn test_capabilities_detection_does_not_panic() {
703        // Must not panic on any supported target.
704        let caps = SimdCapabilities::detect();
705        // Cache line should be a sensible power-of-two value.
706        assert!(caps.cache_line_bytes >= 8);
707        assert!(caps.cache_line_bytes.is_power_of_two());
708    }
709
710    // ------------------------------------------------------------------
711    // 2. AArch64: NEON must always be true
712    // ------------------------------------------------------------------
713    #[cfg(target_arch = "aarch64")]
714    #[test]
715    fn test_aarch64_neon_always_present() {
716        let caps = SimdCapabilities::detect();
717        assert!(caps.has_neon, "NEON is mandatory on AArch64");
718        // x86-64 flags must be absent
719        assert!(!caps.has_avx2);
720        assert!(!caps.has_avx512f);
721    }
722
723    // ------------------------------------------------------------------
724    // 3. x86-64: flags are self-consistent
725    // ------------------------------------------------------------------
726    #[cfg(target_arch = "x86_64")]
727    #[test]
728    fn test_x86_64_flag_consistency() {
729        let caps = SimdCapabilities::detect();
730        // NEON must not be set on x86-64.
731        assert!(!caps.has_neon);
732        // If AVX2 is present, AVX must be present too (architectural requirement).
733        if caps.has_avx2 {
734            assert!(caps.has_avx, "AVX2 implies AVX");
735        }
736        // If AVX-512F is present, SSE4.2 should be present (all x86-64 CPUs
737        // with AVX-512 also support SSE4.2).
738        if caps.has_avx512f {
739            assert!(caps.has_sse42, "AVX-512 implies SSE4.2");
740        }
741    }
742
743    // ------------------------------------------------------------------
744    // 4. Vector width matches capabilities
745    // ------------------------------------------------------------------
746    #[test]
747    fn test_vector_width_consistent_with_capabilities() {
748        let caps = SimdCapabilities::detect();
749        if caps.has_avx512f {
750            assert_eq!(caps.vector_width_bytes, 64);
751        } else if caps.has_avx2 {
752            assert_eq!(caps.vector_width_bytes, 32);
753        }
754    }
755
756    // ------------------------------------------------------------------
757    // 5. f64 / f32 SIMD widths are derived from vector_width_bytes
758    // ------------------------------------------------------------------
759    #[test]
760    fn test_simd_widths_derived_correctly() {
761        let caps = SimdCapabilities::detect();
762        assert_eq!(
763            caps.f64_simd_width(),
764            caps.vector_width_bytes / core::mem::size_of::<f64>()
765        );
766        assert_eq!(
767            caps.f32_simd_width(),
768            caps.vector_width_bytes / core::mem::size_of::<f32>()
769        );
770        // f32 always holds twice as many elements as f64 in the same register.
771        assert_eq!(caps.f32_simd_width(), caps.f64_simd_width() * 2);
772    }
773
774    // ------------------------------------------------------------------
775    // 6. SimdLevel ordering is meaningful within x86-64 family
776    // ------------------------------------------------------------------
777    #[test]
778    fn test_simd_level_ordering() {
779        assert!(SimdLevel::Avx512 > SimdLevel::Avx2);
780        assert!(SimdLevel::Avx2 > SimdLevel::Avx);
781        assert!(SimdLevel::Avx > SimdLevel::Sse42);
782        assert!(SimdLevel::Sse42 > SimdLevel::Scalar);
783    }
784
785    // ------------------------------------------------------------------
786    // 7. SimdLevel vector widths are consistent
787    // ------------------------------------------------------------------
788    #[test]
789    fn test_simd_level_widths() {
790        // Scalar holds exactly 1 element per lane.
791        assert_eq!(SimdLevel::Scalar.f64_width(), 1);
792        assert_eq!(SimdLevel::Scalar.f32_width(), 1);
793        // Each wider ISA must hold more elements.
794        assert_eq!(SimdLevel::Sse42.f64_width(), 2);
795        assert_eq!(SimdLevel::Sse42.f32_width(), 4);
796        assert_eq!(SimdLevel::Avx2.f64_width(), 4);
797        assert_eq!(SimdLevel::Avx2.f32_width(), 8);
798        assert_eq!(SimdLevel::Avx512.f64_width(), 8);
799        assert_eq!(SimdLevel::Avx512.f32_width(), 16);
800        assert_eq!(SimdLevel::Neon.f64_width(), 2);
801        assert_eq!(SimdLevel::Neon.f32_width(), 4);
802    }
803
804    // ------------------------------------------------------------------
805    // 8. simd_caps() is idempotent (same result on repeated calls)
806    // ------------------------------------------------------------------
807    #[cfg(feature = "std")]
808    #[test]
809    fn test_simd_caps_cached_identity() {
810        let a = simd_caps();
811        let b = simd_caps();
812        // Must be the same static reference.
813        assert!(
814            core::ptr::eq(a, b),
815            "simd_caps() must return a stable &'static"
816        );
817    }
818
819    // ------------------------------------------------------------------
820    // 9. optimal_level() returns a variant that matches the caps flags
821    // ------------------------------------------------------------------
822    #[test]
823    fn test_optimal_level_consistent_with_flags() {
824        let caps = SimdCapabilities::detect();
825        let level = caps.optimal_level();
826        match level {
827            SimdLevel::Avx512 => assert!(caps.has_avx512_full()),
828            SimdLevel::Avx2 => {
829                assert!(!caps.has_avx512_full());
830                assert!(caps.has_avx2_fma());
831            }
832            SimdLevel::Avx => {
833                assert!(!caps.has_avx512_full());
834                assert!(!caps.has_avx2_fma());
835                assert!(caps.has_avx);
836            }
837            SimdLevel::Sse42 => {
838                assert!(!caps.has_avx);
839                assert!(caps.has_sse42);
840            }
841            SimdLevel::Neon => {
842                assert!(caps.has_neon);
843                assert!(!caps.has_avx);
844            }
845            SimdLevel::Sve => {
846                assert!(caps.has_sve);
847                assert!(!caps.has_neon);
848            }
849            SimdLevel::Scalar => {
850                assert!(!caps.has_sse42);
851                assert!(!caps.has_avx);
852                assert!(!caps.has_neon);
853                assert!(!caps.has_sve);
854            }
855        }
856    }
857
858    // ------------------------------------------------------------------
859    // 10. KernelSelector::select() does not panic and returns valid kinds
860    // ------------------------------------------------------------------
861    #[test]
862    fn test_kernel_selector_valid_kinds() {
863        #[cfg(feature = "std")]
864        let sel = *KernelSelector::select();
865        #[cfg(not(feature = "std"))]
866        let sel = KernelSelector::select();
867
868        // Kind must be one of the defined variants.
869        assert!(matches!(
870            sel.gemm_f64_kernel,
871            GemmKernelKind::Avx512
872                | GemmKernelKind::Avx2
873                | GemmKernelKind::Sse42
874                | GemmKernelKind::Neon
875                | GemmKernelKind::Scalar
876        ));
877        assert!(matches!(
878            sel.gemm_f32_kernel,
879            GemmKernelKind::Avx512
880                | GemmKernelKind::Avx2
881                | GemmKernelKind::Sse42
882                | GemmKernelKind::Neon
883                | GemmKernelKind::Scalar
884        ));
885    }
886
887    // ------------------------------------------------------------------
888    // 11. KernelSelector and optimal_level agree on AVX-512
889    // ------------------------------------------------------------------
890    #[test]
891    fn test_kernel_selector_matches_optimal_level() {
892        let caps = SimdCapabilities::detect();
893
894        #[cfg(feature = "std")]
895        let sel = *KernelSelector::select();
896        #[cfg(not(feature = "std"))]
897        let sel = KernelSelector::select();
898
899        if caps.has_avx512_full() {
900            assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Avx512);
901            assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Avx512);
902        } else if caps.has_avx2_fma() {
903            assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Avx2);
904            assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Avx2);
905        } else if caps.has_sse42 {
906            assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Sse42);
907            assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Sse42);
908        } else if caps.has_neon {
909            assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Neon);
910            assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Neon);
911        } else {
912            assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Scalar);
913            assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Scalar);
914        }
915    }
916
917    // ------------------------------------------------------------------
918    // 12. simd_dispatch! macro selects a branch without panicking
919    // ------------------------------------------------------------------
920    #[test]
921    fn test_simd_dispatch_macro_selects_branch() {
922        #[cfg(feature = "std")]
923        let caps = simd_caps();
924        #[cfg(not(feature = "std"))]
925        let caps = &simd_caps();
926
927        let result: u32 = simd_dispatch!(
928            caps,
929            avx512  => 512u32,
930            avx2    => 256u32,
931            sse42   => 128u32,
932            neon    => 1000u32,
933            scalar  => 1u32,
934        );
935
936        // The result must be consistent with what optimal_level() says.
937        // Note: simd_dispatch! routes AVX (no FMA) to the avx2 branch because
938        // the macro checks has_avx2_fma() not has_avx, so AVX-only maps to sse42.
939        let expected: u32 = match caps.optimal_level() {
940            SimdLevel::Avx512 => 512,
941            SimdLevel::Avx2 => 256,
942            SimdLevel::Avx | SimdLevel::Sse42 => 128,
943            SimdLevel::Neon | SimdLevel::Sve => 1000,
944            SimdLevel::Scalar => 1,
945        };
946        assert_eq!(result, expected);
947    }
948
949    // ------------------------------------------------------------------
950    // 13. SimdDispatcher trait: scalar reference implementation
951    // ------------------------------------------------------------------
952    struct ScalarDot<'a> {
953        x: &'a [f64],
954        y: &'a [f64],
955    }
956
957    impl SimdDispatcher for ScalarDot<'_> {
958        type Output = f64;
959
960        fn dispatch_avx512(&self) -> f64 {
961            // In production this would use AVX-512 intrinsics; here we
962            // delegate to scalar so the test compiles on all platforms.
963            self.dispatch_scalar()
964        }
965
966        fn dispatch_avx2(&self) -> f64 {
967            self.dispatch_scalar()
968        }
969
970        fn dispatch_neon(&self) -> f64 {
971            self.dispatch_scalar()
972        }
973
974        fn dispatch_scalar(&self) -> f64 {
975            self.x.iter().zip(self.y.iter()).map(|(a, b)| a * b).sum()
976        }
977    }
978
979    #[test]
980    fn test_simd_dispatcher_trait_correctness() {
981        let x = [1.0_f64, 2.0, 3.0, 4.0];
982        let y = [5.0_f64, 6.0, 7.0, 8.0];
983        let dot = ScalarDot { x: &x, y: &y };
984
985        // Expected: 1*5 + 2*6 + 3*7 + 4*8 = 5 + 12 + 21 + 32 = 70
986        let result = dot.dispatch();
987        assert!((result - 70.0).abs() < f64::EPSILON);
988    }
989
990    // ------------------------------------------------------------------
991    // 14. print_capabilities does not panic (std only)
992    // ------------------------------------------------------------------
993    #[cfg(feature = "std")]
994    #[test]
995    fn test_print_capabilities_does_not_panic() {
996        print_capabilities();
997    }
998
999    // ------------------------------------------------------------------
1000    // 15. GemmKernelKind names are non-empty strings
1001    // ------------------------------------------------------------------
1002    #[test]
1003    fn test_gemm_kernel_kind_names_non_empty() {
1004        for kind in [
1005            GemmKernelKind::Avx512,
1006            GemmKernelKind::Avx2,
1007            GemmKernelKind::Sse42,
1008            GemmKernelKind::Neon,
1009            GemmKernelKind::Scalar,
1010        ] {
1011            assert!(!kind.name().is_empty());
1012        }
1013    }
1014
1015    // ------------------------------------------------------------------
1016    // 16. has_avx512() / has_avx2_fma() / has_neon() helpers agree with caps
1017    // ------------------------------------------------------------------
1018    #[test]
1019    fn test_helper_functions_agree_with_caps() {
1020        let caps = SimdCapabilities::detect();
1021        assert_eq!(has_avx512(), caps.has_avx512_full());
1022        assert_eq!(has_avx2_fma(), caps.has_avx2_fma());
1023        assert_eq!(has_neon(), caps.has_neon);
1024    }
1025}