oxiblas_core/simd/multiver.rs
1//! Runtime function multi-versioning infrastructure for OxiBLAS.
2//!
3//! This module builds on top of [`crate::simd::dispatch`] to provide:
4//!
5//! - Extended [`SimdCapabilityInfo`] with the field names required by the task
6//! specification (`has_avx512f`, `has_avx2`, `has_sse42`, `has_fma`,
7//! `has_neon`, `cache_line_bytes`, `vector_width_bytes`) and a `detect()`
8//! that returns `&'static Self` when `std` is enabled.
9//! - [`simd_dispatch_caps!`] macro for named-arm dispatch.
10//! - [`SimdDispatcher`] trait for multi-versioned computations.
11//! - [`KernelSelector`] / [`GemmKernelKind`] for startup kernel selection.
12//!
13//! # no_std
14//!
15//! Under `no_std`, `SimdCapabilityInfo::detect()` returns a freshly derived
16//! value on every call because `OnceLock` is unavailable. All fields are set
17//! from compile-time `cfg!(target_feature = …)` constants, which the compiler
18//! constant-folds away.
19
20#[cfg(feature = "std")]
21use std::sync::OnceLock;
22
23use crate::simd::dispatch::{SimdCapabilities as LegacyCaps, SimdLevel as LegacyLevel};
24
25// ---------------------------------------------------------------------------
26// SimdCapabilityInfo — enhanced capability struct
27// ---------------------------------------------------------------------------
28
29/// Extended CPU SIMD capability description with the field naming convention
30/// required by the multi-versioning layer.
31///
32/// On `std`-enabled builds, [`SimdCapabilityInfo::detect`] returns a
33/// `&'static Self` reference that is initialised exactly once per process
34/// (via [`OnceLock`]). On `no_std` builds it returns a fresh `Self` value
35/// computed from compile-time `cfg!` constants.
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub struct SimdCapabilityInfo {
38 // ------------------------------------------------------------------
39 // x86-64
40 // ------------------------------------------------------------------
41 /// SSE4.2 support (x86-64 only).
42 pub has_sse42: bool,
43 /// AVX support (x86-64 only).
44 pub has_avx: bool,
45 /// AVX2 support (x86-64 only).
46 pub has_avx2: bool,
47 /// FMA (Fused Multiply-Add) support (x86-64 only).
48 pub has_fma: bool,
49 /// AVX-512 Foundation support (x86-64 only).
50 pub has_avx512f: bool,
51 /// AVX-512 Byte & Word instructions (x86-64 only).
52 pub has_avx512bw: bool,
53 /// AVX-512 Vector Length extensions (x86-64 only).
54 pub has_avx512vl: bool,
55
56 // ------------------------------------------------------------------
57 // ARM
58 // ------------------------------------------------------------------
59 /// NEON support. Always `true` on AArch64.
60 pub has_neon: bool,
61 /// SVE (Scalable Vector Extension) support.
62 pub has_sve: bool,
63
64 // ------------------------------------------------------------------
65 // Memory topology
66 // ------------------------------------------------------------------
67 /// Bytes in a single cache line (typically 64 on modern CPUs).
68 pub cache_line_bytes: usize,
69 /// Width of the widest supported SIMD vector register in bytes.
70 pub vector_width_bytes: usize,
71}
72
73impl SimdCapabilityInfo {
74 // ------------------------------------------------------------------
75 // Public entry-points
76 // ------------------------------------------------------------------
77
78 /// Detect (or derive) capabilities for the current CPU and return a
79 /// reference to the process-wide cached value.
80 ///
81 /// The first call performs runtime detection and stores the result.
82 /// Subsequent calls return the same `&'static` reference.
83 #[cfg(feature = "std")]
84 #[inline]
85 pub fn detect() -> &'static Self {
86 static INFO: OnceLock<SimdCapabilityInfo> = OnceLock::new();
87 INFO.get_or_init(Self::compute)
88 }
89
90 /// Derive capabilities from compile-time target features (no_std path).
91 ///
92 /// Returns a fresh value on every call. The computation consists only of
93 /// `cfg!` evaluations which the compiler constant-folds.
94 #[cfg(not(feature = "std"))]
95 #[inline]
96 pub fn detect() -> Self {
97 Self::compute()
98 }
99
100 // ------------------------------------------------------------------
101 // Internal construction
102 // ------------------------------------------------------------------
103
104 fn compute() -> Self {
105 let legacy = simd_caps();
106 Self::from_legacy(legacy)
107 }
108
109 /// Build from the legacy [`SimdCapabilities`] already present in
110 /// `dispatch.rs`. This ensures the two structs never diverge in their
111 /// detection logic.
112 #[cfg(feature = "std")]
113 fn from_legacy(legacy: &LegacyCaps) -> Self {
114 // SSE4.2 detection: the legacy struct only tracks SSE3. We check for
115 // SSE4.2 directly when std is available so that runtime detection is
116 // accurate.
117 #[cfg(all(target_arch = "x86_64", feature = "std"))]
118 let has_sse42 = is_x86_feature_detected!("sse4.2");
119 #[cfg(not(target_arch = "x86_64"))]
120 let has_sse42 = false;
121
122 let has_avx512f = legacy.has_avx512f;
123 let has_avx2 = legacy.has_avx2;
124
125 let vector_width_bytes = if has_avx512f {
126 64
127 } else if has_avx2 {
128 32
129 } else if has_sse42 || legacy.has_neon {
130 16
131 } else {
132 8
133 };
134
135 Self {
136 has_sse42,
137 has_avx: legacy.has_avx,
138 has_avx2,
139 has_fma: legacy.has_fma,
140 has_avx512f,
141 has_avx512bw: legacy.has_avx512bw,
142 has_avx512vl: legacy.has_avx512vl,
143 has_neon: legacy.has_neon,
144 has_sve: legacy.has_sve,
145 cache_line_bytes: 64,
146 vector_width_bytes,
147 }
148 }
149
150 /// Build from the legacy [`SimdCapabilities`] (no_std path, passed by
151 /// value).
152 #[cfg(not(feature = "std"))]
153 fn from_legacy(legacy: LegacyCaps) -> Self {
154 let has_sse42 = cfg!(target_feature = "sse4.2");
155 let has_avx512f = legacy.has_avx512f;
156 let has_avx2 = legacy.has_avx2;
157
158 let vector_width_bytes: usize = if has_avx512f {
159 64
160 } else if has_avx2 {
161 32
162 } else if has_sse42 || legacy.has_neon {
163 16
164 } else {
165 8
166 };
167
168 Self {
169 has_sse42,
170 has_avx: legacy.has_avx,
171 has_avx2,
172 has_fma: legacy.has_fma,
173 has_avx512f,
174 has_avx512bw: legacy.has_avx512bw,
175 has_avx512vl: legacy.has_avx512vl,
176 has_neon: legacy.has_neon,
177 has_sve: legacy.has_sve,
178 cache_line_bytes: 64,
179 vector_width_bytes,
180 }
181 }
182
183 // ------------------------------------------------------------------
184 // Capability queries
185 // ------------------------------------------------------------------
186
187 /// Returns `true` when AVX-512F, BW, and VL are all present.
188 #[inline]
189 pub fn has_avx512_full(&self) -> bool {
190 self.has_avx512f && self.has_avx512bw && self.has_avx512vl
191 }
192
193 /// Returns `true` when both AVX2 and FMA are present.
194 #[inline]
195 pub fn has_avx2_fma(&self) -> bool {
196 self.has_avx2 && self.has_fma
197 }
198
199 /// Number of `f64` elements that fit in the widest supported SIMD register.
200 ///
201 /// For AVX-512 this is 8; for AVX2 / NEON-128 this is 4 / 2; for scalar 1.
202 #[inline]
203 pub fn f64_simd_width(&self) -> usize {
204 self.vector_width_bytes / core::mem::size_of::<f64>()
205 }
206
207 /// Number of `f32` elements that fit in the widest supported SIMD register.
208 ///
209 /// Always twice `f64_simd_width()`.
210 #[inline]
211 pub fn f32_simd_width(&self) -> usize {
212 self.vector_width_bytes / core::mem::size_of::<f32>()
213 }
214
215 /// Returns the [`LegacyLevel`] that best summarises these capabilities.
216 #[inline]
217 pub fn optimal_level(&self) -> LegacyLevel {
218 if self.has_avx512_full() {
219 LegacyLevel::Avx512
220 } else if self.has_avx2_fma() {
221 LegacyLevel::Avx2
222 } else if self.has_avx {
223 LegacyLevel::Avx
224 } else if self.has_sse42 {
225 LegacyLevel::Sse42
226 } else if self.has_neon {
227 LegacyLevel::Neon
228 } else if self.has_sve {
229 LegacyLevel::Sve
230 } else {
231 LegacyLevel::Scalar
232 }
233 }
234}
235
236// ---------------------------------------------------------------------------
237// simd_dispatch_caps! macro
238// ---------------------------------------------------------------------------
239
240/// Dispatch to the best available SIMD implementation, selecting among five
241/// named arms in priority order: `avx512`, `avx2`, `sse42`, `neon`, `scalar`.
242///
243/// The first argument must be a [`SimdCapabilityInfo`] reference (or value);
244/// on `std`-enabled builds use [`SimdCapabilityInfo::detect()`].
245///
246/// # Priority
247///
248/// 1. `avx512` — AVX-512F + BW + VL
249/// 2. `avx2` — AVX2 + FMA (256-bit)
250/// 3. `sse42` — SSE4.2 (128-bit)
251/// 4. `neon` — AArch64 NEON
252/// 5. `scalar` — portable fallback
253///
254/// # Example
255///
256/// ```rust
257/// use oxiblas_core::simd::multiver::{SimdCapabilityInfo, simd_dispatch_caps};
258///
259/// # #[cfg(feature = "std")]
260/// let result: &str = simd_dispatch_caps!(
261/// SimdCapabilityInfo::detect(),
262/// avx512 => "avx512",
263/// avx2 => "avx2",
264/// sse42 => "sse42",
265/// neon => "neon",
266/// scalar => "scalar",
267/// );
268/// ```
269#[macro_export]
270macro_rules! simd_dispatch_caps {
271 (
272 $caps:expr,
273 avx512 => $avx512:expr,
274 avx2 => $avx2:expr,
275 sse42 => $sse42:expr,
276 neon => $neon:expr,
277 scalar => $scalar:expr $(,)?
278 ) => {{
279 let _caps = $caps;
280 if _caps.has_avx512_full() {
281 $avx512
282 } else if _caps.has_avx2_fma() {
283 $avx2
284 } else if _caps.has_sse42 {
285 $sse42
286 } else if _caps.has_neon {
287 $neon
288 } else {
289 $scalar
290 }
291 }};
292}
293
294// Make the macro accessible as `crate::simd::multiver::simd_dispatch_caps`.
295pub use simd_dispatch_caps;
296
297// ---------------------------------------------------------------------------
298// SimdDispatcher trait
299// ---------------------------------------------------------------------------
300
301/// Trait for types that provide architecture-specialised implementations of a
302/// single computation via function multi-versioning.
303///
304/// Implement the four required methods and call [`SimdDispatcher::dispatch`]
305/// to have the runtime select the fastest available path automatically.
306///
307/// # Design note
308///
309/// The trait uses `&self` receivers so that the dispatch object can carry all
310/// input data as fields, keeping call-sites clean.
311///
312/// # Example
313///
314/// ```rust
315/// use oxiblas_core::simd::multiver::SimdDispatcher;
316///
317/// struct ScalarSum<'a>(&'a [f64]);
318///
319/// impl SimdDispatcher for ScalarSum<'_> {
320/// type Output = f64;
321/// fn dispatch_avx512(&self) -> f64 { self.dispatch_scalar() }
322/// fn dispatch_avx2(&self) -> f64 { self.dispatch_scalar() }
323/// fn dispatch_neon(&self) -> f64 { self.dispatch_scalar() }
324/// fn dispatch_scalar(&self) -> f64 { self.0.iter().copied().sum() }
325/// }
326///
327/// assert_eq!(ScalarSum(&[1.0, 2.0, 3.0]).dispatch(), 6.0);
328/// ```
329pub trait SimdDispatcher {
330 /// The type returned by the computation.
331 type Output;
332
333 /// AVX-512F+BW+VL specialised implementation.
334 fn dispatch_avx512(&self) -> Self::Output;
335
336 /// AVX2 + FMA specialised implementation.
337 fn dispatch_avx2(&self) -> Self::Output;
338
339 /// NEON (AArch64) specialised implementation.
340 fn dispatch_neon(&self) -> Self::Output;
341
342 /// Portable scalar fallback.
343 fn dispatch_scalar(&self) -> Self::Output;
344
345 /// Select and call the best available implementation for the current CPU.
346 ///
347 /// The selection is based on [`SimdCapabilityInfo::detect`], which caches
348 /// the result in a process-wide static (std builds) or recomputes from
349 /// compile-time flags (no_std builds).
350 fn dispatch(&self) -> Self::Output {
351 #[cfg(feature = "std")]
352 let caps = SimdCapabilityInfo::detect();
353 #[cfg(not(feature = "std"))]
354 let caps = SimdCapabilityInfo::detect();
355
356 if caps.has_avx512_full() {
357 self.dispatch_avx512()
358 } else if caps.has_avx2_fma() {
359 self.dispatch_avx2()
360 } else if caps.has_neon {
361 self.dispatch_neon()
362 } else {
363 self.dispatch_scalar()
364 }
365 }
366}
367
368// ---------------------------------------------------------------------------
369// GemmKernelKind
370// ---------------------------------------------------------------------------
371
372/// Identifies which microkernel variant is used for GEMM operations.
373#[derive(Debug, Clone, Copy, PartialEq, Eq)]
374pub enum GemmKernelKind {
375 /// AVX-512 microkernel (512-bit registers, x86-64).
376 Avx512,
377 /// AVX2 + FMA microkernel (256-bit registers, x86-64).
378 Avx2,
379 /// NEON microkernel (128-bit registers, AArch64).
380 Neon,
381 /// Portable scalar microkernel (fallback for all targets).
382 Scalar,
383}
384
385impl GemmKernelKind {
386 /// Human-readable name of this kernel kind.
387 #[inline]
388 pub const fn name(self) -> &'static str {
389 match self {
390 GemmKernelKind::Avx512 => "AVX-512",
391 GemmKernelKind::Avx2 => "AVX2+FMA",
392 GemmKernelKind::Neon => "NEON",
393 GemmKernelKind::Scalar => "scalar",
394 }
395 }
396
397 /// Returns `true` when this kind uses SIMD (i.e., is not `Scalar`).
398 #[inline]
399 pub const fn is_simd(self) -> bool {
400 !matches!(self, GemmKernelKind::Scalar)
401 }
402}
403
404// ---------------------------------------------------------------------------
405// KernelSelector
406// ---------------------------------------------------------------------------
407
408/// Selects the optimal GEMM microkernel for `f64` and `f32` based on the CPU
409/// capabilities detected at runtime (or compile-time on no_std).
410///
411/// Call [`KernelSelector::select`] once at startup to obtain the globally
412/// cached selector; subsequent calls return the same reference (std) or a
413/// freshly computed identical value (no_std).
414#[derive(Debug, Clone, Copy, PartialEq, Eq)]
415pub struct KernelSelector {
416 /// Best microkernel kind for double-precision (f64) GEMM.
417 pub gemm_f64_kernel: GemmKernelKind,
418 /// Best microkernel kind for single-precision (f32) GEMM.
419 pub gemm_f32_kernel: GemmKernelKind,
420}
421
422impl KernelSelector {
423 fn from_caps(caps: &SimdCapabilityInfo) -> Self {
424 let kind = if caps.has_avx512_full() {
425 GemmKernelKind::Avx512
426 } else if caps.has_avx2_fma() {
427 GemmKernelKind::Avx2
428 } else if caps.has_neon {
429 GemmKernelKind::Neon
430 } else {
431 GemmKernelKind::Scalar
432 };
433
434 Self {
435 gemm_f64_kernel: kind,
436 gemm_f32_kernel: kind,
437 }
438 }
439
440 /// Returns a reference to the globally cached [`KernelSelector`].
441 ///
442 /// The first call performs detection; all subsequent calls return the same
443 /// `&'static` reference.
444 #[cfg(feature = "std")]
445 pub fn select() -> &'static Self {
446 static KERNEL_SEL: OnceLock<KernelSelector> = OnceLock::new();
447 KERNEL_SEL.get_or_init(|| Self::from_caps(SimdCapabilityInfo::detect()))
448 }
449
450 /// Recomputes the [`KernelSelector`] from compile-time target features
451 /// (no_std path).
452 #[cfg(not(feature = "std"))]
453 pub fn select() -> Self {
454 Self::from_caps(&SimdCapabilityInfo::detect())
455 }
456}
457
458// ---------------------------------------------------------------------------
459// Convenience re-exports from dispatch.rs
460// ---------------------------------------------------------------------------
461
462pub use crate::simd::dispatch::{
463 SimdCapabilities, SimdLevel, has_avx2_fma, has_avx512, has_neon, optimal_simd_level, simd_caps,
464};
465
466// ---------------------------------------------------------------------------
467// Tests (>= 10 tests as required)
468// ---------------------------------------------------------------------------
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473
474 // ------------------------------------------------------------------
475 // 1. Detection does not panic, cache_line_bytes is sane
476 // ------------------------------------------------------------------
477 #[test]
478 fn test_detect_does_not_panic_and_cache_line_sane() {
479 let caps = SimdCapabilityInfo::detect();
480 // Cache line must be at least 8 bytes and a power of two.
481 assert!(caps.cache_line_bytes >= 8);
482 assert!(caps.cache_line_bytes.is_power_of_two());
483 // vector_width_bytes must also be a power of two.
484 assert!(caps.vector_width_bytes.is_power_of_two());
485 }
486
487 // ------------------------------------------------------------------
488 // 2. AArch64: NEON always present
489 // ------------------------------------------------------------------
490 #[cfg(target_arch = "aarch64")]
491 #[test]
492 fn test_aarch64_neon_always_true() {
493 let caps = SimdCapabilityInfo::detect();
494 assert!(caps.has_neon, "NEON is mandatory on AArch64");
495 assert!(!caps.has_avx2, "AVX2 must not appear on AArch64");
496 assert!(!caps.has_avx512f, "AVX-512 must not appear on AArch64");
497 }
498
499 // ------------------------------------------------------------------
500 // 3. x86-64: flag hierarchy must be self-consistent
501 // ------------------------------------------------------------------
502 #[cfg(target_arch = "x86_64")]
503 #[test]
504 fn test_x86_64_flag_hierarchy() {
505 let caps = SimdCapabilityInfo::detect();
506 assert!(!caps.has_neon, "NEON must not appear on x86-64");
507 // AVX2 implies AVX (architectural requirement).
508 if caps.has_avx2 {
509 assert!(caps.has_avx, "AVX2 requires AVX");
510 }
511 // AVX-512 implies SSE4.2 on all known shipping CPUs.
512 if caps.has_avx512f {
513 assert!(caps.has_sse42, "AVX-512 implies SSE4.2");
514 }
515 }
516
517 // ------------------------------------------------------------------
518 // 4. f64/f32 simd widths are derived from vector_width_bytes
519 // ------------------------------------------------------------------
520 #[test]
521 fn test_simd_width_derivation() {
522 let caps = SimdCapabilityInfo::detect();
523 assert_eq!(
524 caps.f64_simd_width(),
525 caps.vector_width_bytes / core::mem::size_of::<f64>()
526 );
527 assert_eq!(
528 caps.f32_simd_width(),
529 caps.vector_width_bytes / core::mem::size_of::<f32>()
530 );
531 // f32 must fit twice as many elements as f64 in the same register.
532 assert_eq!(caps.f32_simd_width(), caps.f64_simd_width() * 2);
533 }
534
535 // ------------------------------------------------------------------
536 // 5. vector_width_bytes agrees with reported capabilities
537 // ------------------------------------------------------------------
538 #[test]
539 fn test_vector_width_matches_capability_tier() {
540 let caps = SimdCapabilityInfo::detect();
541 if caps.has_avx512f {
542 assert_eq!(caps.vector_width_bytes, 64);
543 } else if caps.has_avx2 {
544 assert_eq!(caps.vector_width_bytes, 32);
545 }
546 }
547
548 // ------------------------------------------------------------------
549 // 6. simd_caps() is idempotent — same pointer on repeated calls (std)
550 // ------------------------------------------------------------------
551 #[cfg(feature = "std")]
552 #[test]
553 fn test_simd_caps_stable_pointer() {
554 let a = simd_caps();
555 let b = simd_caps();
556 assert!(
557 core::ptr::eq(a, b),
558 "simd_caps() must return a stable &'static"
559 );
560 }
561
562 // ------------------------------------------------------------------
563 // 7. SimdCapabilityInfo::detect() is stable pointer on std builds
564 // ------------------------------------------------------------------
565 #[cfg(feature = "std")]
566 #[test]
567 fn test_capability_info_stable_pointer() {
568 let a = SimdCapabilityInfo::detect();
569 let b = SimdCapabilityInfo::detect();
570 assert!(
571 core::ptr::eq(a, b),
572 "detect() must return a stable &'static"
573 );
574 }
575
576 // ------------------------------------------------------------------
577 // 8. optimal_level() is consistent with the capability flags
578 // ------------------------------------------------------------------
579 #[test]
580 fn test_optimal_level_consistent_with_flags() {
581 let caps = SimdCapabilityInfo::detect();
582 let level = caps.optimal_level();
583 match level {
584 LegacyLevel::Avx512 => assert!(caps.has_avx512_full()),
585 LegacyLevel::Avx2 => {
586 assert!(!caps.has_avx512_full());
587 assert!(caps.has_avx2_fma());
588 }
589 LegacyLevel::Avx => {
590 assert!(!caps.has_avx512_full());
591 assert!(!caps.has_avx2_fma());
592 assert!(caps.has_avx);
593 }
594 LegacyLevel::Sse42 => {
595 assert!(!caps.has_avx);
596 assert!(caps.has_sse42);
597 }
598 LegacyLevel::Neon => {
599 assert!(caps.has_neon);
600 assert!(!caps.has_avx);
601 }
602 LegacyLevel::Sve => {
603 assert!(caps.has_sve);
604 assert!(!caps.has_neon);
605 }
606 LegacyLevel::Scalar => {
607 assert!(!caps.has_avx);
608 assert!(!caps.has_neon);
609 assert!(!caps.has_sve);
610 }
611 }
612 }
613
614 // ------------------------------------------------------------------
615 // 9. KernelSelector::select() returns valid GemmKernelKind values
616 // ------------------------------------------------------------------
617 #[test]
618 fn test_kernel_selector_valid_kinds() {
619 #[cfg(feature = "std")]
620 let sel = *KernelSelector::select();
621 #[cfg(not(feature = "std"))]
622 let sel = KernelSelector::select();
623
624 assert!(matches!(
625 sel.gemm_f64_kernel,
626 GemmKernelKind::Avx512
627 | GemmKernelKind::Avx2
628 | GemmKernelKind::Neon
629 | GemmKernelKind::Scalar
630 ));
631 assert!(matches!(
632 sel.gemm_f32_kernel,
633 GemmKernelKind::Avx512
634 | GemmKernelKind::Avx2
635 | GemmKernelKind::Neon
636 | GemmKernelKind::Scalar
637 ));
638 }
639
640 // ------------------------------------------------------------------
641 // 10. KernelSelector agrees with SimdCapabilityInfo on which tier to use
642 // ------------------------------------------------------------------
643 #[test]
644 fn test_kernel_selector_agrees_with_capability_info() {
645 let caps = SimdCapabilityInfo::detect();
646
647 #[cfg(feature = "std")]
648 let sel = *KernelSelector::select();
649 #[cfg(not(feature = "std"))]
650 let sel = KernelSelector::select();
651
652 if caps.has_avx512_full() {
653 assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Avx512);
654 assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Avx512);
655 } else if caps.has_avx2_fma() {
656 assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Avx2);
657 assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Avx2);
658 } else if caps.has_neon {
659 assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Neon);
660 assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Neon);
661 } else {
662 assert_eq!(sel.gemm_f64_kernel, GemmKernelKind::Scalar);
663 assert_eq!(sel.gemm_f32_kernel, GemmKernelKind::Scalar);
664 }
665 }
666
667 // ------------------------------------------------------------------
668 // 11. simd_dispatch_caps! macro selects a branch consistent with caps
669 // ------------------------------------------------------------------
670 #[test]
671 fn test_simd_dispatch_caps_macro_branch_selection() {
672 #[cfg(feature = "std")]
673 let caps = SimdCapabilityInfo::detect();
674 #[cfg(not(feature = "std"))]
675 let caps = SimdCapabilityInfo::detect();
676
677 let chosen: u32 = simd_dispatch_caps!(
678 &caps,
679 avx512 => 512u32,
680 avx2 => 256u32,
681 sse42 => 128u32,
682 neon => 1000u32,
683 scalar => 1u32,
684 );
685
686 // Verify the chosen branch is consistent with the flags.
687 if caps.has_avx512_full() {
688 assert_eq!(chosen, 512);
689 } else if caps.has_avx2_fma() {
690 assert_eq!(chosen, 256);
691 } else if caps.has_sse42 {
692 assert_eq!(chosen, 128);
693 } else if caps.has_neon {
694 assert_eq!(chosen, 1000);
695 } else {
696 assert_eq!(chosen, 1);
697 }
698 }
699
700 // ------------------------------------------------------------------
701 // 12. SimdDispatcher trait: reference implementation is correct
702 // ------------------------------------------------------------------
703 struct DotProduct<'a> {
704 x: &'a [f64],
705 y: &'a [f64],
706 }
707
708 impl SimdDispatcher for DotProduct<'_> {
709 type Output = f64;
710
711 fn dispatch_avx512(&self) -> f64 {
712 // Use scalar path for portability in the test.
713 self.dispatch_scalar()
714 }
715
716 fn dispatch_avx2(&self) -> f64 {
717 self.dispatch_scalar()
718 }
719
720 fn dispatch_neon(&self) -> f64 {
721 self.dispatch_scalar()
722 }
723
724 fn dispatch_scalar(&self) -> f64 {
725 self.x.iter().zip(self.y.iter()).map(|(a, b)| a * b).sum()
726 }
727 }
728
729 #[test]
730 fn test_simd_dispatcher_trait_correctness() {
731 let x = [1.0_f64, 2.0, 3.0, 4.0];
732 let y = [5.0_f64, 6.0, 7.0, 8.0];
733 // 1*5 + 2*6 + 3*7 + 4*8 = 5 + 12 + 21 + 32 = 70
734 let result = DotProduct { x: &x, y: &y }.dispatch();
735 assert!((result - 70.0).abs() < f64::EPSILON);
736 }
737
738 // ------------------------------------------------------------------
739 // 13. GemmKernelKind::name() returns non-empty strings
740 // ------------------------------------------------------------------
741 #[test]
742 fn test_gemm_kernel_kind_names_non_empty() {
743 for kind in [
744 GemmKernelKind::Avx512,
745 GemmKernelKind::Avx2,
746 GemmKernelKind::Neon,
747 GemmKernelKind::Scalar,
748 ] {
749 assert!(!kind.name().is_empty());
750 }
751 }
752
753 // ------------------------------------------------------------------
754 // 14. GemmKernelKind::is_simd() is false only for Scalar
755 // ------------------------------------------------------------------
756 #[test]
757 fn test_gemm_kernel_kind_is_simd() {
758 assert!(GemmKernelKind::Avx512.is_simd());
759 assert!(GemmKernelKind::Avx2.is_simd());
760 assert!(GemmKernelKind::Neon.is_simd());
761 assert!(!GemmKernelKind::Scalar.is_simd());
762 }
763
764 // ------------------------------------------------------------------
765 // 15. has_avx512/has_avx2_fma/has_neon helpers agree with detect()
766 // ------------------------------------------------------------------
767 #[test]
768 fn test_free_helper_fns_agree_with_detect() {
769 let caps = SimdCapabilityInfo::detect();
770 // The helpers delegate to the legacy simd_caps() which is consistent
771 // with detect(). We verify they do not contradict each other.
772 if caps.has_avx512_full() {
773 assert!(has_avx512());
774 }
775 if caps.has_avx2_fma() {
776 assert!(has_avx2_fma());
777 }
778 if caps.has_neon {
779 assert!(has_neon());
780 }
781 }
782
783 // ------------------------------------------------------------------
784 // 16. SimdDispatcher: dispatch_scalar used as ground truth
785 // ------------------------------------------------------------------
786 #[test]
787 fn test_dispatcher_scalar_ground_truth() {
788 let x = [0.0_f64; 0];
789 let y = [0.0_f64; 0];
790 let result = DotProduct { x: &x, y: &y }.dispatch();
791 assert_eq!(result, 0.0, "empty dot product must be zero");
792 }
793}