Skip to main content

oxibonsai_kernels/
dispatch.rs

1//! Runtime kernel dispatch with CPU feature detection.
2//!
3//! Uses SciRS2-Core's SIMD capability detection to select the best available
4//! kernel implementation at runtime. Falls back to scalar reference
5//! when no SIMD acceleration is available.
6//!
7//! The selection hierarchy (highest priority first):
8//! 1. AVX-512F (x86-64 only)
9//! 2. AVX2 + FMA (x86-64 only)
10//! 3. NEON (AArch64 only)
11//! 4. Reference (scalar — always available)
12
13use crate::dequant;
14use crate::error::KernelResult;
15use crate::gemm;
16use crate::gemv;
17use crate::traits::{Fp8Kernel, OneBitKernel, TernaryKernel};
18use crate::weight_cache::GpuWeightHandle;
19use oxibonsai_core::tensor::BlockQ1_0G128;
20use oxibonsai_core::{BlockFP8E4M3, BlockFP8E5M2};
21#[cfg(feature = "gpu")]
22use std::sync::Arc;
23
24/// Kernel implementation tier, ordered from slowest to fastest.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum KernelTier {
27    /// Pure scalar Rust — correctness reference.
28    Reference,
29    /// AVX2 + FMA (256-bit SIMD, x86-64).
30    #[cfg(target_arch = "x86_64")]
31    Avx2,
32    /// AVX-512F + AVX-512BW + AVX-512VL (512-bit SIMD, x86-64).
33    #[cfg(target_arch = "x86_64")]
34    Avx512,
35    /// NEON (128-bit SIMD, AArch64).
36    #[cfg(target_arch = "aarch64")]
37    Neon,
38    /// GPU-accelerated (Metal / CUDA via scirs2-core).
39    #[cfg(feature = "gpu")]
40    Gpu,
41}
42
43impl std::fmt::Display for KernelTier {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        match self {
46            Self::Reference => write!(f, "reference"),
47            #[cfg(target_arch = "x86_64")]
48            Self::Avx2 => write!(f, "avx2+fma"),
49            #[cfg(target_arch = "x86_64")]
50            Self::Avx512 => write!(f, "avx512f+bw+vl"),
51            #[cfg(target_arch = "aarch64")]
52            Self::Neon => write!(f, "neon"),
53            #[cfg(feature = "gpu")]
54            Self::Gpu => write!(f, "gpu"),
55        }
56    }
57}
58
59/// Dispatches kernel calls to the best available implementation.
60///
61/// Uses [`scirs2_core::simd::detect::CpuFeatures`] for CPU feature
62/// detection, ensuring consistent SIMD dispatch across the COOLJAPAN ecosystem.
63pub struct KernelDispatcher {
64    tier: KernelTier,
65    /// GPU backend handle, available when `gpu` feature is enabled and a
66    /// hardware-accelerated backend was detected at construction time.
67    #[cfg(feature = "gpu")]
68    gpu_backend: Option<Arc<dyn crate::gpu_backend::GpuBackendTrait>>,
69}
70
71impl std::fmt::Debug for KernelDispatcher {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        // GPU backend is a trait object without Debug; show only the tier.
74        f.debug_struct("KernelDispatcher")
75            .field("tier", &self.tier)
76            .finish_non_exhaustive()
77    }
78}
79
80impl KernelDispatcher {
81    /// Create a dispatcher that auto-detects the best available kernel tier.
82    ///
83    /// Queries SciRS2-Core's cached `CpuFeatures` to determine the
84    /// optimal tier for the current CPU.
85    pub fn auto_detect() -> Self {
86        // Try GPU first when the feature is compiled in.
87        #[cfg(feature = "gpu")]
88        {
89            let backend = crate::gpu_backend::select_backend();
90            if backend.is_accelerated() {
91                tracing::info!(backend = backend.name(), "GPU backend available");
92                return Self {
93                    tier: KernelTier::Gpu,
94                    gpu_backend: Some(Arc::from(backend)),
95                };
96            }
97        }
98
99        let caps = scirs2_core::simd::detect::get_cpu_features();
100        let tier = Self::select_tier(caps);
101        tracing::info!(tier = %tier, "selected kernel tier");
102        Self {
103            tier,
104            #[cfg(feature = "gpu")]
105            gpu_backend: None,
106        }
107    }
108
109    /// Create a dispatcher with a specific tier (for testing/benchmarks).
110    pub fn with_tier(tier: KernelTier) -> Self {
111        Self {
112            tier,
113            #[cfg(feature = "gpu")]
114            gpu_backend: None,
115        }
116    }
117
118    /// Create a dispatcher that uses the GPU backend with the given handle.
119    #[cfg(feature = "gpu")]
120    pub fn with_gpu(backend: Arc<dyn crate::gpu_backend::GpuBackendTrait>) -> Self {
121        Self {
122            tier: KernelTier::Gpu,
123            gpu_backend: Some(backend),
124        }
125    }
126
127    /// Get the selected kernel tier.
128    pub fn tier(&self) -> KernelTier {
129        self.tier
130    }
131
132    /// Select best tier based on detected capabilities.
133    ///
134    /// On x86_64, validates scirs2_core detection against Rust's built-in
135    /// `is_x86_feature_detected!` macros to ensure correct tier selection.
136    /// This prevents issues where scirs2_core might incorrectly detect
137    /// CPU features on certain platforms (e.g., Windows AMD CPUs).
138    fn select_tier(caps: &scirs2_core::simd::detect::CpuFeatures) -> KernelTier {
139        #[cfg(target_arch = "x86_64")]
140        {
141            // Use Rust's built-in feature detection as the source of truth.
142            // scirs2_core detection may be unreliable on some platforms.
143            let has_avx512f = is_x86_feature_detected!("avx512f");
144            let has_avx512bw = is_x86_feature_detected!("avx512bw");
145            let has_avx512vl = is_x86_feature_detected!("avx512vl");
146            let has_avx2 = is_x86_feature_detected!("avx2");
147            let has_fma = is_x86_feature_detected!("fma");
148
149            // Log if there's a mismatch between scirs2_core and std detection
150            if caps.has_avx512f != has_avx512f {
151                tracing::warn!(
152                    scirs2_avx512f = caps.has_avx512f,
153                    std_avx512f = has_avx512f,
154                    "CPU feature detection mismatch for AVX-512F, using std detection"
155                );
156            }
157            if caps.has_avx2 != has_avx2 || caps.has_fma != has_fma {
158                tracing::warn!(
159                    scirs2_avx2 = caps.has_avx2,
160                    scirs2_fma = caps.has_fma,
161                    std_avx2 = has_avx2,
162                    std_fma = has_fma,
163                    "CPU feature detection mismatch for AVX2/FMA, using std detection"
164                );
165            }
166
167            // AVX-512 requires all three: avx512f, avx512bw, and avx512vl
168            if has_avx512f && has_avx512bw && has_avx512vl {
169                tracing::debug!("AVX-512 (F+BW+VL) detected, selecting AVX-512 tier");
170                return KernelTier::Avx512;
171            }
172            if has_avx2 && has_fma {
173                tracing::debug!("AVX2 + FMA detected, selecting AVX2 tier");
174                return KernelTier::Avx2;
175            }
176
177            // Log fallback to reference tier
178            tracing::warn!(
179                has_avx512f,
180                has_avx512bw,
181                has_avx512vl,
182                has_avx2,
183                has_fma,
184                "No SIMD acceleration available, falling back to reference tier (this will be slow)"
185            );
186        }
187
188        #[cfg(target_arch = "aarch64")]
189        {
190            if caps.has_neon {
191                return KernelTier::Neon;
192            }
193        }
194
195        // Suppress unused-variable warning on architectures with no SIMD paths
196        let _ = caps;
197        KernelTier::Reference
198    }
199}
200
201/// Minimum number of rows before the GPU path is worthwhile.
202///
203/// Below this threshold the overhead of host-to-device transfer exceeds the
204/// compute savings, so we fall back to the best SIMD tier.
205#[cfg(feature = "gpu")]
206const GPU_MIN_ROWS: usize = 1024;
207
208impl KernelDispatcher {
209    /// Return the best CPU-only tier for use as GPU fallback.
210    ///
211    /// Uses Rust's built-in `is_x86_feature_detected!` macros directly
212    /// for reliable detection, bypassing scirs2_core which may have issues
213    /// on certain platforms.
214    #[cfg(feature = "gpu")]
215    fn cpu_tier() -> KernelTier {
216        #[cfg(target_arch = "x86_64")]
217        {
218            let has_avx512f = is_x86_feature_detected!("avx512f");
219            let has_avx512bw = is_x86_feature_detected!("avx512bw");
220            let has_avx512vl = is_x86_feature_detected!("avx512vl");
221            let has_avx2 = is_x86_feature_detected!("avx2");
222            let has_fma = is_x86_feature_detected!("fma");
223
224            if has_avx512f && has_avx512bw && has_avx512vl {
225                return KernelTier::Avx512;
226            }
227            if has_avx2 && has_fma {
228                return KernelTier::Avx2;
229            }
230        }
231
232        #[cfg(target_arch = "aarch64")]
233        {
234            // NEON is always available on AArch64
235            return KernelTier::Neon;
236        }
237
238        #[allow(unreachable_code)]
239        KernelTier::Reference
240    }
241
242    /// Dispatch a `dequant` call using the best *CPU* tier.
243    #[cfg(feature = "gpu")]
244    fn cpu_dequant(blocks: &[BlockQ1_0G128], output: &mut [f32]) -> KernelResult<()> {
245        match Self::cpu_tier() {
246            KernelTier::Reference => dequant::dequant_1bit_g128(blocks, output),
247            #[cfg(target_arch = "x86_64")]
248            KernelTier::Avx2 => unsafe { crate::simd_avx2::dequant_1bit_g128_avx2(blocks, output) },
249            #[cfg(target_arch = "x86_64")]
250            KernelTier::Avx512 => unsafe {
251                crate::simd_avx512::dequant_1bit_g128_avx512(blocks, output)
252            },
253            #[cfg(target_arch = "aarch64")]
254            KernelTier::Neon => unsafe { crate::simd_neon::dequant_1bit_g128_neon(blocks, output) },
255            #[cfg(feature = "gpu")]
256            KernelTier::Gpu => dequant::dequant_1bit_g128(blocks, output),
257        }
258    }
259
260    /// Dispatch a `gemv` call using the best *CPU* tier.
261    #[cfg(feature = "gpu")]
262    fn cpu_gemv(
263        blocks: &[BlockQ1_0G128],
264        input: &[f32],
265        output: &mut [f32],
266        n_rows: usize,
267        k: usize,
268    ) -> KernelResult<()> {
269        match Self::cpu_tier() {
270            KernelTier::Reference => gemv::gemv_1bit_g128(blocks, input, output, n_rows, k),
271            #[cfg(target_arch = "x86_64")]
272            KernelTier::Avx2 => unsafe {
273                crate::simd_avx2::gemv_1bit_g128_avx2_prefetch(blocks, input, output, n_rows, k)
274            },
275            #[cfg(target_arch = "x86_64")]
276            KernelTier::Avx512 => unsafe {
277                crate::simd_avx512::gemv_1bit_g128_avx512_prefetch(blocks, input, output, n_rows, k)
278            },
279            #[cfg(target_arch = "aarch64")]
280            KernelTier::Neon => unsafe {
281                crate::simd_neon::gemv_1bit_g128_neon_prefetch(blocks, input, output, n_rows, k)
282            },
283            #[cfg(feature = "gpu")]
284            KernelTier::Gpu => gemv::gemv_1bit_g128(blocks, input, output, n_rows, k),
285        }
286    }
287
288    /// Dispatch a `gemm` call using the best *CPU* tier.
289    #[cfg(feature = "gpu")]
290    fn cpu_gemm(
291        blocks: &[BlockQ1_0G128],
292        input: &[f32],
293        output: &mut [f32],
294        m: usize,
295        n_rows: usize,
296        k: usize,
297    ) -> KernelResult<()> {
298        match Self::cpu_tier() {
299            KernelTier::Reference => gemm::gemm_1bit_g128(blocks, input, output, m, n_rows, k),
300            #[cfg(target_arch = "x86_64")]
301            KernelTier::Avx2 => unsafe {
302                crate::simd_avx2::gemm_1bit_g128_avx2_prefetch(blocks, input, output, m, n_rows, k)
303            },
304            // No prefetch variant for AVX-512 GEMM — keep non-prefetch.
305            #[cfg(target_arch = "x86_64")]
306            KernelTier::Avx512 => unsafe {
307                crate::simd_avx512::gemm_1bit_g128_avx512(blocks, input, output, m, n_rows, k)
308            },
309            #[cfg(target_arch = "aarch64")]
310            KernelTier::Neon => unsafe {
311                crate::simd_neon::gemm_1bit_g128_neon_prefetch(blocks, input, output, m, n_rows, k)
312            },
313            #[cfg(feature = "gpu")]
314            KernelTier::Gpu => gemm::gemm_1bit_g128(blocks, input, output, m, n_rows, k),
315        }
316    }
317
318    /// Dispatch a `dequant_ternary` call using the best *CPU* tier.
319    #[cfg(feature = "gpu")]
320    fn cpu_dequant_ternary(
321        blocks: &[oxibonsai_core::BlockTQ2_0_g128],
322        output: &mut [f32],
323    ) -> KernelResult<()> {
324        match Self::cpu_tier() {
325            KernelTier::Reference => crate::dequant_ternary::dequant_tq2_0_g128(blocks, output),
326            #[cfg(target_arch = "x86_64")]
327            KernelTier::Avx2 => unsafe {
328                crate::simd_avx2::dequant_tq2_0_g128_avx2(blocks, output)
329            },
330            #[cfg(target_arch = "x86_64")]
331            KernelTier::Avx512 => unsafe {
332                crate::simd_avx512::dequant_tq2_0_g128_avx512(blocks, output)
333            },
334            #[cfg(target_arch = "aarch64")]
335            KernelTier::Neon => unsafe {
336                crate::simd_neon::dequant_tq2_0_g128_neon(blocks, output)
337            },
338            #[cfg(feature = "gpu")]
339            KernelTier::Gpu => crate::dequant_ternary::dequant_tq2_0_g128(blocks, output),
340        }
341    }
342
343    /// Dispatch a `gemv_ternary` call using the best *CPU* tier.
344    #[cfg(feature = "gpu")]
345    fn cpu_gemv_ternary(
346        blocks: &[oxibonsai_core::BlockTQ2_0_g128],
347        input: &[f32],
348        output: &mut [f32],
349        n_rows: usize,
350        k: usize,
351    ) -> KernelResult<()> {
352        match Self::cpu_tier() {
353            KernelTier::Reference => {
354                crate::gemv_ternary::gemv_tq2_0_g128(blocks, input, output, n_rows, k)
355            }
356            #[cfg(target_arch = "x86_64")]
357            KernelTier::Avx2 => unsafe {
358                crate::simd_avx2::gemv_tq2_0_g128_avx2_prefetch(blocks, input, output, n_rows, k)
359            },
360            #[cfg(target_arch = "x86_64")]
361            KernelTier::Avx512 => unsafe {
362                crate::simd_avx512::gemv_tq2_0_g128_avx512_prefetch(
363                    blocks, input, output, n_rows, k,
364                )
365            },
366            #[cfg(target_arch = "aarch64")]
367            KernelTier::Neon => unsafe {
368                crate::simd_neon::gemv_tq2_0_g128_neon_prefetch(blocks, input, output, n_rows, k)
369            },
370            #[cfg(feature = "gpu")]
371            KernelTier::Gpu => {
372                crate::gemv_ternary::gemv_tq2_0_g128(blocks, input, output, n_rows, k)
373            }
374        }
375    }
376
377    /// Dispatch a `gemm_ternary` call using the best *CPU* tier.
378    #[cfg(feature = "gpu")]
379    fn cpu_gemm_ternary(
380        blocks: &[oxibonsai_core::BlockTQ2_0_g128],
381        input: &[f32],
382        output: &mut [f32],
383        m: usize,
384        n_rows: usize,
385        k: usize,
386    ) -> KernelResult<()> {
387        match Self::cpu_tier() {
388            KernelTier::Reference => {
389                crate::gemm_ternary::gemm_tq2_0_g128(blocks, input, output, m, n_rows, k)
390            }
391            #[cfg(target_arch = "x86_64")]
392            KernelTier::Avx2 => unsafe {
393                crate::simd_avx2::gemm_tq2_0_g128_avx2(blocks, input, output, m, n_rows, k)
394            },
395            #[cfg(target_arch = "x86_64")]
396            KernelTier::Avx512 => unsafe {
397                crate::simd_avx512::gemm_tq2_0_g128_avx512(blocks, input, output, m, n_rows, k)
398            },
399            #[cfg(target_arch = "aarch64")]
400            KernelTier::Neon => unsafe {
401                crate::simd_neon::gemm_tq2_0_g128_neon(blocks, input, output, m, n_rows, k)
402            },
403            #[cfg(feature = "gpu")]
404            KernelTier::Gpu => {
405                crate::gemm_ternary::gemm_tq2_0_g128(blocks, input, output, m, n_rows, k)
406            }
407        }
408    }
409
410    /// Reinterpret a slice of `BlockQ1_0G128` as raw bytes (zero-copy).
411    ///
412    /// # Safety
413    /// `BlockQ1_0G128` is `#[repr(C)]` with a well-defined 18-byte layout,
414    /// so this transmute is safe.
415    #[cfg(feature = "gpu")]
416    fn blocks_as_bytes(blocks: &[BlockQ1_0G128]) -> &[u8] {
417        let ptr = blocks.as_ptr() as *const u8;
418        let len = std::mem::size_of_val(blocks);
419        // SAFETY: BlockQ1_0G128 is repr(C), POD-like, with no padding.
420        unsafe { std::slice::from_raw_parts(ptr, len) }
421    }
422}
423
424impl OneBitKernel for KernelDispatcher {
425    fn dequant(&self, blocks: &[BlockQ1_0G128], output: &mut [f32]) -> KernelResult<()> {
426        match self.tier {
427            KernelTier::Reference => dequant::dequant_1bit_g128(blocks, output),
428            #[cfg(target_arch = "x86_64")]
429            KernelTier::Avx2 => unsafe { crate::simd_avx2::dequant_1bit_g128_avx2(blocks, output) },
430            #[cfg(target_arch = "x86_64")]
431            KernelTier::Avx512 => unsafe {
432                crate::simd_avx512::dequant_1bit_g128_avx512(blocks, output)
433            },
434            #[cfg(target_arch = "aarch64")]
435            KernelTier::Neon => unsafe { crate::simd_neon::dequant_1bit_g128_neon(blocks, output) },
436            // GPU dequant is not worth the transfer cost — use best CPU path.
437            #[cfg(feature = "gpu")]
438            KernelTier::Gpu => Self::cpu_dequant(blocks, output),
439        }
440    }
441
442    fn gemv(
443        &self,
444        blocks: &[BlockQ1_0G128],
445        input: &[f32],
446        output: &mut [f32],
447        n_rows: usize,
448        k: usize,
449    ) -> KernelResult<()> {
450        match self.tier {
451            KernelTier::Reference => gemv::gemv_1bit_g128(blocks, input, output, n_rows, k),
452            #[cfg(target_arch = "x86_64")]
453            KernelTier::Avx2 => unsafe {
454                crate::simd_avx2::gemv_1bit_g128_avx2_prefetch(blocks, input, output, n_rows, k)
455            },
456            #[cfg(target_arch = "x86_64")]
457            KernelTier::Avx512 => unsafe {
458                crate::simd_avx512::gemv_1bit_g128_avx512_prefetch(blocks, input, output, n_rows, k)
459            },
460            #[cfg(target_arch = "aarch64")]
461            KernelTier::Neon => unsafe {
462                crate::simd_neon::gemv_1bit_g128_neon_prefetch(blocks, input, output, n_rows, k)
463            },
464            #[cfg(feature = "gpu")]
465            KernelTier::Gpu => {
466                if n_rows < GPU_MIN_ROWS {
467                    return Self::cpu_gemv(blocks, input, output, n_rows, k);
468                }
469                if let Some(ref backend) = self.gpu_backend {
470                    let bytes = Self::blocks_as_bytes(blocks);
471                    match backend.gemv_q1_g128(bytes, input, n_rows, k) {
472                        Ok(result) => {
473                            let copy_len = output.len().min(result.len());
474                            output[..copy_len].copy_from_slice(&result[..copy_len]);
475                            return Ok(());
476                        }
477                        Err(e) => {
478                            tracing::warn!(error = %e, "GPU gemv failed, falling back to CPU");
479                            return Self::cpu_gemv(blocks, input, output, n_rows, k);
480                        }
481                    }
482                }
483                Self::cpu_gemv(blocks, input, output, n_rows, k)
484            }
485        }
486    }
487
488    fn gemm(
489        &self,
490        blocks: &[BlockQ1_0G128],
491        input: &[f32],
492        output: &mut [f32],
493        m: usize,
494        n_rows: usize,
495        k: usize,
496    ) -> KernelResult<()> {
497        match self.tier {
498            KernelTier::Reference => gemm::gemm_1bit_g128(blocks, input, output, m, n_rows, k),
499            #[cfg(target_arch = "x86_64")]
500            KernelTier::Avx2 => unsafe {
501                crate::simd_avx2::gemm_1bit_g128_avx2_prefetch(blocks, input, output, m, n_rows, k)
502            },
503            // No prefetch variant for AVX-512 GEMM — keep non-prefetch.
504            #[cfg(target_arch = "x86_64")]
505            KernelTier::Avx512 => unsafe {
506                crate::simd_avx512::gemm_1bit_g128_avx512(blocks, input, output, m, n_rows, k)
507            },
508            #[cfg(target_arch = "aarch64")]
509            KernelTier::Neon => unsafe {
510                crate::simd_neon::gemm_1bit_g128_neon_prefetch(blocks, input, output, m, n_rows, k)
511            },
512            #[cfg(feature = "gpu")]
513            KernelTier::Gpu => {
514                if n_rows < GPU_MIN_ROWS {
515                    return Self::cpu_gemm(blocks, input, output, m, n_rows, k);
516                }
517                if let Some(ref backend) = self.gpu_backend {
518                    let bytes = Self::blocks_as_bytes(blocks);
519                    match backend.gemm_q1_g128(bytes, input, m, n_rows, k) {
520                        Ok(result) => {
521                            let copy_len = output.len().min(result.len());
522                            output[..copy_len].copy_from_slice(&result[..copy_len]);
523                            return Ok(());
524                        }
525                        Err(e) => {
526                            tracing::warn!(error = %e, "GPU gemm failed, falling back to CPU");
527                            return Self::cpu_gemm(blocks, input, output, m, n_rows, k);
528                        }
529                    }
530                }
531                Self::cpu_gemm(blocks, input, output, m, n_rows, k)
532            }
533        }
534    }
535
536    fn name(&self) -> &'static str {
537        match self.tier {
538            KernelTier::Reference => "Q1_0_g128 reference (scalar)",
539            #[cfg(target_arch = "x86_64")]
540            KernelTier::Avx2 => "Q1_0_g128 AVX2+FMA (256-bit)",
541            #[cfg(target_arch = "x86_64")]
542            KernelTier::Avx512 => "Q1_0_g128 AVX-512 (512-bit)",
543            #[cfg(target_arch = "aarch64")]
544            KernelTier::Neon => "Q1_0_g128 NEON (128-bit)",
545            #[cfg(feature = "gpu")]
546            KernelTier::Gpu => "Q1_0_g128 GPU (accelerated)",
547        }
548    }
549
550    fn is_gpu_accelerated(&self) -> bool {
551        #[cfg(feature = "gpu")]
552        let answer = self.tier == KernelTier::Gpu;
553        #[cfg(not(feature = "gpu"))]
554        let answer = false;
555        answer
556    }
557
558    fn upload_weights(&self, blocks: &[BlockQ1_0G128]) -> Option<GpuWeightHandle> {
559        #[cfg(feature = "gpu")]
560        {
561            if let (KernelTier::Gpu, Some(ref backend)) = (self.tier, &self.gpu_backend) {
562                let bytes = Self::blocks_as_bytes(blocks);
563                match backend.upload_weights_raw(bytes) {
564                    Ok(handle) => return Some(handle),
565                    Err(e) => {
566                        tracing::warn!(error = %e, "failed to upload weights to GPU");
567                    }
568                }
569            }
570        }
571        let _ = blocks;
572        None
573    }
574
575    fn gemv_cached(
576        &self,
577        handle: GpuWeightHandle,
578        input: &[f32],
579        output: &mut [f32],
580        n_rows: usize,
581        k: usize,
582    ) -> KernelResult<()> {
583        #[cfg(feature = "gpu")]
584        {
585            if let (KernelTier::Gpu, Some(ref backend)) = (self.tier, &self.gpu_backend) {
586                match backend.gemv_q1_g128_cached(handle, input, n_rows, k) {
587                    Ok(result) => {
588                        let len = output.len().min(result.len());
589                        output[..len].copy_from_slice(&result[..len]);
590                        return Ok(());
591                    }
592                    Err(e) => {
593                        tracing::warn!(error = %e, "cached GPU gemv failed, cannot fallback without blocks");
594                        return Err(crate::error::KernelError::GpuError(e.to_string()));
595                    }
596                }
597            }
598        }
599        let _ = (handle, input, output, n_rows, k);
600        Err(crate::error::KernelError::UnsupportedOperation(
601            "gemv_cached requires GPU tier".into(),
602        ))
603    }
604
605    fn batch_attn_phase(
606        &self,
607        hidden: &[f32],
608        norm_weight: &[f32],
609        norm_eps: f32,
610        qkv_handle: GpuWeightHandle,
611        q_rows: usize,
612        k_rows: usize,
613        h: usize,
614    ) -> KernelResult<Option<(Vec<f32>, Vec<f32>, Vec<f32>)>> {
615        // Disabled: CPU RMSNorm + single fused GEMV is faster than
616        // GPU batch (dispatch_no_wait + dispatch) for only 2 operations.
617        // The GPU batch creates 4 new Metal buffers per call; the fallback
618        // reuses pre-allocated io_input/output buffers.
619        let _ = (hidden, norm_weight, norm_eps, qkv_handle, q_rows, k_rows, h);
620        Ok(None)
621    }
622
623    fn batch_ffn_phase(
624        &self,
625        hidden: &mut [f32],
626        attn_out: &[f32],
627        norm_weight: &[f32],
628        norm_eps: f32,
629        attn_proj_handle: GpuWeightHandle,
630        gate_up_handle: GpuWeightHandle,
631        down_handle: GpuWeightHandle,
632        h: usize,
633        intermediate: usize,
634        attn_proj_k: usize,
635    ) -> KernelResult<bool> {
636        #[cfg(feature = "gpu")]
637        {
638            if let (KernelTier::Gpu, Some(ref backend)) = (self.tier, &self.gpu_backend) {
639                match backend.batch_ffn_phase(
640                    hidden,
641                    attn_out,
642                    norm_weight,
643                    norm_eps,
644                    attn_proj_handle,
645                    gate_up_handle,
646                    down_handle,
647                    h,
648                    intermediate,
649                    attn_proj_k,
650                ) {
651                    Ok(true) => return Ok(true),
652                    Ok(false) => return Ok(false),
653                    Err(e) => {
654                        tracing::warn!(error = %e, "batch FFN phase failed, falling back");
655                        return Ok(false);
656                    }
657                }
658            }
659        }
660        let _ = (
661            hidden,
662            attn_out,
663            norm_weight,
664            norm_eps,
665            attn_proj_handle,
666            gate_up_handle,
667            down_handle,
668            h,
669            intermediate,
670            attn_proj_k,
671        );
672        Ok(false)
673    }
674}
675
676impl TernaryKernel for KernelDispatcher {
677    fn dequant_ternary_g128(
678        &self,
679        blocks: &[oxibonsai_core::BlockTQ2_0_g128],
680        output: &mut [f32],
681    ) -> KernelResult<()> {
682        match self.tier {
683            KernelTier::Reference => crate::dequant_ternary::dequant_tq2_0_g128(blocks, output),
684            #[cfg(target_arch = "x86_64")]
685            KernelTier::Avx2 => unsafe {
686                crate::simd_avx2::dequant_tq2_0_g128_avx2(blocks, output)
687            },
688            #[cfg(target_arch = "x86_64")]
689            KernelTier::Avx512 => unsafe {
690                crate::simd_avx512::dequant_tq2_0_g128_avx512(blocks, output)
691            },
692            #[cfg(target_arch = "aarch64")]
693            KernelTier::Neon => unsafe {
694                crate::simd_neon::dequant_tq2_0_g128_neon(blocks, output)
695            },
696            // No ternary GPU kernels — fall back to best CPU SIMD tier.
697            #[cfg(feature = "gpu")]
698            KernelTier::Gpu => Self::cpu_dequant_ternary(blocks, output),
699        }
700    }
701
702    fn gemv_ternary_g128(
703        &self,
704        blocks: &[oxibonsai_core::BlockTQ2_0_g128],
705        input: &[f32],
706        output: &mut [f32],
707        n_rows: usize,
708        k: usize,
709    ) -> KernelResult<()> {
710        match self.tier {
711            KernelTier::Reference => {
712                crate::gemv_ternary::gemv_tq2_0_g128(blocks, input, output, n_rows, k)
713            }
714            #[cfg(target_arch = "x86_64")]
715            KernelTier::Avx2 => unsafe {
716                crate::simd_avx2::gemv_tq2_0_g128_avx2_prefetch(blocks, input, output, n_rows, k)
717            },
718            #[cfg(target_arch = "x86_64")]
719            KernelTier::Avx512 => unsafe {
720                crate::simd_avx512::gemv_tq2_0_g128_avx512_prefetch(
721                    blocks, input, output, n_rows, k,
722                )
723            },
724            #[cfg(target_arch = "aarch64")]
725            KernelTier::Neon => unsafe {
726                crate::simd_neon::gemv_tq2_0_g128_neon_prefetch(blocks, input, output, n_rows, k)
727            },
728            // No ternary GPU kernels — fall back to best CPU SIMD tier.
729            #[cfg(feature = "gpu")]
730            KernelTier::Gpu => Self::cpu_gemv_ternary(blocks, input, output, n_rows, k),
731        }
732    }
733
734    fn gemm_ternary_g128(
735        &self,
736        blocks: &[oxibonsai_core::BlockTQ2_0_g128],
737        input: &[f32],
738        output: &mut [f32],
739        m: usize,
740        n_rows: usize,
741        k: usize,
742    ) -> KernelResult<()> {
743        match self.tier {
744            KernelTier::Reference => {
745                crate::gemm_ternary::gemm_tq2_0_g128(blocks, input, output, m, n_rows, k)
746            }
747            #[cfg(target_arch = "x86_64")]
748            KernelTier::Avx2 => unsafe {
749                crate::simd_avx2::gemm_tq2_0_g128_avx2(blocks, input, output, m, n_rows, k)
750            },
751            #[cfg(target_arch = "x86_64")]
752            KernelTier::Avx512 => unsafe {
753                crate::simd_avx512::gemm_tq2_0_g128_avx512(blocks, input, output, m, n_rows, k)
754            },
755            #[cfg(target_arch = "aarch64")]
756            KernelTier::Neon => unsafe {
757                crate::simd_neon::gemm_tq2_0_g128_neon(blocks, input, output, m, n_rows, k)
758            },
759            // No ternary GPU kernels — fall back to best CPU SIMD tier.
760            #[cfg(feature = "gpu")]
761            KernelTier::Gpu => Self::cpu_gemm_ternary(blocks, input, output, m, n_rows, k),
762        }
763    }
764
765    fn upload_weights_ternary(
766        &self,
767        blocks: &[oxibonsai_core::BlockTQ2_0_g128],
768    ) -> Option<GpuWeightHandle> {
769        #[cfg(feature = "gpu")]
770        {
771            if let (KernelTier::Gpu, Some(ref backend)) = (self.tier, &self.gpu_backend) {
772                match backend.upload_weights_ternary(blocks) {
773                    Ok(handle) => return Some(handle),
774                    Err(e) => {
775                        // Some backends (e.g. NativeCudaBackend without a TQ2 kernel)
776                        // legitimately don't support ternary uploads. We get called
777                        // once per ternary weight tensor at model load — log just
778                        // once to avoid hundreds of identical warnings.
779                        use std::sync::atomic::{AtomicBool, Ordering};
780                        static WARNED: AtomicBool = AtomicBool::new(false);
781                        if !WARNED.swap(true, Ordering::Relaxed) {
782                            tracing::warn!(
783                                error = %e,
784                                backend = backend.name(),
785                                "ternary weight GPU upload not supported by backend; \
786                                 falling back to CPU SIMD for ternary GEMV (this message \
787                                 is shown once per process)"
788                            );
789                        }
790                    }
791                }
792            }
793        }
794        let _ = blocks;
795        None
796    }
797
798    fn gemv_ternary_g128_cached(
799        &self,
800        handle: GpuWeightHandle,
801        input: &[f32],
802        output: &mut [f32],
803        n_rows: usize,
804        k: usize,
805    ) -> KernelResult<()> {
806        #[cfg(feature = "gpu")]
807        {
808            if let (KernelTier::Gpu, Some(ref backend)) = (self.tier, &self.gpu_backend) {
809                match backend.gemv_tq2_g128_cached(handle, input, n_rows, k) {
810                    Ok(result) => {
811                        let len = output.len().min(result.len());
812                        output[..len].copy_from_slice(&result[..len]);
813                        return Ok(());
814                    }
815                    Err(e) => {
816                        tracing::warn!(error = %e, "cached GPU ternary gemv failed, cannot fallback without blocks");
817                        return Err(crate::error::KernelError::GpuError(e.to_string()));
818                    }
819                }
820            }
821        }
822        let _ = (handle, input, output, n_rows, k);
823        Err(crate::error::KernelError::UnsupportedOperation(
824            "gemv_ternary_g128_cached requires GPU tier".into(),
825        ))
826    }
827}
828
829// Compile-time size checks: BlockFP8E4M3/E5M2 must be exactly BLOCK_FP8_BYTES (34).
830// This ensures the raw-pointer cast in the CUDA GPU dispatch is sound.
831const _: () =
832    assert!(std::mem::size_of::<oxibonsai_core::BlockFP8E4M3>() == oxibonsai_core::BLOCK_FP8_BYTES);
833const _: () =
834    assert!(std::mem::size_of::<oxibonsai_core::BlockFP8E5M2>() == oxibonsai_core::BLOCK_FP8_BYTES);
835
836impl Fp8Kernel for KernelDispatcher {
837    /// Dequantize FP8 E4M3FN blocks — tier-aware SIMD dispatch.
838    fn dequant_fp8_e4m3(&self, blocks: &[BlockFP8E4M3], output: &mut [f32]) -> KernelResult<()> {
839        match self.tier {
840            #[cfg(target_arch = "x86_64")]
841            KernelTier::Avx512 => unsafe {
842                crate::simd_fp8_avx512::dequant_fp8_e4m3_avx512(blocks, output)
843            },
844            #[cfg(target_arch = "x86_64")]
845            KernelTier::Avx2 => unsafe {
846                crate::simd_fp8_avx2::dequant_fp8_e4m3_avx2(blocks, output)
847            },
848            #[cfg(target_arch = "aarch64")]
849            KernelTier::Neon => unsafe {
850                crate::simd_fp8_neon::dequant_fp8_e4m3_neon(blocks, output)
851            },
852            _ => crate::dequant_fp8::dequant_fp8_e4m3(blocks, output),
853        }
854    }
855
856    /// Dequantize FP8 E5M2 blocks — tier-aware SIMD dispatch.
857    fn dequant_fp8_e5m2(&self, blocks: &[BlockFP8E5M2], output: &mut [f32]) -> KernelResult<()> {
858        match self.tier {
859            #[cfg(target_arch = "x86_64")]
860            KernelTier::Avx512 => unsafe {
861                crate::simd_fp8_avx512::dequant_fp8_e5m2_avx512(blocks, output)
862            },
863            #[cfg(target_arch = "x86_64")]
864            KernelTier::Avx2 => unsafe {
865                crate::simd_fp8_avx2::dequant_fp8_e5m2_avx2(blocks, output)
866            },
867            #[cfg(target_arch = "aarch64")]
868            KernelTier::Neon => unsafe {
869                crate::simd_fp8_neon::dequant_fp8_e5m2_neon(blocks, output)
870            },
871            _ => crate::dequant_fp8::dequant_fp8_e5m2(blocks, output),
872        }
873    }
874
875    /// FP8 E4M3FN GEMV — tier-aware SIMD dispatch with optional GPU acceleration.
876    ///
877    /// Dispatch priority on the `KernelTier::Gpu` path:
878    /// 1. Metal (macOS + `metal` feature) — `metal_gemv_fp8_e4m3`.
879    /// 2. CUDA (Linux/Windows + `native-cuda` feature) — `cuda_gemv_fp8_e4m3`.
880    /// 3. CPU SIMD fallback (AVX-512 / AVX2 / NEON / scalar).
881    ///
882    /// The raw-byte cast of `blocks` to `*const u8` is sound because
883    /// `BlockFP8E4M3` is `#[repr(C)]` with size `BLOCK_FP8_BYTES = 34`.
884    fn gemv_fp8_e4m3(
885        &self,
886        blocks: &[BlockFP8E4M3],
887        input: &[f32],
888        output: &mut [f32],
889        n_rows: usize,
890        k: usize,
891    ) -> KernelResult<()> {
892        // GPU dispatch via Metal — macOS only, `metal` feature.
893        #[cfg(all(feature = "metal", target_os = "macos"))]
894        {
895            // SAFETY: BlockFP8E4M3 is repr(C) with size BLOCK_FP8_BYTES (= 34).
896            let bytes = unsafe {
897                std::slice::from_raw_parts(
898                    blocks.as_ptr().cast::<u8>(),
899                    blocks.len() * oxibonsai_core::BLOCK_FP8_BYTES,
900                )
901            };
902            match crate::gpu_backend::metal_gemv_fp8_e4m3(bytes, input, output, n_rows, k) {
903                Ok(()) => return Ok(()),
904                Err(e) => {
905                    // No Metal device or compile failure: fall through to CPU SIMD path.
906                    let msg = e.to_string();
907                    if !msg.contains("no Metal-capable GPU device") {
908                        tracing::warn!(
909                            error = %e,
910                            "Metal FP8 E4M3 GEMV failed, falling back to CPU SIMD"
911                        );
912                    }
913                }
914            }
915        }
916
917        // GPU dispatch via CUDA NVRTC — Linux/Windows only, native-cuda feature.
918        #[cfg(all(
919            feature = "native-cuda",
920            any(target_os = "linux", target_os = "windows")
921        ))]
922        {
923            // SAFETY: BlockFP8E4M3 is repr(C) with size BLOCK_FP8_BYTES (= 34),
924            // validated by the compile-time assert above.  The slice lifetime is
925            // tied to `blocks` which outlives this call.
926            let bytes = unsafe {
927                std::slice::from_raw_parts(
928                    blocks.as_ptr().cast::<u8>(),
929                    blocks.len() * oxibonsai_core::BLOCK_FP8_BYTES,
930                )
931            };
932            match crate::gpu_backend::cuda_gemv_fp8_e4m3(bytes, input, output, n_rows, k) {
933                Ok(()) => return Ok(()),
934                Err(e) => {
935                    // No CUDA device — fall through to CPU SIMD path silently.
936                    // Any other error is logged at warn level and we fall through.
937                    let msg = e.to_string();
938                    if !msg.contains("no CUDA device") {
939                        tracing::warn!(
940                            error = %e,
941                            "CUDA FP8 E4M3 GEMV failed, falling back to CPU SIMD"
942                        );
943                    }
944                }
945            }
946        }
947
948        match self.tier {
949            #[cfg(target_arch = "x86_64")]
950            KernelTier::Avx512 => unsafe {
951                crate::simd_fp8_avx512::gemv_fp8_e4m3_avx512(blocks, input, output, n_rows, k)
952            },
953            #[cfg(target_arch = "x86_64")]
954            KernelTier::Avx2 => unsafe {
955                crate::simd_fp8_avx2::gemv_fp8_e4m3_avx2(blocks, input, output, n_rows, k)
956            },
957            #[cfg(target_arch = "aarch64")]
958            KernelTier::Neon => unsafe {
959                crate::simd_fp8_neon::gemv_fp8_e4m3_neon(blocks, input, output, n_rows, k)
960            },
961            _ => crate::gemv_fp8::gemv_fp8_e4m3(blocks, input, output, n_rows, k),
962        }
963    }
964
965    /// FP8 E5M2 GEMV — tier-aware SIMD dispatch with optional GPU acceleration.
966    ///
967    /// Mirrors [`gemv_fp8_e4m3`](Self::gemv_fp8_e4m3): Metal → CUDA → CPU SIMD.
968    /// The raw-byte cast is sound because `BlockFP8E5M2` is `#[repr(C)]` with
969    /// size `BLOCK_FP8_BYTES = 34`.
970    fn gemv_fp8_e5m2(
971        &self,
972        blocks: &[BlockFP8E5M2],
973        input: &[f32],
974        output: &mut [f32],
975        n_rows: usize,
976        k: usize,
977    ) -> KernelResult<()> {
978        // GPU dispatch via Metal — macOS only, `metal` feature.
979        #[cfg(all(feature = "metal", target_os = "macos"))]
980        {
981            // SAFETY: BlockFP8E5M2 is repr(C) with size BLOCK_FP8_BYTES (= 34).
982            let bytes = unsafe {
983                std::slice::from_raw_parts(
984                    blocks.as_ptr().cast::<u8>(),
985                    blocks.len() * oxibonsai_core::BLOCK_FP8_BYTES,
986                )
987            };
988            match crate::gpu_backend::metal_gemv_fp8_e5m2(bytes, input, output, n_rows, k) {
989                Ok(()) => return Ok(()),
990                Err(e) => {
991                    let msg = e.to_string();
992                    if !msg.contains("no Metal-capable GPU device") {
993                        tracing::warn!(
994                            error = %e,
995                            "Metal FP8 E5M2 GEMV failed, falling back to CPU SIMD"
996                        );
997                    }
998                }
999            }
1000        }
1001
1002        // GPU dispatch via CUDA NVRTC — Linux/Windows only, native-cuda feature.
1003        #[cfg(all(
1004            feature = "native-cuda",
1005            any(target_os = "linux", target_os = "windows")
1006        ))]
1007        {
1008            // SAFETY: BlockFP8E5M2 is repr(C) with size BLOCK_FP8_BYTES (= 34),
1009            // validated by the compile-time assert above.
1010            let bytes = unsafe {
1011                std::slice::from_raw_parts(
1012                    blocks.as_ptr().cast::<u8>(),
1013                    blocks.len() * oxibonsai_core::BLOCK_FP8_BYTES,
1014                )
1015            };
1016            match crate::gpu_backend::cuda_gemv_fp8_e5m2(bytes, input, output, n_rows, k) {
1017                Ok(()) => return Ok(()),
1018                Err(e) => {
1019                    let msg = e.to_string();
1020                    if !msg.contains("no CUDA device") {
1021                        tracing::warn!(
1022                            error = %e,
1023                            "CUDA FP8 E5M2 GEMV failed, falling back to CPU SIMD"
1024                        );
1025                    }
1026                }
1027            }
1028        }
1029
1030        match self.tier {
1031            #[cfg(target_arch = "x86_64")]
1032            KernelTier::Avx512 => unsafe {
1033                crate::simd_fp8_avx512::gemv_fp8_e5m2_avx512(blocks, input, output, n_rows, k)
1034            },
1035            #[cfg(target_arch = "x86_64")]
1036            KernelTier::Avx2 => unsafe {
1037                crate::simd_fp8_avx2::gemv_fp8_e5m2_avx2(blocks, input, output, n_rows, k)
1038            },
1039            #[cfg(target_arch = "aarch64")]
1040            KernelTier::Neon => unsafe {
1041                crate::simd_fp8_neon::gemv_fp8_e5m2_neon(blocks, input, output, n_rows, k)
1042            },
1043            _ => crate::gemv_fp8::gemv_fp8_e5m2(blocks, input, output, n_rows, k),
1044        }
1045    }
1046
1047    /// FP8 E4M3FN GEMM — tier-aware SIMD dispatch.
1048    fn gemm_fp8_e4m3(
1049        &self,
1050        blocks: &[BlockFP8E4M3],
1051        inputs: &[f32],
1052        outputs: &mut [f32],
1053        n_rows: usize,
1054        k: usize,
1055        batch: usize,
1056    ) -> KernelResult<()> {
1057        match self.tier {
1058            #[cfg(target_arch = "x86_64")]
1059            KernelTier::Avx512 => unsafe {
1060                crate::simd_fp8_avx512::gemm_fp8_e4m3_avx512(
1061                    blocks, inputs, outputs, n_rows, k, batch,
1062                )
1063            },
1064            #[cfg(target_arch = "x86_64")]
1065            KernelTier::Avx2 => unsafe {
1066                crate::simd_fp8_avx2::gemm_fp8_e4m3_avx2(blocks, inputs, outputs, n_rows, k, batch)
1067            },
1068            #[cfg(target_arch = "aarch64")]
1069            KernelTier::Neon => unsafe {
1070                crate::simd_fp8_neon::gemm_fp8_e4m3_neon(blocks, inputs, outputs, n_rows, k, batch)
1071            },
1072            _ => crate::gemm_fp8::gemm_fp8_e4m3(blocks, inputs, outputs, n_rows, k, batch),
1073        }
1074    }
1075
1076    /// FP8 E5M2 GEMM — tier-aware SIMD dispatch.
1077    fn gemm_fp8_e5m2(
1078        &self,
1079        blocks: &[BlockFP8E5M2],
1080        inputs: &[f32],
1081        outputs: &mut [f32],
1082        n_rows: usize,
1083        k: usize,
1084        batch: usize,
1085    ) -> KernelResult<()> {
1086        match self.tier {
1087            #[cfg(target_arch = "x86_64")]
1088            KernelTier::Avx512 => unsafe {
1089                crate::simd_fp8_avx512::gemm_fp8_e5m2_avx512(
1090                    blocks, inputs, outputs, n_rows, k, batch,
1091                )
1092            },
1093            #[cfg(target_arch = "x86_64")]
1094            KernelTier::Avx2 => unsafe {
1095                crate::simd_fp8_avx2::gemm_fp8_e5m2_avx2(blocks, inputs, outputs, n_rows, k, batch)
1096            },
1097            #[cfg(target_arch = "aarch64")]
1098            KernelTier::Neon => unsafe {
1099                crate::simd_fp8_neon::gemm_fp8_e5m2_neon(blocks, inputs, outputs, n_rows, k, batch)
1100            },
1101            _ => crate::gemm_fp8::gemm_fp8_e5m2(blocks, inputs, outputs, n_rows, k, batch),
1102        }
1103    }
1104
1105    fn name_fp8(&self) -> &'static str {
1106        match self.tier {
1107            #[cfg(target_arch = "x86_64")]
1108            KernelTier::Avx512 => "fp8_avx512",
1109            #[cfg(target_arch = "x86_64")]
1110            KernelTier::Avx2 => "fp8_avx2",
1111            #[cfg(target_arch = "aarch64")]
1112            KernelTier::Neon => "fp8_neon",
1113            _ => "fp8_reference",
1114        }
1115    }
1116}
1117
1118#[cfg(test)]
1119mod tests {
1120    use super::*;
1121
1122    #[test]
1123    fn auto_detect_creates_dispatcher() {
1124        let dispatcher = KernelDispatcher::auto_detect();
1125        // On x86-64 with AVX2, it should pick Avx2; otherwise Reference
1126        let _tier = dispatcher.tier();
1127        let _name = dispatcher.name();
1128    }
1129
1130    /// Verify that CPU feature detection uses std's is_x86_feature_detected!
1131    /// and not scirs2_core, which may have issues on some platforms.
1132    #[cfg(target_arch = "x86_64")]
1133    #[test]
1134    fn cpu_feature_detection_uses_std() {
1135        // This test verifies the fix for GitHub issue #4:
1136        // Token generation hangs at 100% CPU on Windows AMD CPUs.
1137        //
1138        // The issue was that scirs2_core might incorrectly detect CPU features,
1139        // causing the wrong kernel tier to be selected.
1140
1141        let has_avx2 = is_x86_feature_detected!("avx2");
1142        let has_fma = is_x86_feature_detected!("fma");
1143
1144        let dispatcher = KernelDispatcher::auto_detect();
1145        let tier = dispatcher.tier();
1146
1147        // If std detects AVX2+FMA, tier should be at least Avx2 (or GPU if available).
1148        // The original bug (#4) was the dispatcher falling back to Reference on
1149        // AVX2+FMA hardware; GPU is acceptable since it's strictly faster than AVX2
1150        // for the workloads where dispatch matters.
1151        if has_avx2 && has_fma {
1152            #[cfg(feature = "gpu")]
1153            let acceptable = matches!(
1154                tier,
1155                KernelTier::Avx2 | KernelTier::Avx512 | KernelTier::Gpu
1156            );
1157            #[cfg(not(feature = "gpu"))]
1158            let acceptable = matches!(tier, KernelTier::Avx2 | KernelTier::Avx512);
1159            assert!(
1160                acceptable,
1161                "Expected AVX2/AVX-512/GPU tier when AVX2+FMA detected, got {:?}",
1162                tier
1163            );
1164        }
1165    }
1166
1167    #[test]
1168    fn reference_tier_works() {
1169        let dispatcher = KernelDispatcher::with_tier(KernelTier::Reference);
1170        assert_eq!(dispatcher.tier(), KernelTier::Reference);
1171        assert_eq!(dispatcher.name(), "Q1_0_g128 reference (scalar)");
1172    }
1173
1174    #[cfg(target_arch = "x86_64")]
1175    #[test]
1176    fn avx2_tier_name() {
1177        if !(is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma")) {
1178            return;
1179        }
1180        let dispatcher = KernelDispatcher::with_tier(KernelTier::Avx2);
1181        assert_eq!(dispatcher.tier(), KernelTier::Avx2);
1182        assert_eq!(dispatcher.name(), "Q1_0_g128 AVX2+FMA (256-bit)");
1183    }
1184
1185    #[cfg(target_arch = "aarch64")]
1186    #[test]
1187    fn neon_tier_name() {
1188        let dispatcher = KernelDispatcher::with_tier(KernelTier::Neon);
1189        assert_eq!(dispatcher.tier(), KernelTier::Neon);
1190        assert_eq!(dispatcher.name(), "Q1_0_g128 NEON (128-bit)");
1191    }
1192
1193    #[test]
1194    fn dispatcher_exposes_ternary_gemv() {
1195        use crate::TernaryKernel;
1196        use half::f16;
1197        use oxibonsai_core::BlockTQ2_0_g128;
1198
1199        let dispatcher = KernelDispatcher::auto_detect();
1200
1201        // row 0: all +1 (qs=0xAA = 0b10101010 → four 0b10 codes per byte → +1)
1202        // row 1: all -1 (qs=0x00 = 0b00000000 → four 0b00 codes per byte → -1)
1203        let block_pos = BlockTQ2_0_g128 {
1204            qs: [0xAA; 32],
1205            d: f16::from_f32(1.0),
1206        };
1207        let block_neg = BlockTQ2_0_g128 {
1208            qs: [0x00; 32],
1209            d: f16::from_f32(1.0),
1210        };
1211        let blocks = vec![block_pos, block_neg];
1212        let input = vec![1.0f32; 128];
1213        let mut output = vec![0.0f32; 2];
1214
1215        dispatcher
1216            .gemv_ternary_g128(&blocks, &input, &mut output, 2, 128)
1217            .expect("gemv_ternary_g128 should succeed");
1218        assert!(
1219            (output[0] - 128.0).abs() < 1.0,
1220            "row0 expected ~128.0, got {}",
1221            output[0]
1222        );
1223        assert!(
1224            (output[1] + 128.0).abs() < 1.0,
1225            "row1 expected ~-128.0, got {}",
1226            output[1]
1227        );
1228    }
1229
1230    #[test]
1231    fn dispatcher_ternary_reference_tier() {
1232        use crate::TernaryKernel;
1233        use half::f16;
1234        use oxibonsai_core::BlockTQ2_0_g128;
1235
1236        let dispatcher = KernelDispatcher::with_tier(KernelTier::Reference);
1237        let blocks = vec![BlockTQ2_0_g128 {
1238            qs: [0xAA; 32],
1239            d: f16::from_f32(1.0),
1240        }];
1241        let input = vec![1.0f32; 128];
1242        let mut output = vec![0.0f32; 1];
1243
1244        dispatcher
1245            .gemv_ternary_g128(&blocks, &input, &mut output, 1, 128)
1246            .expect("gemv_ternary_g128 should succeed");
1247        assert!((output[0] - 128.0).abs() < 1.0);
1248    }
1249
1250    #[test]
1251    fn ternary_upload_non_gpu_returns_none() {
1252        use crate::TernaryKernel;
1253        use half::f16;
1254        use oxibonsai_core::BlockTQ2_0_g128;
1255
1256        // Reference tier has no GPU — upload_weights_ternary must return None.
1257        let dispatcher = KernelDispatcher::with_tier(KernelTier::Reference);
1258        let block = BlockTQ2_0_g128 {
1259            qs: [0xAAu8; 32],
1260            d: f16::from_f32(1.0),
1261        };
1262        let handle = dispatcher.upload_weights_ternary(&[block]);
1263        assert!(
1264            handle.is_none(),
1265            "expected None for non-GPU tier, got {:?}",
1266            handle
1267        );
1268    }
1269}