Skip to main content

oxiblas_core/simd/
multiver.rs

1//! Runtime function multi-versioning infrastructure for OxiBLAS.
2//!
3//! This module builds on top of [`crate::simd::dispatch`] to provide:
4//!
5//! - Extended [`SimdCapabilityInfo`] with the field names required by the task
6//!   specification (`has_avx512f`, `has_avx2`, `has_sse42`, `has_fma`,
7//!   `has_neon`, `cache_line_bytes`, `vector_width_bytes`) and a `detect()`
8//!   that returns `&'static Self` when `std` is enabled.
9//! - [`simd_dispatch_caps!`] macro for named-arm dispatch.
10//! - [`SimdDispatcher`] trait for multi-versioned computations.
11//! - [`KernelSelector`] / [`GemmKernelKind`] for startup kernel selection.
12//!
13//! # no_std
14//!
15//! Under `no_std`, `SimdCapabilityInfo::detect()` returns a freshly derived
16//! value on every call because `OnceLock` is unavailable.  All fields are set
17//! from compile-time `cfg!(target_feature = …)` constants, which the compiler
18//! constant-folds away.
19
20#[cfg(feature = "std")]
21use std::sync::OnceLock;
22
23use crate::simd::dispatch::{SimdCapabilities as LegacyCaps, SimdLevel as LegacyLevel};
24
25// ---------------------------------------------------------------------------
26// SimdCapabilityInfo — enhanced capability struct
27// ---------------------------------------------------------------------------
28
29/// Extended CPU SIMD capability description with the field naming convention
30/// required by the multi-versioning layer.
31///
32/// On `std`-enabled builds, [`SimdCapabilityInfo::detect`] returns a
33/// `&'static Self` reference that is initialised exactly once per process
34/// (via [`OnceLock`]).  On `no_std` builds it returns a fresh `Self` value
35/// computed from compile-time `cfg!` constants.
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub struct SimdCapabilityInfo {
38    // ------------------------------------------------------------------
39    // x86-64
40    // ------------------------------------------------------------------
41    /// SSE4.2 support (x86-64 only).
42    pub has_sse42: bool,
43    /// AVX support (x86-64 only).
44    pub has_avx: bool,
45    /// AVX2 support (x86-64 only).
46    pub has_avx2: bool,
47    /// FMA (Fused Multiply-Add) support (x86-64 only).
48    pub has_fma: bool,
49    /// AVX-512 Foundation support (x86-64 only).
50    pub has_avx512f: bool,
51    /// AVX-512 Byte & Word instructions (x86-64 only).
52    pub has_avx512bw: bool,
53    /// AVX-512 Vector Length extensions (x86-64 only).
54    pub has_avx512vl: bool,
55
56    // ------------------------------------------------------------------
57    // ARM
58    // ------------------------------------------------------------------
59    /// NEON support.  Always `true` on AArch64.
60    pub has_neon: bool,
61    /// SVE (Scalable Vector Extension) support.
62    pub has_sve: bool,
63
64    // ------------------------------------------------------------------
65    // Memory topology
66    // ------------------------------------------------------------------
67    /// Bytes in a single cache line (typically 64 on modern CPUs).
68    pub cache_line_bytes: usize,
69    /// Width of the widest supported SIMD vector register in bytes.
70    pub vector_width_bytes: usize,
71}
72
73impl SimdCapabilityInfo {
74    // ------------------------------------------------------------------
75    // Public entry-points
76    // ------------------------------------------------------------------
77
78    /// Detect (or derive) capabilities for the current CPU and return a
79    /// reference to the process-wide cached value.
80    ///
81    /// The first call performs runtime detection and stores the result.
82    /// Subsequent calls return the same `&'static` reference.
83    #[cfg(feature = "std")]
84    #[inline]
85    pub fn detect() -> &'static Self {
86        static INFO: OnceLock<SimdCapabilityInfo> = OnceLock::new();
87        INFO.get_or_init(Self::compute)
88    }
89
90    /// Derive capabilities from compile-time target features (no_std path).
91    ///
92    /// Returns a fresh value on every call.  The computation consists only of
93    /// `cfg!` evaluations which the compiler constant-folds.
94    #[cfg(not(feature = "std"))]
95    #[inline]
96    pub fn detect() -> Self {
97        Self::compute()
98    }
99
100    // ------------------------------------------------------------------
101    // Internal construction
102    // ------------------------------------------------------------------
103
104    fn compute() -> Self {
105        let legacy = simd_caps();
106        Self::from_legacy(legacy)
107    }
108
109    /// Build from the legacy [`SimdCapabilities`] already present in
110    /// `dispatch.rs`.  This ensures the two structs never diverge in their
111    /// detection logic.
112    #[cfg(feature = "std")]
113    fn from_legacy(legacy: &LegacyCaps) -> Self {
114        // SSE4.2 detection: the legacy struct only tracks SSE3.  We check for
115        // SSE4.2 directly when std is available so that runtime detection is
116        // accurate.
117        #[cfg(all(target_arch = "x86_64", feature = "std"))]
118        let has_sse42 = is_x86_feature_detected!("sse4.2");
119        #[cfg(not(target_arch = "x86_64"))]
120        let has_sse42 = false;
121
122        let has_avx512f = legacy.has_avx512f;
123        let has_avx2 = legacy.has_avx2;
124
125        let vector_width_bytes = if has_avx512f {
126            64
127        } else if has_avx2 {
128            32
129        } else if has_sse42 || legacy.has_neon {
130            16
131        } else {
132            8
133        };
134
135        Self {
136            has_sse42,
137            has_avx: legacy.has_avx,
138            has_avx2,
139            has_fma: legacy.has_fma,
140            has_avx512f,
141            has_avx512bw: legacy.has_avx512bw,
142            has_avx512vl: legacy.has_avx512vl,
143            has_neon: legacy.has_neon,
144            has_sve: legacy.has_sve,
145            cache_line_bytes: 64,
146            vector_width_bytes,
147        }
148    }
149
150    /// Build from the legacy [`SimdCapabilities`] (no_std path, passed by
151    /// value).
152    #[cfg(not(feature = "std"))]
153    fn from_legacy(legacy: LegacyCaps) -> Self {
154        let has_sse42 = cfg!(target_feature = "sse4.2");
155        let has_avx512f = legacy.has_avx512f;
156        let has_avx2 = legacy.has_avx2;
157
158        let vector_width_bytes: usize = if has_avx512f {
159            64
160        } else if has_avx2 {
161            32
162        } else if has_sse42 || legacy.has_neon {
163            16
164        } else {
165            8
166        };
167
168        Self {
169            has_sse42,
170            has_avx: legacy.has_avx,
171            has_avx2,
172            has_fma: legacy.has_fma,
173            has_avx512f,
174            has_avx512bw: legacy.has_avx512bw,
175            has_avx512vl: legacy.has_avx512vl,
176            has_neon: legacy.has_neon,
177            has_sve: legacy.has_sve,
178            cache_line_bytes: 64,
179            vector_width_bytes,
180        }
181    }
182
183    // ------------------------------------------------------------------
184    // Capability queries
185    // ------------------------------------------------------------------
186
187    /// Returns `true` when AVX-512F, BW, and VL are all present.
188    #[inline]
189    pub fn has_avx512_full(&self) -> bool {
190        self.has_avx512f && self.has_avx512bw && self.has_avx512vl
191    }
192
193    /// Returns `true` when both AVX2 and FMA are present.
194    #[inline]
195    pub fn has_avx2_fma(&self) -> bool {
196        self.has_avx2 && self.has_fma
197    }
198
199    /// Number of `f64` elements that fit in the widest supported SIMD register.
200    ///
201    /// For AVX-512 this is 8; for AVX2 / NEON-128 this is 4 / 2; for scalar 1.
202    #[inline]
203    pub fn f64_simd_width(&self) -> usize {
204        self.vector_width_bytes / core::mem::size_of::<f64>()
205    }
206
207    /// Number of `f32` elements that fit in the widest supported SIMD register.
208    ///
209    /// Always twice `f64_simd_width()`.
210    #[inline]
211    pub fn f32_simd_width(&self) -> usize {
212        self.vector_width_bytes / core::mem::size_of::<f32>()
213    }
214
215    /// Returns the [`LegacyLevel`] that best summarises these capabilities.
216    #[inline]
217    pub fn optimal_level(&self) -> LegacyLevel {
218        if self.has_avx512_full() {
219            LegacyLevel::Avx512
220        } else if self.has_avx2_fma() {
221            LegacyLevel::Avx2
222        } else if self.has_avx {
223            LegacyLevel::Avx
224        } else if self.has_sse42 {
225            LegacyLevel::Sse42
226        } else if self.has_neon {
227            LegacyLevel::Neon
228        } else if self.has_sve {
229            LegacyLevel::Sve
230        } else {
231            LegacyLevel::Scalar
232        }
233    }
234}
235
236// ---------------------------------------------------------------------------
237// simd_dispatch_caps! macro
238// ---------------------------------------------------------------------------
239
240/// Dispatch to the best available SIMD implementation, selecting among five
241/// named arms in priority order: `avx512`, `avx2`, `sse42`, `neon`, `scalar`.
242///
243/// The first argument must be a [`SimdCapabilityInfo`] reference (or value);
244/// on `std`-enabled builds use [`SimdCapabilityInfo::detect()`].
245///
246/// # Priority
247///
248/// 1. `avx512`  — AVX-512F + BW + VL
249/// 2. `avx2`    — AVX2 + FMA (256-bit)
250/// 3. `sse42`   — SSE4.2 (128-bit)
251/// 4. `neon`    — AArch64 NEON
252/// 5. `scalar`  — portable fallback
253///
254/// # Example
255///
256/// ```rust
257/// use oxiblas_core::simd::multiver::{SimdCapabilityInfo, simd_dispatch_caps};
258///
259/// # #[cfg(feature = "std")]
260/// let result: &str = simd_dispatch_caps!(
261///     SimdCapabilityInfo::detect(),
262///     avx512  => "avx512",
263///     avx2    => "avx2",
264///     sse42   => "sse42",
265///     neon    => "neon",
266///     scalar  => "scalar",
267/// );
268/// ```
269#[macro_export]
270macro_rules! simd_dispatch_caps {
271    (
272        $caps:expr,
273        avx512  => $avx512:expr,
274        avx2    => $avx2:expr,
275        sse42   => $sse42:expr,
276        neon    => $neon:expr,
277        scalar  => $scalar:expr $(,)?
278    ) => {{
279        let _caps = $caps;
280        if _caps.has_avx512_full() {
281            $avx512
282        } else if _caps.has_avx2_fma() {
283            $avx2
284        } else if _caps.has_sse42 {
285            $sse42
286        } else if _caps.has_neon {
287            $neon
288        } else {
289            $scalar
290        }
291    }};
292}
293
294// Make the macro accessible as `crate::simd::multiver::simd_dispatch_caps`.
295pub use simd_dispatch_caps;
296
297// ---------------------------------------------------------------------------
298// SimdDispatcher trait
299// ---------------------------------------------------------------------------
300
301/// Trait for types that provide architecture-specialised implementations of a
302/// single computation via function multi-versioning.
303///
304/// Implement the four required methods and call [`SimdDispatcher::dispatch`]
305/// to have the runtime select the fastest available path automatically.
306///
307/// # Design note
308///
309/// The trait uses `&self` receivers so that the dispatch object can carry all
310/// input data as fields, keeping call-sites clean.
311///
312/// # Example
313///
314/// ```rust
315/// use oxiblas_core::simd::multiver::SimdDispatcher;
316///
317/// struct ScalarSum<'a>(&'a [f64]);
318///
319/// impl SimdDispatcher for ScalarSum<'_> {
320///     type Output = f64;
321///     fn dispatch_avx512(&self) -> f64 { self.dispatch_scalar() }
322///     fn dispatch_avx2(&self)   -> f64 { self.dispatch_scalar() }
323///     fn dispatch_neon(&self)   -> f64 { self.dispatch_scalar() }
324///     fn dispatch_scalar(&self) -> f64 { self.0.iter().copied().sum() }
325/// }
326///
327/// assert_eq!(ScalarSum(&[1.0, 2.0, 3.0]).dispatch(), 6.0);
328/// ```
329pub trait SimdDispatcher {
330    /// The type returned by the computation.
331    type Output;
332
333    /// AVX-512F+BW+VL specialised implementation.
334    fn dispatch_avx512(&self) -> Self::Output;
335
336    /// AVX2 + FMA specialised implementation.
337    fn dispatch_avx2(&self) -> Self::Output;
338
339    /// NEON (AArch64) specialised implementation.
340    fn dispatch_neon(&self) -> Self::Output;
341
342    /// Portable scalar fallback.
343    fn dispatch_scalar(&self) -> Self::Output;
344
345    /// Select and call the best available implementation for the current CPU.
346    ///
347    /// The selection is based on [`SimdCapabilityInfo::detect`], which caches
348    /// the result in a process-wide static (std builds) or recomputes from
349    /// compile-time flags (no_std builds).
350    fn dispatch(&self) -> Self::Output {
351        #[cfg(feature = "std")]
352        let caps = SimdCapabilityInfo::detect();
353        #[cfg(not(feature = "std"))]
354        let caps = SimdCapabilityInfo::detect();
355
356        if caps.has_avx512_full() {
357            self.dispatch_avx512()
358        } else if caps.has_avx2_fma() {
359            self.dispatch_avx2()
360        } else if caps.has_neon {
361            self.dispatch_neon()
362        } else {
363            self.dispatch_scalar()
364        }
365    }
366}
367
368// ---------------------------------------------------------------------------
369// GemmKernelKind
370// ---------------------------------------------------------------------------
371
372/// Identifies which microkernel variant is used for GEMM operations.
373#[derive(Debug, Clone, Copy, PartialEq, Eq)]
374pub enum GemmKernelKind {
375    /// AVX-512 microkernel (512-bit registers, x86-64).
376    Avx512,
377    /// AVX2 + FMA microkernel (256-bit registers, x86-64).
378    Avx2,
379    /// NEON microkernel (128-bit registers, AArch64).
380    Neon,
381    /// Portable scalar microkernel (fallback for all targets).
382    Scalar,
383}
384
385impl GemmKernelKind {
386    /// Human-readable name of this kernel kind.
387    #[inline]
388    pub const fn name(self) -> &'static str {
389        match self {
390            GemmKernelKind::Avx512 => "AVX-512",
391            GemmKernelKind::Avx2 => "AVX2+FMA",
392            GemmKernelKind::Neon => "NEON",
393            GemmKernelKind::Scalar => "scalar",
394        }
395    }
396
397    /// Returns `true` when this kind uses SIMD (i.e., is not `Scalar`).
398    #[inline]
399    pub const fn is_simd(self) -> bool {
400        !matches!(self, GemmKernelKind::Scalar)
401    }
402}
403
404// ---------------------------------------------------------------------------
405// KernelSelector
406// ---------------------------------------------------------------------------
407
408/// Selects the optimal GEMM microkernel for `f64` and `f32` based on the CPU
409/// capabilities detected at runtime (or compile-time on no_std).
410///
411/// Call [`KernelSelector::select`] once at startup to obtain the globally
412/// cached selector; subsequent calls return the same reference (std) or a
413/// freshly computed identical value (no_std).
414#[derive(Debug, Clone, Copy, PartialEq, Eq)]
415pub struct KernelSelector {
416    /// Best microkernel kind for double-precision (f64) GEMM.
417    pub gemm_f64_kernel: GemmKernelKind,
418    /// Best microkernel kind for single-precision (f32) GEMM.
419    pub gemm_f32_kernel: GemmKernelKind,
420}
421
422impl KernelSelector {
423    fn from_caps(caps: &SimdCapabilityInfo) -> Self {
424        let kind = if caps.has_avx512_full() {
425            GemmKernelKind::Avx512
426        } else if caps.has_avx2_fma() {
427            GemmKernelKind::Avx2
428        } else if caps.has_neon {
429            GemmKernelKind::Neon
430        } else {
431            GemmKernelKind::Scalar
432        };
433
434        Self {
435            gemm_f64_kernel: kind,
436            gemm_f32_kernel: kind,
437        }
438    }
439
440    /// Returns a reference to the globally cached [`KernelSelector`].
441    ///
442    /// The first call performs detection; all subsequent calls return the same
443    /// `&'static` reference.
444    #[cfg(feature = "std")]
445    pub fn select() -> &'static Self {
446        static KERNEL_SEL: OnceLock<KernelSelector> = OnceLock::new();
447        KERNEL_SEL.get_or_init(|| Self::from_caps(SimdCapabilityInfo::detect()))
448    }
449
450    /// Recomputes the [`KernelSelector`] from compile-time target features
451    /// (no_std path).
452    #[cfg(not(feature = "std"))]
453    pub fn select() -> Self {
454        Self::from_caps(&SimdCapabilityInfo::detect())
455    }
456}
457
458// ---------------------------------------------------------------------------
459// Convenience re-exports from dispatch.rs
460// ---------------------------------------------------------------------------
461
462pub use crate::simd::dispatch::{
463    SimdCapabilities, SimdLevel, has_avx2_fma, has_avx512, has_neon, optimal_simd_level, simd_caps,
464};
465
466// ---------------------------------------------------------------------------
467// Tests (>= 10 tests as required)
468// ---------------------------------------------------------------------------
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473
474    // ------------------------------------------------------------------
475    // 1. Detection does not panic, cache_line_bytes is sane
476    // ------------------------------------------------------------------
477    #[test]
478    fn test_detect_does_not_panic_and_cache_line_sane() {
479        let caps = SimdCapabilityInfo::detect();
480        // Cache line must be at least 8 bytes and a power of two.
481        assert!(caps.cache_line_bytes >= 8);
482        assert!(caps.cache_line_bytes.is_power_of_two());
483        // vector_width_bytes must also be a power of two.
484        assert!(caps.vector_width_bytes.is_power_of_two());
485    }
486
487    // ------------------------------------------------------------------
488    // 2. AArch64: NEON always present
489    // ------------------------------------------------------------------
490    #[cfg(target_arch = "aarch64")]
491    #[test]
492    fn test_aarch64_neon_always_true() {
493        let caps = SimdCapabilityInfo::detect();
494        assert!(caps.has_neon, "NEON is mandatory on AArch64");
495        assert!(!caps.has_avx2, "AVX2 must not appear on AArch64");
496        assert!(!caps.has_avx512f, "AVX-512 must not appear on AArch64");
497    }
498
499    // ------------------------------------------------------------------
500    // 3. x86-64: flag hierarchy must be self-consistent
501    // ------------------------------------------------------------------
502    #[cfg(target_arch = "x86_64")]
503    #[test]
504    fn test_x86_64_flag_hierarchy() {
505        let caps = SimdCapabilityInfo::detect();
506        assert!(!caps.has_neon, "NEON must not appear on x86-64");
507        // AVX2 implies AVX (architectural requirement).
508        if caps.has_avx2 {
509            assert!(caps.has_avx, "AVX2 requires AVX");
510        }
511        // AVX-512 implies SSE4.2 on all known shipping CPUs.
512        if caps.has_avx512f {
513            assert!(caps.has_sse42, "AVX-512 implies SSE4.2");
514        }
515    }
516
517    // ------------------------------------------------------------------
518    // 4. f64/f32 simd widths are derived from vector_width_bytes
519    // ------------------------------------------------------------------
520    #[test]
521    fn test_simd_width_derivation() {
522        let caps = SimdCapabilityInfo::detect();
523        assert_eq!(
524            caps.f64_simd_width(),
525            caps.vector_width_bytes / core::mem::size_of::<f64>()
526        );
527        assert_eq!(
528            caps.f32_simd_width(),
529            caps.vector_width_bytes / core::mem::size_of::<f32>()
530        );
531        // f32 must fit twice as many elements as f64 in the same register.
532        assert_eq!(caps.f32_simd_width(), caps.f64_simd_width() * 2);
533    }
534
535    // ------------------------------------------------------------------
536    // 5. vector_width_bytes agrees with reported capabilities
537    // ------------------------------------------------------------------
538    #[test]
539    fn test_vector_width_matches_capability_tier() {
540        let caps = SimdCapabilityInfo::detect();
541        if caps.has_avx512f {
542            assert_eq!(caps.vector_width_bytes, 64);
543        } else if caps.has_avx2 {
544            assert_eq!(caps.vector_width_bytes, 32);
545        }
546    }
547
548    // ------------------------------------------------------------------
549    // 6. simd_caps() is idempotent — same pointer on repeated calls (std)
550    // ------------------------------------------------------------------
551    #[cfg(feature = "std")]
552    #[test]
553    fn test_simd_caps_stable_pointer() {
554        let a = simd_caps();
555        let b = simd_caps();
556        assert!(
557            core::ptr::eq(a, b),
558            "simd_caps() must return a stable &'static"
559        );
560    }
561
562    // ------------------------------------------------------------------
563    // 7. SimdCapabilityInfo::detect() is stable pointer on std builds
564    // ------------------------------------------------------------------
565    #[cfg(feature = "std")]
566    #[test]
567    fn test_capability_info_stable_pointer() {
568        let a = SimdCapabilityInfo::detect();
569        let b = SimdCapabilityInfo::detect();
570        assert!(
571            core::ptr::eq(a, b),
572            "detect() must return a stable &'static"
573        );
574    }
575
576    // ------------------------------------------------------------------
577    // 8. optimal_level() is consistent with the capability flags
578    // ------------------------------------------------------------------
579    #[test]
580    fn test_optimal_level_consistent_with_flags() {
581        let caps = SimdCapabilityInfo::detect();
582        let level = caps.optimal_level();
583        match level {
584            LegacyLevel::Avx512 => assert!(caps.has_avx512_full()),
585            LegacyLevel::Avx2 => {
586                assert!(!caps.has_avx512_full());
587                assert!(caps.has_avx2_fma());
588            }
589            LegacyLevel::Avx => {
590                assert!(!caps.has_avx512_full());
591                assert!(!caps.has_avx2_fma());
592                assert!(caps.has_avx);
593            }
594            LegacyLevel::Sse42 => {
595                assert!(!caps.has_avx);
596                assert!(caps.has_sse42);
597            }
598            LegacyLevel::Neon => {
599                assert!(caps.has_neon);
600                assert!(!caps.has_avx);
601            }
602            LegacyLevel::Sve => {
603                assert!(caps.has_sve);
604                assert!(!caps.has_neon);
605            }
606            LegacyLevel::Scalar => {
607                assert!(!caps.has_avx);
608                assert!(!caps.has_neon);
609                assert!(!caps.has_sve);
610            }
611        }
612    }
613
614    // ------------------------------------------------------------------
615    // 9. KernelSelector::select() returns valid GemmKernelKind values
616    // ------------------------------------------------------------------
617    #[test]
618    fn test_kernel_selector_valid_kinds() {
619        #[cfg(feature = "std")]
620        let sel = *KernelSelector::select();
621        #[cfg(not(feature = "std"))]
622        let sel = KernelSelector::select();
623
624        assert!(matches!(
625            sel.gemm_f64_kernel,
626            GemmKernelKind::Avx512
627                | GemmKernelKind::Avx2
628                | GemmKernelKind::Neon
629                | GemmKernelKind::Scalar
630        ));
631        assert!(matches!(
632            sel.gemm_f32_kernel,
633            GemmKernelKind::Avx512
634                | GemmKernelKind::Avx2
635                | GemmKernelKind::Neon
636                | GemmKernelKind::Scalar
637        ));
638    }
639
640    // ------------------------------------------------------------------
641    // 10. KernelSelector agrees with SimdCapabilityInfo on which tier to use
642    // ------------------------------------------------------------------
643    #[test]
644    fn test_kernel_selector_agrees_with_capability_info() {
645        let caps = SimdCapabilityInfo::detect();
646
647        #[cfg(feature = "std")]
648        let sel = *KernelSelector::select();
649        #[cfg(not(feature = "std"))]
650        let sel = KernelSelector::select();
651
652        if caps.has_avx512_full() {
653            assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Avx512);
654            assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Avx512);
655        } else if caps.has_avx2_fma() {
656            assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Avx2);
657            assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Avx2);
658        } else if caps.has_neon {
659            assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Neon);
660            assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Neon);
661        } else {
662            assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Scalar);
663            assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Scalar);
664        }
665    }
666
667    // ------------------------------------------------------------------
668    // 11. simd_dispatch_caps! macro selects a branch consistent with caps
669    // ------------------------------------------------------------------
670    #[test]
671    fn test_simd_dispatch_caps_macro_branch_selection() {
672        #[cfg(feature = "std")]
673        let caps = SimdCapabilityInfo::detect();
674        #[cfg(not(feature = "std"))]
675        let caps = SimdCapabilityInfo::detect();
676
677        let chosen: u32 = simd_dispatch_caps!(
678            &caps,
679            avx512  => 512u32,
680            avx2    => 256u32,
681            sse42   => 128u32,
682            neon    => 1000u32,
683            scalar  => 1u32,
684        );
685
686        // Verify the chosen branch is consistent with the flags.
687        if caps.has_avx512_full() {
688            assert_eq!(chosen, 512);
689        } else if caps.has_avx2_fma() {
690            assert_eq!(chosen, 256);
691        } else if caps.has_sse42 {
692            assert_eq!(chosen, 128);
693        } else if caps.has_neon {
694            assert_eq!(chosen, 1000);
695        } else {
696            assert_eq!(chosen, 1);
697        }
698    }
699
700    // ------------------------------------------------------------------
701    // 12. SimdDispatcher trait: reference implementation is correct
702    // ------------------------------------------------------------------
703    struct DotProduct<'a> {
704        x: &'a [f64],
705        y: &'a [f64],
706    }
707
708    impl SimdDispatcher for DotProduct<'_> {
709        type Output = f64;
710
711        fn dispatch_avx512(&self) -> f64 {
712            // Use scalar path for portability in the test.
713            self.dispatch_scalar()
714        }
715
716        fn dispatch_avx2(&self) -> f64 {
717            self.dispatch_scalar()
718        }
719
720        fn dispatch_neon(&self) -> f64 {
721            self.dispatch_scalar()
722        }
723
724        fn dispatch_scalar(&self) -> f64 {
725            self.x.iter().zip(self.y.iter()).map(|(a, b)| a * b).sum()
726        }
727    }
728
729    #[test]
730    fn test_simd_dispatcher_trait_correctness() {
731        let x = [1.0_f64, 2.0, 3.0, 4.0];
732        let y = [5.0_f64, 6.0, 7.0, 8.0];
733        // 1*5 + 2*6 + 3*7 + 4*8 = 5 + 12 + 21 + 32 = 70
734        let result = DotProduct { x: &x, y: &y }.dispatch();
735        assert!((result - 70.0).abs() < f64::EPSILON);
736    }
737
738    // ------------------------------------------------------------------
739    // 13. GemmKernelKind::name() returns non-empty strings
740    // ------------------------------------------------------------------
741    #[test]
742    fn test_gemm_kernel_kind_names_non_empty() {
743        for kind in [
744            GemmKernelKind::Avx512,
745            GemmKernelKind::Avx2,
746            GemmKernelKind::Neon,
747            GemmKernelKind::Scalar,
748        ] {
749            assert!(!kind.name().is_empty());
750        }
751    }
752
753    // ------------------------------------------------------------------
754    // 14. GemmKernelKind::is_simd() is false only for Scalar
755    // ------------------------------------------------------------------
756    #[test]
757    fn test_gemm_kernel_kind_is_simd() {
758        assert!(GemmKernelKind::Avx512.is_simd());
759        assert!(GemmKernelKind::Avx2.is_simd());
760        assert!(GemmKernelKind::Neon.is_simd());
761        assert!(!GemmKernelKind::Scalar.is_simd());
762    }
763
764    // ------------------------------------------------------------------
765    // 15. has_avx512/has_avx2_fma/has_neon helpers agree with detect()
766    // ------------------------------------------------------------------
767    #[test]
768    fn test_free_helper_fns_agree_with_detect() {
769        let caps = SimdCapabilityInfo::detect();
770        // The helpers delegate to the legacy simd_caps() which is consistent
771        // with detect().  We verify they do not contradict each other.
772        if caps.has_avx512_full() {
773            assert!(has_avx512());
774        }
775        if caps.has_avx2_fma() {
776            assert!(has_avx2_fma());
777        }
778        if caps.has_neon {
779            assert!(has_neon());
780        }
781    }
782
783    // ------------------------------------------------------------------
784    // 16. SimdDispatcher: dispatch_scalar used as ground truth
785    // ------------------------------------------------------------------
786    #[test]
787    fn test_dispatcher_scalar_ground_truth() {
788        let x = [0.0_f64; 0];
789        let y = [0.0_f64; 0];
790        let result = DotProduct { x: &x, y: &y }.dispatch();
791        assert_eq!(result, 0.0, "empty dot product must be zero");
792    }
793}