oxiblas_core/simd/dispatch.rs
1//! Runtime SIMD dispatch system with function multi-versioning.
2//!
3//! Provides centralized CPU feature detection and caching for optimal
4//! performance, plus infrastructure for dispatching to the best available
5//! microkernel implementation at runtime.
6//!
7//! # Architecture
8//!
9//! - [`SimdCapabilities`]: Describes which SIMD extensions are available.
10//! - [`simd_caps`]: Returns a `&'static SimdCapabilities` (std) or a freshly
11//! computed value (no_std) for the current CPU.
12//! - [`simd_dispatch!`]: Macro that selects the best implementation branch
13//! based on detected capabilities.
14//! - [`SimdDispatcher`]: Trait for types that provide multi-versioned
15//! implementations of a computation.
16//! - [`KernelSelector`]: Chooses the optimal GEMM microkernel kind at startup.
17//!
18//! # no_std note
19//!
20//! When the `std` feature is disabled, `SimdCapabilities::detect()` returns a
21//! freshly computed value derived from compile-time `target_feature` flags on
22//! every call (no caching), and `simd_caps()` returns that value by value.
23//! This is consistent with the no_std contract: no heap, no globals.
24
25#[cfg(feature = "std")]
26use std::sync::OnceLock;
27
28// ---------------------------------------------------------------------------
29// SimdCapabilities
30// ---------------------------------------------------------------------------
31
32/// CPU SIMD capabilities detected (or derived from compile-time flags) at
33/// startup.
34///
35/// On x86-64 with `std` enabled, detection uses the `is_x86_feature_detected!`
36/// macro which reads CPUID and is cached by the standard library. On targets
37/// that do not have `std`, detection falls back to compile-time
38/// `target_feature` constants, which are set by RUSTFLAGS / `.cargo/config.toml`.
39///
40/// The struct is intentionally kept `Copy` so callers on no_std targets can
41/// store it cheaply on the stack without worrying about ownership.
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub struct SimdCapabilities {
44 // ------------------------------------------------------------------
45 // x86-64 fields
46 // ------------------------------------------------------------------
47 /// SSE4.2 support (x86-64 only).
48 pub has_sse42: bool,
49 /// AVX support (x86-64 only).
50 pub has_avx: bool,
51 /// AVX2 support (x86-64 only).
52 pub has_avx2: bool,
53 /// FMA (Fused Multiply-Add) support (x86-64 only).
54 pub has_fma: bool,
55 /// AVX-512 Foundation support (x86-64 only).
56 pub has_avx512f: bool,
57 /// AVX-512 Byte & Word instructions (x86-64 only).
58 pub has_avx512bw: bool,
59 /// AVX-512 Vector Length extensions (x86-64 only).
60 pub has_avx512vl: bool,
61
62 // ------------------------------------------------------------------
63 // ARM fields
64 // ------------------------------------------------------------------
65 /// NEON support. Always `true` on AArch64; may be `true` on 32-bit ARM.
66 pub has_neon: bool,
67 /// SVE (Scalable Vector Extension) support.
68 pub has_sve: bool,
69
70 // ------------------------------------------------------------------
71 // Memory topology
72 // ------------------------------------------------------------------
73 /// Size of a single cache line in bytes (typically 64).
74 pub cache_line_bytes: usize,
75 /// Width of the widest supported SIMD vector in bytes.
76 pub vector_width_bytes: usize,
77}
78
79impl SimdCapabilities {
80 // ------------------------------------------------------------------
81 // Construction helpers
82 // ------------------------------------------------------------------
83
84 /// Compute the [`SimdCapabilities`] for the current CPU.
85 ///
86 /// When compiled with `std`, the result is cached in a [`OnceLock`] and
87 /// this function returns `&'static Self`; the detection work is only done
88 /// once per process.
89 ///
90 /// When compiled **without** `std`, this returns `Self` by value (derived
91 /// from compile-time `cfg!(target_feature = …)` flags).
92 #[cfg(feature = "std")]
93 pub fn detect() -> &'static Self {
94 simd_caps()
95 }
96
97 /// Compute the [`SimdCapabilities`] for the current CPU (no_std path).
98 ///
99 /// Returns a fresh value on every call because `OnceLock` is not available
100 /// without the standard library.
101 #[cfg(not(feature = "std"))]
102 pub fn detect() -> Self {
103 simd_caps()
104 }
105
106 // ------------------------------------------------------------------
107 // Internal constructor used by simd_caps_compute()
108 // ------------------------------------------------------------------
109
110 #[cfg(all(target_arch = "x86_64", feature = "std"))]
111 fn compute() -> Self {
112 let has_avx512f = is_x86_feature_detected!("avx512f");
113 let has_avx512bw = is_x86_feature_detected!("avx512bw");
114 let has_avx512vl = is_x86_feature_detected!("avx512vl");
115 let has_avx2 = is_x86_feature_detected!("avx2");
116 let has_fma = is_x86_feature_detected!("fma");
117 let has_avx = is_x86_feature_detected!("avx");
118 let has_sse42 = is_x86_feature_detected!("sse4.2");
119
120 let vector_width_bytes = if has_avx512f {
121 64
122 } else if has_avx2 {
123 32
124 } else if has_sse42 {
125 16
126 } else {
127 8
128 };
129
130 Self {
131 has_sse42,
132 has_avx,
133 has_avx2,
134 has_fma,
135 has_avx512f,
136 has_avx512bw,
137 has_avx512vl,
138 has_neon: false,
139 has_sve: false,
140 cache_line_bytes: 64,
141 vector_width_bytes,
142 }
143 }
144
145 #[cfg(all(target_arch = "x86_64", not(feature = "std")))]
146 fn compute() -> Self {
147 let has_avx512f = cfg!(target_feature = "avx512f");
148 let has_avx512bw = cfg!(target_feature = "avx512bw");
149 let has_avx512vl = cfg!(target_feature = "avx512vl");
150 let has_avx2 = cfg!(target_feature = "avx2");
151 let has_fma = cfg!(target_feature = "fma");
152 let has_avx = cfg!(target_feature = "avx");
153 let has_sse42 = cfg!(target_feature = "sse4.2");
154
155 let vector_width_bytes: usize = if has_avx512f {
156 64
157 } else if has_avx2 {
158 32
159 } else if has_sse42 {
160 16
161 } else {
162 8
163 };
164
165 Self {
166 has_sse42,
167 has_avx,
168 has_avx2,
169 has_fma,
170 has_avx512f,
171 has_avx512bw,
172 has_avx512vl,
173 has_neon: false,
174 has_sve: false,
175 cache_line_bytes: 64,
176 vector_width_bytes,
177 }
178 }
179
180 #[cfg(target_arch = "aarch64")]
181 fn compute() -> Self {
182 // NEON is mandatory on AArch64 per the architecture specification.
183 let has_sve = cfg!(target_feature = "sve");
184 Self {
185 has_sse42: false,
186 has_avx: false,
187 has_avx2: false,
188 has_fma: false,
189 has_avx512f: false,
190 has_avx512bw: false,
191 has_avx512vl: false,
192 has_neon: true,
193 has_sve,
194 cache_line_bytes: 64,
195 vector_width_bytes: 16,
196 }
197 }
198
199 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
200 fn compute() -> Self {
201 Self {
202 has_sse42: false,
203 has_avx: false,
204 has_avx2: false,
205 has_fma: false,
206 has_avx512f: false,
207 has_avx512bw: false,
208 has_avx512vl: false,
209 has_neon: false,
210 has_sve: false,
211 cache_line_bytes: 64,
212 vector_width_bytes: 8,
213 }
214 }
215
216 // ------------------------------------------------------------------
217 // Capability queries
218 // ------------------------------------------------------------------
219
220 /// Returns `true` when AVX-512F, BW, and VL are all available — a common
221 /// prerequisite for most practical AVX-512 code paths.
222 #[inline]
223 pub fn has_avx512_full(&self) -> bool {
224 self.has_avx512f && self.has_avx512bw && self.has_avx512vl
225 }
226
227 /// Returns `true` when both AVX2 and FMA are available — the usual
228 /// prerequisite for 256-bit fused multiply-add kernels.
229 #[inline]
230 pub fn has_avx2_fma(&self) -> bool {
231 self.has_avx2 && self.has_fma
232 }
233
234 /// Returns the number of `f64` elements that fit in the widest supported
235 /// SIMD vector register.
236 ///
237 /// | ISA | f64 elements |
238 /// |-----------------|-------------|
239 /// | AVX-512 | 8 |
240 /// | AVX2 / NEON-256 | 4 |
241 /// | SSE4.2 / NEON | 2 |
242 /// | scalar | 1 |
243 #[inline]
244 pub fn f64_simd_width(&self) -> usize {
245 self.vector_width_bytes / core::mem::size_of::<f64>()
246 }
247
248 /// Returns the number of `f32` elements that fit in the widest supported
249 /// SIMD vector register.
250 #[inline]
251 pub fn f32_simd_width(&self) -> usize {
252 self.vector_width_bytes / core::mem::size_of::<f32>()
253 }
254
255 /// Returns the [`SimdLevel`] that best summarises the capabilities.
256 #[inline]
257 pub fn optimal_level(&self) -> SimdLevel {
258 if self.has_avx512_full() {
259 SimdLevel::Avx512
260 } else if self.has_avx2_fma() {
261 SimdLevel::Avx2
262 } else if self.has_avx {
263 SimdLevel::Avx
264 } else if self.has_sse42 {
265 SimdLevel::Sse42
266 } else if self.has_neon {
267 SimdLevel::Neon
268 } else if self.has_sve {
269 SimdLevel::Sve
270 } else {
271 SimdLevel::Scalar
272 }
273 }
274}
275
276// ---------------------------------------------------------------------------
277// SimdLevel (kept distinct from the coarser SimdLevel in simd.rs)
278// ---------------------------------------------------------------------------
279
280/// Fine-grained SIMD instruction set level, used by the dispatch layer.
281///
282/// The ordering (`Ord` derived) is meaningful only within the x86-64 family;
283/// ARM levels intentionally have large discriminants so that comparisons across
284/// families are not accidentally relied upon.
285#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
286pub enum SimdLevel {
287 /// Scalar (no SIMD).
288 Scalar = 0,
289 /// SSE4.2 (128-bit, x86-64).
290 Sse42 = 1,
291 /// AVX (256-bit, x86-64, no FMA).
292 Avx = 2,
293 /// AVX2 with FMA (256-bit, x86-64).
294 Avx2 = 3,
295 /// AVX-512F+BW+VL (512-bit, x86-64).
296 Avx512 = 4,
297 /// NEON (128-bit, AArch64).
298 Neon = 10,
299 /// SVE (scalable, AArch64).
300 Sve = 11,
301}
302
303impl SimdLevel {
304 /// Human-readable name of the level.
305 #[inline]
306 pub const fn name(self) -> &'static str {
307 match self {
308 SimdLevel::Scalar => "scalar",
309 SimdLevel::Sse42 => "SSE4.2",
310 SimdLevel::Avx => "AVX",
311 SimdLevel::Avx2 => "AVX2+FMA",
312 SimdLevel::Avx512 => "AVX-512",
313 SimdLevel::Neon => "NEON",
314 SimdLevel::Sve => "SVE",
315 }
316 }
317
318 /// Number of `f64` elements in a vector of this width.
319 #[inline]
320 pub const fn f64_width(self) -> usize {
321 match self {
322 SimdLevel::Scalar => 1,
323 SimdLevel::Sse42 => 2,
324 SimdLevel::Avx | SimdLevel::Avx2 => 4,
325 SimdLevel::Avx512 => 8,
326 SimdLevel::Neon => 2,
327 SimdLevel::Sve => 2, // conservative; actual width is dynamic
328 }
329 }
330
331 /// Number of `f32` elements in a vector of this width.
332 #[inline]
333 pub const fn f32_width(self) -> usize {
334 match self {
335 SimdLevel::Scalar => 1,
336 SimdLevel::Sse42 => 4,
337 SimdLevel::Avx | SimdLevel::Avx2 => 8,
338 SimdLevel::Avx512 => 16,
339 SimdLevel::Neon => 4,
340 SimdLevel::Sve => 4, // conservative; actual width is dynamic
341 }
342 }
343}
344
345// ---------------------------------------------------------------------------
346// Cached global (std path)
347// ---------------------------------------------------------------------------
348
349#[cfg(feature = "std")]
350static SIMD_CAPS: OnceLock<SimdCapabilities> = OnceLock::new();
351
352/// Returns a reference to the globally cached [`SimdCapabilities`].
353///
354/// Detection is performed exactly once and the result is stored in a process-
355/// wide static. All subsequent calls return the same reference.
356#[cfg(feature = "std")]
357#[inline]
358pub fn simd_caps() -> &'static SimdCapabilities {
359 SIMD_CAPS.get_or_init(SimdCapabilities::compute)
360}
361
362/// Returns the [`SimdCapabilities`] derived from compile-time target features.
363///
364/// Because `OnceLock` requires `std`, this no_std variant recomputes the value
365/// on every call. The computation is a chain of `cfg!` evaluations and is
366/// expected to be fully inlined / constant-folded by the compiler.
367#[cfg(not(feature = "std"))]
368#[inline]
369pub fn simd_caps() -> SimdCapabilities {
370 SimdCapabilities::compute()
371}
372
373/// Returns the optimal [`SimdLevel`] for the current CPU.
374#[inline]
375pub fn optimal_simd_level() -> SimdLevel {
376 #[cfg(feature = "std")]
377 {
378 simd_caps().optimal_level()
379 }
380 #[cfg(not(feature = "std"))]
381 {
382 simd_caps().optimal_level()
383 }
384}
385
386/// Returns `true` when AVX-512F+BW+VL are all available.
387#[inline]
388pub fn has_avx512() -> bool {
389 #[cfg(feature = "std")]
390 {
391 simd_caps().has_avx512_full()
392 }
393 #[cfg(not(feature = "std"))]
394 {
395 simd_caps().has_avx512_full()
396 }
397}
398
399/// Returns `true` when AVX2 and FMA are both available.
400#[inline]
401pub fn has_avx2_fma() -> bool {
402 #[cfg(feature = "std")]
403 {
404 simd_caps().has_avx2_fma()
405 }
406 #[cfg(not(feature = "std"))]
407 {
408 simd_caps().has_avx2_fma()
409 }
410}
411
412/// Returns `true` when NEON is available.
413#[inline]
414pub fn has_neon() -> bool {
415 #[cfg(feature = "std")]
416 {
417 simd_caps().has_neon
418 }
419 #[cfg(not(feature = "std"))]
420 {
421 simd_caps().has_neon
422 }
423}
424
425// ---------------------------------------------------------------------------
426// simd_dispatch! macro
427// ---------------------------------------------------------------------------
428
429/// Dispatch to the best available SIMD implementation for the current CPU.
430///
431/// The macro evaluates the provided [`SimdCapabilities`] reference (or value)
432/// and selects **exactly one** branch, in priority order:
433///
434/// 1. `avx512` — AVX-512F+BW+VL
435/// 2. `avx2` — AVX2 + FMA (256-bit)
436/// 3. `sse42` — SSE4.2 (128-bit)
437/// 4. `neon` — AArch64 NEON
438/// 5. `scalar` — Portable scalar fallback
439///
440/// # Example
441///
442/// ```rust,ignore
443/// use oxiblas_core::simd::dispatch::{simd_caps, simd_dispatch};
444///
445/// let result = simd_dispatch!(
446/// simd_caps(),
447/// avx512 => compute_avx512(&a, &b),
448/// avx2 => compute_avx2(&a, &b),
449/// sse42 => compute_sse42(&a, &b),
450/// neon => compute_neon(&a, &b),
451/// scalar => compute_scalar(&a, &b),
452/// );
453/// ```
454#[macro_export]
455macro_rules! simd_dispatch {
456 (
457 $caps:expr,
458 avx512 => $avx512:expr,
459 avx2 => $avx2:expr,
460 sse42 => $sse42:expr,
461 neon => $neon:expr,
462 scalar => $scalar:expr $(,)?
463 ) => {{
464 let caps = $caps;
465 if caps.has_avx512_full() {
466 $avx512
467 } else if caps.has_avx2_fma() {
468 $avx2
469 } else if caps.has_sse42 {
470 $sse42
471 } else if caps.has_neon {
472 $neon
473 } else {
474 $scalar
475 }
476 }};
477}
478
479// Re-export so the macro is accessible from `oxiblas_core::simd::dispatch`.
480pub use simd_dispatch;
481
482// ---------------------------------------------------------------------------
483// SimdDispatcher trait
484// ---------------------------------------------------------------------------
485
486/// Trait for types that provide architecture-specialised implementations of a
487/// single computation.
488///
489/// Implement the four dispatch variants and call [`SimdDispatcher::dispatch`]
490/// to let the runtime choose the best one automatically.
491///
492/// # Example
493///
494/// ```rust,ignore
495/// struct DotProductF64<'a> {
496/// x: &'a [f64],
497/// y: &'a [f64],
498/// }
499///
500/// impl SimdDispatcher for DotProductF64<'_> {
501/// type Output = f64;
502/// fn dispatch_avx512(&self) -> f64 { /* AVX-512 kernel */ unimplemented!() }
503/// fn dispatch_avx2(&self) -> f64 { /* AVX2+FMA kernel */ unimplemented!() }
504/// fn dispatch_neon(&self) -> f64 { /* NEON kernel */ unimplemented!() }
505/// fn dispatch_scalar(&self) -> f64 { self.x.iter().zip(self.y).map(|(a,b)| a*b).sum() }
506/// }
507///
508/// let result = DotProductF64 { x: &[1.0, 2.0], y: &[3.0, 4.0] }.dispatch();
509/// assert_eq!(result, 11.0);
510/// ```
511pub trait SimdDispatcher {
512 /// The return type of the computation.
513 type Output;
514
515 /// AVX-512F+BW+VL specialised implementation.
516 fn dispatch_avx512(&self) -> Self::Output;
517
518 /// AVX2 + FMA specialised implementation.
519 fn dispatch_avx2(&self) -> Self::Output;
520
521 /// NEON (AArch64) specialised implementation.
522 fn dispatch_neon(&self) -> Self::Output;
523
524 /// Portable scalar fallback implementation.
525 fn dispatch_scalar(&self) -> Self::Output;
526
527 /// SSE4.2 (x86-64, 128-bit) specialised implementation.
528 ///
529 /// Defaults to [`dispatch_scalar`][SimdDispatcher::dispatch_scalar] when
530 /// not overridden, so implementors only need to override it when an
531 /// SSE4.2-specific path exists.
532 fn dispatch_sse42(&self) -> Self::Output {
533 self.dispatch_scalar()
534 }
535
536 /// Select and call the best available implementation.
537 ///
538 /// The selection is based on [`SimdCapabilities::detect`] which caches the
539 /// result in a process-wide static (when `std` is enabled).
540 fn dispatch(&self) -> Self::Output {
541 #[cfg(feature = "std")]
542 let caps = SimdCapabilities::detect();
543 #[cfg(not(feature = "std"))]
544 let caps = SimdCapabilities::detect();
545
546 if caps.has_avx512_full() {
547 self.dispatch_avx512()
548 } else if caps.has_avx2_fma() {
549 self.dispatch_avx2()
550 } else if caps.has_sse42 {
551 self.dispatch_sse42()
552 } else if caps.has_neon {
553 self.dispatch_neon()
554 } else {
555 self.dispatch_scalar()
556 }
557 }
558}
559
560// ---------------------------------------------------------------------------
561// KernelSelector / GemmKernelKind
562// ---------------------------------------------------------------------------
563
564/// Identifies the microkernel variant chosen for GEMM operations.
565#[derive(Debug, Clone, Copy, PartialEq, Eq)]
566pub enum GemmKernelKind {
567 /// AVX-512 microkernel (x86-64, 512-bit registers).
568 Avx512,
569 /// AVX2 + FMA microkernel (x86-64, 256-bit registers).
570 Avx2,
571 /// SSE4.2 microkernel (x86-64, 128-bit registers).
572 Sse42,
573 /// NEON microkernel (AArch64, 128-bit registers).
574 Neon,
575 /// Portable scalar microkernel (fallback).
576 Scalar,
577}
578
579impl GemmKernelKind {
580 /// Returns a human-readable name for the kind.
581 #[inline]
582 pub const fn name(self) -> &'static str {
583 match self {
584 GemmKernelKind::Avx512 => "AVX-512",
585 GemmKernelKind::Avx2 => "AVX2+FMA",
586 GemmKernelKind::Sse42 => "SSE4.2",
587 GemmKernelKind::Neon => "NEON",
588 GemmKernelKind::Scalar => "scalar",
589 }
590 }
591}
592
593/// Selects the optimal GEMM microkernel for each floating-point type based on
594/// the CPU capabilities detected at runtime.
595///
596/// The selection is performed once and stored in a process-wide static (on
597/// std-enabled builds). Call [`KernelSelector::select`] to obtain a shared
598/// reference.
599#[derive(Debug, Clone, Copy, PartialEq, Eq)]
600pub struct KernelSelector {
601 /// Best microkernel kind for double-precision GEMM.
602 pub gemm_f64_kernel: GemmKernelKind,
603 /// Best microkernel kind for single-precision GEMM.
604 pub gemm_f32_kernel: GemmKernelKind,
605}
606
607impl KernelSelector {
608 /// Choose kernel kinds from a [`SimdCapabilities`] snapshot.
609 fn from_caps(caps: &SimdCapabilities) -> Self {
610 let kind = if caps.has_avx512_full() {
611 GemmKernelKind::Avx512
612 } else if caps.has_avx2_fma() {
613 GemmKernelKind::Avx2
614 } else if caps.has_sse42 {
615 GemmKernelKind::Sse42
616 } else if caps.has_neon {
617 GemmKernelKind::Neon
618 } else {
619 GemmKernelKind::Scalar
620 };
621
622 Self {
623 gemm_f64_kernel: kind,
624 gemm_f32_kernel: kind,
625 }
626 }
627
628 /// Returns a reference to the globally cached [`KernelSelector`].
629 ///
630 /// Detection is performed exactly once per process.
631 #[cfg(feature = "std")]
632 pub fn select() -> &'static Self {
633 static KERNEL_SEL: OnceLock<KernelSelector> = OnceLock::new();
634 KERNEL_SEL.get_or_init(|| Self::from_caps(simd_caps()))
635 }
636
637 /// Recomputes the [`KernelSelector`] from compile-time features (no_std).
638 #[cfg(not(feature = "std"))]
639 pub fn select() -> Self {
640 Self::from_caps(&simd_caps())
641 }
642}
643
644// ---------------------------------------------------------------------------
645// Debugging helper (std only)
646// ---------------------------------------------------------------------------
647
648/// Prints a formatted summary of the detected SIMD capabilities to stdout.
649///
650/// Useful for diagnostics and quick sanity checks. Only available when the
651/// `std` feature is enabled.
652#[cfg(feature = "std")]
653pub fn print_capabilities() {
654 let caps = simd_caps();
655 let level = caps.optimal_level();
656
657 println!("=== OxiBLAS SIMD Capabilities ===");
658 println!("Optimal level : {}", level.name());
659 println!("Cache line : {} bytes", caps.cache_line_bytes);
660 println!("Vector width : {} bytes", caps.vector_width_bytes);
661 println!("f64 SIMD width : {} elements", caps.f64_simd_width());
662 println!("f32 SIMD width : {} elements", caps.f32_simd_width());
663
664 #[cfg(target_arch = "x86_64")]
665 {
666 println!("x86-64 Features:");
667 println!(" SSE4.2 : {}", caps.has_sse42);
668 println!(" AVX : {}", caps.has_avx);
669 println!(" AVX2 : {}", caps.has_avx2);
670 println!(" FMA : {}", caps.has_fma);
671 println!(" AVX-512F : {}", caps.has_avx512f);
672 println!(" AVX-512BW : {}", caps.has_avx512bw);
673 println!(" AVX-512VL : {}", caps.has_avx512vl);
674 }
675
676 #[cfg(target_arch = "aarch64")]
677 {
678 println!("AArch64 Features:");
679 println!(" NEON : {}", caps.has_neon);
680 println!(" SVE : {}", caps.has_sve);
681 }
682
683 let sel = KernelSelector::select();
684 println!("Kernel Selection:");
685 println!(" GEMM f64 : {}", sel.gemm_f64_kernel.name());
686 println!(" GEMM f32 : {}", sel.gemm_f32_kernel.name());
687 println!("==================================");
688}
689
690// ---------------------------------------------------------------------------
691// Tests
692// ---------------------------------------------------------------------------
693
694#[cfg(test)]
695mod tests {
696 use super::*;
697
698 // ------------------------------------------------------------------
699 // 1. Detection smoke-test: at least one SIMD feature or scalar path
700 // ------------------------------------------------------------------
701 #[test]
702 fn test_capabilities_detection_does_not_panic() {
703 // Must not panic on any supported target.
704 let caps = SimdCapabilities::detect();
705 // Cache line should be a sensible power-of-two value.
706 assert!(caps.cache_line_bytes >= 8);
707 assert!(caps.cache_line_bytes.is_power_of_two());
708 }
709
710 // ------------------------------------------------------------------
711 // 2. AArch64: NEON must always be true
712 // ------------------------------------------------------------------
713 #[cfg(target_arch = "aarch64")]
714 #[test]
715 fn test_aarch64_neon_always_present() {
716 let caps = SimdCapabilities::detect();
717 assert!(caps.has_neon, "NEON is mandatory on AArch64");
718 // x86-64 flags must be absent
719 assert!(!caps.has_avx2);
720 assert!(!caps.has_avx512f);
721 }
722
723 // ------------------------------------------------------------------
724 // 3. x86-64: flags are self-consistent
725 // ------------------------------------------------------------------
726 #[cfg(target_arch = "x86_64")]
727 #[test]
728 fn test_x86_64_flag_consistency() {
729 let caps = SimdCapabilities::detect();
730 // NEON must not be set on x86-64.
731 assert!(!caps.has_neon);
732 // If AVX2 is present, AVX must be present too (architectural requirement).
733 if caps.has_avx2 {
734 assert!(caps.has_avx, "AVX2 implies AVX");
735 }
736 // If AVX-512F is present, SSE4.2 should be present (all x86-64 CPUs
737 // with AVX-512 also support SSE4.2).
738 if caps.has_avx512f {
739 assert!(caps.has_sse42, "AVX-512 implies SSE4.2");
740 }
741 }
742
743 // ------------------------------------------------------------------
744 // 4. Vector width matches capabilities
745 // ------------------------------------------------------------------
746 #[test]
747 fn test_vector_width_consistent_with_capabilities() {
748 let caps = SimdCapabilities::detect();
749 if caps.has_avx512f {
750 assert_eq!(caps.vector_width_bytes, 64);
751 } else if caps.has_avx2 {
752 assert_eq!(caps.vector_width_bytes, 32);
753 }
754 }
755
756 // ------------------------------------------------------------------
757 // 5. f64 / f32 SIMD widths are derived from vector_width_bytes
758 // ------------------------------------------------------------------
759 #[test]
760 fn test_simd_widths_derived_correctly() {
761 let caps = SimdCapabilities::detect();
762 assert_eq!(
763 caps.f64_simd_width(),
764 caps.vector_width_bytes / core::mem::size_of::<f64>()
765 );
766 assert_eq!(
767 caps.f32_simd_width(),
768 caps.vector_width_bytes / core::mem::size_of::<f32>()
769 );
770 // f32 always holds twice as many elements as f64 in the same register.
771 assert_eq!(caps.f32_simd_width(), caps.f64_simd_width() * 2);
772 }
773
774 // ------------------------------------------------------------------
775 // 6. SimdLevel ordering is meaningful within x86-64 family
776 // ------------------------------------------------------------------
777 #[test]
778 fn test_simd_level_ordering() {
779 assert!(SimdLevel::Avx512 > SimdLevel::Avx2);
780 assert!(SimdLevel::Avx2 > SimdLevel::Avx);
781 assert!(SimdLevel::Avx > SimdLevel::Sse42);
782 assert!(SimdLevel::Sse42 > SimdLevel::Scalar);
783 }
784
785 // ------------------------------------------------------------------
786 // 7. SimdLevel vector widths are consistent
787 // ------------------------------------------------------------------
788 #[test]
789 fn test_simd_level_widths() {
790 // Scalar holds exactly 1 element per lane.
791 assert_eq!(SimdLevel::Scalar.f64_width(), 1);
792 assert_eq!(SimdLevel::Scalar.f32_width(), 1);
793 // Each wider ISA must hold more elements.
794 assert_eq!(SimdLevel::Sse42.f64_width(), 2);
795 assert_eq!(SimdLevel::Sse42.f32_width(), 4);
796 assert_eq!(SimdLevel::Avx2.f64_width(), 4);
797 assert_eq!(SimdLevel::Avx2.f32_width(), 8);
798 assert_eq!(SimdLevel::Avx512.f64_width(), 8);
799 assert_eq!(SimdLevel::Avx512.f32_width(), 16);
800 assert_eq!(SimdLevel::Neon.f64_width(), 2);
801 assert_eq!(SimdLevel::Neon.f32_width(), 4);
802 }
803
804 // ------------------------------------------------------------------
805 // 8. simd_caps() is idempotent (same result on repeated calls)
806 // ------------------------------------------------------------------
807 #[cfg(feature = "std")]
808 #[test]
809 fn test_simd_caps_cached_identity() {
810 let a = simd_caps();
811 let b = simd_caps();
812 // Must be the same static reference.
813 assert!(
814 core::ptr::eq(a, b),
815 "simd_caps() must return a stable &'static"
816 );
817 }
818
819 // ------------------------------------------------------------------
820 // 9. optimal_level() returns a variant that matches the caps flags
821 // ------------------------------------------------------------------
822 #[test]
823 fn test_optimal_level_consistent_with_flags() {
824 let caps = SimdCapabilities::detect();
825 let level = caps.optimal_level();
826 match level {
827 SimdLevel::Avx512 => assert!(caps.has_avx512_full()),
828 SimdLevel::Avx2 => {
829 assert!(!caps.has_avx512_full());
830 assert!(caps.has_avx2_fma());
831 }
832 SimdLevel::Avx => {
833 assert!(!caps.has_avx512_full());
834 assert!(!caps.has_avx2_fma());
835 assert!(caps.has_avx);
836 }
837 SimdLevel::Sse42 => {
838 assert!(!caps.has_avx);
839 assert!(caps.has_sse42);
840 }
841 SimdLevel::Neon => {
842 assert!(caps.has_neon);
843 assert!(!caps.has_avx);
844 }
845 SimdLevel::Sve => {
846 assert!(caps.has_sve);
847 assert!(!caps.has_neon);
848 }
849 SimdLevel::Scalar => {
850 assert!(!caps.has_sse42);
851 assert!(!caps.has_avx);
852 assert!(!caps.has_neon);
853 assert!(!caps.has_sve);
854 }
855 }
856 }
857
858 // ------------------------------------------------------------------
859 // 10. KernelSelector::select() does not panic and returns valid kinds
860 // ------------------------------------------------------------------
861 #[test]
862 fn test_kernel_selector_valid_kinds() {
863 #[cfg(feature = "std")]
864 let sel = *KernelSelector::select();
865 #[cfg(not(feature = "std"))]
866 let sel = KernelSelector::select();
867
868 // Kind must be one of the defined variants.
869 assert!(matches!(
870 sel.gemm_f64_kernel,
871 GemmKernelKind::Avx512
872 | GemmKernelKind::Avx2
873 | GemmKernelKind::Sse42
874 | GemmKernelKind::Neon
875 | GemmKernelKind::Scalar
876 ));
877 assert!(matches!(
878 sel.gemm_f32_kernel,
879 GemmKernelKind::Avx512
880 | GemmKernelKind::Avx2
881 | GemmKernelKind::Sse42
882 | GemmKernelKind::Neon
883 | GemmKernelKind::Scalar
884 ));
885 }
886
887 // ------------------------------------------------------------------
888 // 11. KernelSelector and optimal_level agree on AVX-512
889 // ------------------------------------------------------------------
890 #[test]
891 fn test_kernel_selector_matches_optimal_level() {
892 let caps = SimdCapabilities::detect();
893
894 #[cfg(feature = "std")]
895 let sel = *KernelSelector::select();
896 #[cfg(not(feature = "std"))]
897 let sel = KernelSelector::select();
898
899 if caps.has_avx512_full() {
900 assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Avx512);
901 assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Avx512);
902 } else if caps.has_avx2_fma() {
903 assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Avx2);
904 assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Avx2);
905 } else if caps.has_sse42 {
906 assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Sse42);
907 assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Sse42);
908 } else if caps.has_neon {
909 assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Neon);
910 assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Neon);
911 } else {
912 assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Scalar);
913 assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Scalar);
914 }
915 }
916
917 // ------------------------------------------------------------------
918 // 12. simd_dispatch! macro selects a branch without panicking
919 // ------------------------------------------------------------------
920 #[test]
921 fn test_simd_dispatch_macro_selects_branch() {
922 #[cfg(feature = "std")]
923 let caps = simd_caps();
924 #[cfg(not(feature = "std"))]
925 let caps = &simd_caps();
926
927 let result: u32 = simd_dispatch!(
928 caps,
929 avx512 => 512u32,
930 avx2 => 256u32,
931 sse42 => 128u32,
932 neon => 1000u32,
933 scalar => 1u32,
934 );
935
936 // The result must be consistent with what optimal_level() says.
937 // Note: simd_dispatch! routes AVX (no FMA) to the avx2 branch because
938 // the macro checks has_avx2_fma() not has_avx, so AVX-only maps to sse42.
939 let expected: u32 = match caps.optimal_level() {
940 SimdLevel::Avx512 => 512,
941 SimdLevel::Avx2 => 256,
942 SimdLevel::Avx | SimdLevel::Sse42 => 128,
943 SimdLevel::Neon | SimdLevel::Sve => 1000,
944 SimdLevel::Scalar => 1,
945 };
946 assert_eq!(result, expected);
947 }
948
949 // ------------------------------------------------------------------
950 // 13. SimdDispatcher trait: scalar reference implementation
951 // ------------------------------------------------------------------
952 struct ScalarDot<'a> {
953 x: &'a [f64],
954 y: &'a [f64],
955 }
956
957 impl SimdDispatcher for ScalarDot<'_> {
958 type Output = f64;
959
960 fn dispatch_avx512(&self) -> f64 {
961 // In production this would use AVX-512 intrinsics; here we
962 // delegate to scalar so the test compiles on all platforms.
963 self.dispatch_scalar()
964 }
965
966 fn dispatch_avx2(&self) -> f64 {
967 self.dispatch_scalar()
968 }
969
970 fn dispatch_neon(&self) -> f64 {
971 self.dispatch_scalar()
972 }
973
974 fn dispatch_scalar(&self) -> f64 {
975 self.x.iter().zip(self.y.iter()).map(|(a, b)| a * b).sum()
976 }
977 }
978
979 #[test]
980 fn test_simd_dispatcher_trait_correctness() {
981 let x = [1.0_f64, 2.0, 3.0, 4.0];
982 let y = [5.0_f64, 6.0, 7.0, 8.0];
983 let dot = ScalarDot { x: &x, y: &y };
984
985 // Expected: 1*5 + 2*6 + 3*7 + 4*8 = 5 + 12 + 21 + 32 = 70
986 let result = dot.dispatch();
987 assert!((result - 70.0).abs() < f64::EPSILON);
988 }
989
990 // ------------------------------------------------------------------
991 // 14. print_capabilities does not panic (std only)
992 // ------------------------------------------------------------------
993 #[cfg(feature = "std")]
994 #[test]
995 fn test_print_capabilities_does_not_panic() {
996 print_capabilities();
997 }
998
999 // ------------------------------------------------------------------
1000 // 15. GemmKernelKind names are non-empty strings
1001 // ------------------------------------------------------------------
1002 #[test]
1003 fn test_gemm_kernel_kind_names_non_empty() {
1004 for kind in [
1005 GemmKernelKind::Avx512,
1006 GemmKernelKind::Avx2,
1007 GemmKernelKind::Sse42,
1008 GemmKernelKind::Neon,
1009 GemmKernelKind::Scalar,
1010 ] {
1011 assert!(!kind.name().is_empty());
1012 }
1013 }
1014
1015 // ------------------------------------------------------------------
1016 // 16. has_avx512() / has_avx2_fma() / has_neon() helpers agree with caps
1017 // ------------------------------------------------------------------
1018 #[test]
1019 fn test_helper_functions_agree_with_caps() {
1020 let caps = SimdCapabilities::detect();
1021 assert_eq!(has_avx512(), caps.has_avx512_full());
1022 assert_eq!(has_avx2_fma(), caps.has_avx2_fma());
1023 assert_eq!(has_neon(), caps.has_neon);
1024 }
1025}