1use crate::dequant;
14use crate::error::KernelResult;
15use crate::gemm;
16use crate::gemv;
17use crate::traits::{Fp8Kernel, OneBitKernel, TernaryKernel};
18use crate::weight_cache::GpuWeightHandle;
19use oxibonsai_core::tensor::BlockQ1_0G128;
20use oxibonsai_core::{BlockFP8E4M3, BlockFP8E5M2};
21#[cfg(feature = "gpu")]
22use std::sync::Arc;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum KernelTier {
27 Reference,
29 #[cfg(target_arch = "x86_64")]
31 Avx2,
32 #[cfg(target_arch = "x86_64")]
34 Avx512,
35 #[cfg(target_arch = "aarch64")]
37 Neon,
38 #[cfg(feature = "gpu")]
40 Gpu,
41}
42
43impl std::fmt::Display for KernelTier {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 match self {
46 Self::Reference => write!(f, "reference"),
47 #[cfg(target_arch = "x86_64")]
48 Self::Avx2 => write!(f, "avx2+fma"),
49 #[cfg(target_arch = "x86_64")]
50 Self::Avx512 => write!(f, "avx512f+bw+vl"),
51 #[cfg(target_arch = "aarch64")]
52 Self::Neon => write!(f, "neon"),
53 #[cfg(feature = "gpu")]
54 Self::Gpu => write!(f, "gpu"),
55 }
56 }
57}
58
59pub struct KernelDispatcher {
64 tier: KernelTier,
65 #[cfg(feature = "gpu")]
68 gpu_backend: Option<Arc<dyn crate::gpu_backend::GpuBackendTrait>>,
69}
70
71impl std::fmt::Debug for KernelDispatcher {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 f.debug_struct("KernelDispatcher")
75 .field("tier", &self.tier)
76 .finish_non_exhaustive()
77 }
78}
79
80impl KernelDispatcher {
81 pub fn auto_detect() -> Self {
86 #[cfg(feature = "gpu")]
88 {
89 let backend = crate::gpu_backend::select_backend();
90 if backend.is_accelerated() {
91 tracing::info!(backend = backend.name(), "GPU backend available");
92 return Self {
93 tier: KernelTier::Gpu,
94 gpu_backend: Some(Arc::from(backend)),
95 };
96 }
97 }
98
99 let caps = scirs2_core::simd::detect::get_cpu_features();
100 let tier = Self::select_tier(caps);
101 tracing::info!(tier = %tier, "selected kernel tier");
102 Self {
103 tier,
104 #[cfg(feature = "gpu")]
105 gpu_backend: None,
106 }
107 }
108
109 pub fn with_tier(tier: KernelTier) -> Self {
111 Self {
112 tier,
113 #[cfg(feature = "gpu")]
114 gpu_backend: None,
115 }
116 }
117
118 #[cfg(feature = "gpu")]
120 pub fn with_gpu(backend: Arc<dyn crate::gpu_backend::GpuBackendTrait>) -> Self {
121 Self {
122 tier: KernelTier::Gpu,
123 gpu_backend: Some(backend),
124 }
125 }
126
127 pub fn tier(&self) -> KernelTier {
129 self.tier
130 }
131
132 fn select_tier(caps: &scirs2_core::simd::detect::CpuFeatures) -> KernelTier {
139 #[cfg(target_arch = "x86_64")]
140 {
141 let has_avx512f = is_x86_feature_detected!("avx512f");
144 let has_avx512bw = is_x86_feature_detected!("avx512bw");
145 let has_avx512vl = is_x86_feature_detected!("avx512vl");
146 let has_avx2 = is_x86_feature_detected!("avx2");
147 let has_fma = is_x86_feature_detected!("fma");
148
149 if caps.has_avx512f != has_avx512f {
151 tracing::warn!(
152 scirs2_avx512f = caps.has_avx512f,
153 std_avx512f = has_avx512f,
154 "CPU feature detection mismatch for AVX-512F, using std detection"
155 );
156 }
157 if caps.has_avx2 != has_avx2 || caps.has_fma != has_fma {
158 tracing::warn!(
159 scirs2_avx2 = caps.has_avx2,
160 scirs2_fma = caps.has_fma,
161 std_avx2 = has_avx2,
162 std_fma = has_fma,
163 "CPU feature detection mismatch for AVX2/FMA, using std detection"
164 );
165 }
166
167 if has_avx512f && has_avx512bw && has_avx512vl {
169 tracing::debug!("AVX-512 (F+BW+VL) detected, selecting AVX-512 tier");
170 return KernelTier::Avx512;
171 }
172 if has_avx2 && has_fma {
173 tracing::debug!("AVX2 + FMA detected, selecting AVX2 tier");
174 return KernelTier::Avx2;
175 }
176
177 tracing::warn!(
179 has_avx512f,
180 has_avx512bw,
181 has_avx512vl,
182 has_avx2,
183 has_fma,
184 "No SIMD acceleration available, falling back to reference tier (this will be slow)"
185 );
186 }
187
188 #[cfg(target_arch = "aarch64")]
189 {
190 if caps.has_neon {
191 return KernelTier::Neon;
192 }
193 }
194
195 let _ = caps;
197 KernelTier::Reference
198 }
199}
200
201#[cfg(feature = "gpu")]
206const GPU_MIN_ROWS: usize = 1024;
207
208impl KernelDispatcher {
209 #[cfg(feature = "gpu")]
215 fn cpu_tier() -> KernelTier {
216 #[cfg(target_arch = "x86_64")]
217 {
218 let has_avx512f = is_x86_feature_detected!("avx512f");
219 let has_avx512bw = is_x86_feature_detected!("avx512bw");
220 let has_avx512vl = is_x86_feature_detected!("avx512vl");
221 let has_avx2 = is_x86_feature_detected!("avx2");
222 let has_fma = is_x86_feature_detected!("fma");
223
224 if has_avx512f && has_avx512bw && has_avx512vl {
225 return KernelTier::Avx512;
226 }
227 if has_avx2 && has_fma {
228 return KernelTier::Avx2;
229 }
230 }
231
232 #[cfg(target_arch = "aarch64")]
233 {
234 return KernelTier::Neon;
236 }
237
238 #[allow(unreachable_code)]
239 KernelTier::Reference
240 }
241
242 #[cfg(feature = "gpu")]
244 fn cpu_dequant(blocks: &[BlockQ1_0G128], output: &mut [f32]) -> KernelResult<()> {
245 match Self::cpu_tier() {
246 KernelTier::Reference => dequant::dequant_1bit_g128(blocks, output),
247 #[cfg(target_arch = "x86_64")]
248 KernelTier::Avx2 => unsafe { crate::simd_avx2::dequant_1bit_g128_avx2(blocks, output) },
249 #[cfg(target_arch = "x86_64")]
250 KernelTier::Avx512 => unsafe {
251 crate::simd_avx512::dequant_1bit_g128_avx512(blocks, output)
252 },
253 #[cfg(target_arch = "aarch64")]
254 KernelTier::Neon => unsafe { crate::simd_neon::dequant_1bit_g128_neon(blocks, output) },
255 #[cfg(feature = "gpu")]
256 KernelTier::Gpu => dequant::dequant_1bit_g128(blocks, output),
257 }
258 }
259
260 #[cfg(feature = "gpu")]
262 fn cpu_gemv(
263 blocks: &[BlockQ1_0G128],
264 input: &[f32],
265 output: &mut [f32],
266 n_rows: usize,
267 k: usize,
268 ) -> KernelResult<()> {
269 match Self::cpu_tier() {
270 KernelTier::Reference => gemv::gemv_1bit_g128(blocks, input, output, n_rows, k),
271 #[cfg(target_arch = "x86_64")]
272 KernelTier::Avx2 => unsafe {
273 crate::simd_avx2::gemv_1bit_g128_avx2_prefetch(blocks, input, output, n_rows, k)
274 },
275 #[cfg(target_arch = "x86_64")]
276 KernelTier::Avx512 => unsafe {
277 crate::simd_avx512::gemv_1bit_g128_avx512_prefetch(blocks, input, output, n_rows, k)
278 },
279 #[cfg(target_arch = "aarch64")]
280 KernelTier::Neon => unsafe {
281 crate::simd_neon::gemv_1bit_g128_neon_prefetch(blocks, input, output, n_rows, k)
282 },
283 #[cfg(feature = "gpu")]
284 KernelTier::Gpu => gemv::gemv_1bit_g128(blocks, input, output, n_rows, k),
285 }
286 }
287
288 #[cfg(feature = "gpu")]
290 fn cpu_gemm(
291 blocks: &[BlockQ1_0G128],
292 input: &[f32],
293 output: &mut [f32],
294 m: usize,
295 n_rows: usize,
296 k: usize,
297 ) -> KernelResult<()> {
298 match Self::cpu_tier() {
299 KernelTier::Reference => gemm::gemm_1bit_g128(blocks, input, output, m, n_rows, k),
300 #[cfg(target_arch = "x86_64")]
301 KernelTier::Avx2 => unsafe {
302 crate::simd_avx2::gemm_1bit_g128_avx2_prefetch(blocks, input, output, m, n_rows, k)
303 },
304 #[cfg(target_arch = "x86_64")]
306 KernelTier::Avx512 => unsafe {
307 crate::simd_avx512::gemm_1bit_g128_avx512(blocks, input, output, m, n_rows, k)
308 },
309 #[cfg(target_arch = "aarch64")]
310 KernelTier::Neon => unsafe {
311 crate::simd_neon::gemm_1bit_g128_neon_prefetch(blocks, input, output, m, n_rows, k)
312 },
313 #[cfg(feature = "gpu")]
314 KernelTier::Gpu => gemm::gemm_1bit_g128(blocks, input, output, m, n_rows, k),
315 }
316 }
317
318 #[cfg(feature = "gpu")]
320 fn cpu_dequant_ternary(
321 blocks: &[oxibonsai_core::BlockTQ2_0_g128],
322 output: &mut [f32],
323 ) -> KernelResult<()> {
324 match Self::cpu_tier() {
325 KernelTier::Reference => crate::dequant_ternary::dequant_tq2_0_g128(blocks, output),
326 #[cfg(target_arch = "x86_64")]
327 KernelTier::Avx2 => unsafe {
328 crate::simd_avx2::dequant_tq2_0_g128_avx2(blocks, output)
329 },
330 #[cfg(target_arch = "x86_64")]
331 KernelTier::Avx512 => unsafe {
332 crate::simd_avx512::dequant_tq2_0_g128_avx512(blocks, output)
333 },
334 #[cfg(target_arch = "aarch64")]
335 KernelTier::Neon => unsafe {
336 crate::simd_neon::dequant_tq2_0_g128_neon(blocks, output)
337 },
338 #[cfg(feature = "gpu")]
339 KernelTier::Gpu => crate::dequant_ternary::dequant_tq2_0_g128(blocks, output),
340 }
341 }
342
343 #[cfg(feature = "gpu")]
345 fn cpu_gemv_ternary(
346 blocks: &[oxibonsai_core::BlockTQ2_0_g128],
347 input: &[f32],
348 output: &mut [f32],
349 n_rows: usize,
350 k: usize,
351 ) -> KernelResult<()> {
352 match Self::cpu_tier() {
353 KernelTier::Reference => {
354 crate::gemv_ternary::gemv_tq2_0_g128(blocks, input, output, n_rows, k)
355 }
356 #[cfg(target_arch = "x86_64")]
357 KernelTier::Avx2 => unsafe {
358 crate::simd_avx2::gemv_tq2_0_g128_avx2_prefetch(blocks, input, output, n_rows, k)
359 },
360 #[cfg(target_arch = "x86_64")]
361 KernelTier::Avx512 => unsafe {
362 crate::simd_avx512::gemv_tq2_0_g128_avx512_prefetch(
363 blocks, input, output, n_rows, k,
364 )
365 },
366 #[cfg(target_arch = "aarch64")]
367 KernelTier::Neon => unsafe {
368 crate::simd_neon::gemv_tq2_0_g128_neon_prefetch(blocks, input, output, n_rows, k)
369 },
370 #[cfg(feature = "gpu")]
371 KernelTier::Gpu => {
372 crate::gemv_ternary::gemv_tq2_0_g128(blocks, input, output, n_rows, k)
373 }
374 }
375 }
376
377 #[cfg(feature = "gpu")]
379 fn cpu_gemm_ternary(
380 blocks: &[oxibonsai_core::BlockTQ2_0_g128],
381 input: &[f32],
382 output: &mut [f32],
383 m: usize,
384 n_rows: usize,
385 k: usize,
386 ) -> KernelResult<()> {
387 match Self::cpu_tier() {
388 KernelTier::Reference => {
389 crate::gemm_ternary::gemm_tq2_0_g128(blocks, input, output, m, n_rows, k)
390 }
391 #[cfg(target_arch = "x86_64")]
392 KernelTier::Avx2 => unsafe {
393 crate::simd_avx2::gemm_tq2_0_g128_avx2(blocks, input, output, m, n_rows, k)
394 },
395 #[cfg(target_arch = "x86_64")]
396 KernelTier::Avx512 => unsafe {
397 crate::simd_avx512::gemm_tq2_0_g128_avx512(blocks, input, output, m, n_rows, k)
398 },
399 #[cfg(target_arch = "aarch64")]
400 KernelTier::Neon => unsafe {
401 crate::simd_neon::gemm_tq2_0_g128_neon(blocks, input, output, m, n_rows, k)
402 },
403 #[cfg(feature = "gpu")]
404 KernelTier::Gpu => {
405 crate::gemm_ternary::gemm_tq2_0_g128(blocks, input, output, m, n_rows, k)
406 }
407 }
408 }
409
410 #[cfg(feature = "gpu")]
416 fn blocks_as_bytes(blocks: &[BlockQ1_0G128]) -> &[u8] {
417 let ptr = blocks.as_ptr() as *const u8;
418 let len = std::mem::size_of_val(blocks);
419 unsafe { std::slice::from_raw_parts(ptr, len) }
421 }
422}
423
424impl OneBitKernel for KernelDispatcher {
425 fn dequant(&self, blocks: &[BlockQ1_0G128], output: &mut [f32]) -> KernelResult<()> {
426 match self.tier {
427 KernelTier::Reference => dequant::dequant_1bit_g128(blocks, output),
428 #[cfg(target_arch = "x86_64")]
429 KernelTier::Avx2 => unsafe { crate::simd_avx2::dequant_1bit_g128_avx2(blocks, output) },
430 #[cfg(target_arch = "x86_64")]
431 KernelTier::Avx512 => unsafe {
432 crate::simd_avx512::dequant_1bit_g128_avx512(blocks, output)
433 },
434 #[cfg(target_arch = "aarch64")]
435 KernelTier::Neon => unsafe { crate::simd_neon::dequant_1bit_g128_neon(blocks, output) },
436 #[cfg(feature = "gpu")]
438 KernelTier::Gpu => Self::cpu_dequant(blocks, output),
439 }
440 }
441
442 fn gemv(
443 &self,
444 blocks: &[BlockQ1_0G128],
445 input: &[f32],
446 output: &mut [f32],
447 n_rows: usize,
448 k: usize,
449 ) -> KernelResult<()> {
450 match self.tier {
451 KernelTier::Reference => gemv::gemv_1bit_g128(blocks, input, output, n_rows, k),
452 #[cfg(target_arch = "x86_64")]
453 KernelTier::Avx2 => unsafe {
454 crate::simd_avx2::gemv_1bit_g128_avx2_prefetch(blocks, input, output, n_rows, k)
455 },
456 #[cfg(target_arch = "x86_64")]
457 KernelTier::Avx512 => unsafe {
458 crate::simd_avx512::gemv_1bit_g128_avx512_prefetch(blocks, input, output, n_rows, k)
459 },
460 #[cfg(target_arch = "aarch64")]
461 KernelTier::Neon => unsafe {
462 crate::simd_neon::gemv_1bit_g128_neon_prefetch(blocks, input, output, n_rows, k)
463 },
464 #[cfg(feature = "gpu")]
465 KernelTier::Gpu => {
466 if n_rows < GPU_MIN_ROWS {
467 return Self::cpu_gemv(blocks, input, output, n_rows, k);
468 }
469 if let Some(ref backend) = self.gpu_backend {
470 let bytes = Self::blocks_as_bytes(blocks);
471 match backend.gemv_q1_g128(bytes, input, n_rows, k) {
472 Ok(result) => {
473 let copy_len = output.len().min(result.len());
474 output[..copy_len].copy_from_slice(&result[..copy_len]);
475 return Ok(());
476 }
477 Err(e) => {
478 tracing::warn!(error = %e, "GPU gemv failed, falling back to CPU");
479 return Self::cpu_gemv(blocks, input, output, n_rows, k);
480 }
481 }
482 }
483 Self::cpu_gemv(blocks, input, output, n_rows, k)
484 }
485 }
486 }
487
488 fn gemm(
489 &self,
490 blocks: &[BlockQ1_0G128],
491 input: &[f32],
492 output: &mut [f32],
493 m: usize,
494 n_rows: usize,
495 k: usize,
496 ) -> KernelResult<()> {
497 match self.tier {
498 KernelTier::Reference => gemm::gemm_1bit_g128(blocks, input, output, m, n_rows, k),
499 #[cfg(target_arch = "x86_64")]
500 KernelTier::Avx2 => unsafe {
501 crate::simd_avx2::gemm_1bit_g128_avx2_prefetch(blocks, input, output, m, n_rows, k)
502 },
503 #[cfg(target_arch = "x86_64")]
505 KernelTier::Avx512 => unsafe {
506 crate::simd_avx512::gemm_1bit_g128_avx512(blocks, input, output, m, n_rows, k)
507 },
508 #[cfg(target_arch = "aarch64")]
509 KernelTier::Neon => unsafe {
510 crate::simd_neon::gemm_1bit_g128_neon_prefetch(blocks, input, output, m, n_rows, k)
511 },
512 #[cfg(feature = "gpu")]
513 KernelTier::Gpu => {
514 if n_rows < GPU_MIN_ROWS {
515 return Self::cpu_gemm(blocks, input, output, m, n_rows, k);
516 }
517 if let Some(ref backend) = self.gpu_backend {
518 let bytes = Self::blocks_as_bytes(blocks);
519 match backend.gemm_q1_g128(bytes, input, m, n_rows, k) {
520 Ok(result) => {
521 let copy_len = output.len().min(result.len());
522 output[..copy_len].copy_from_slice(&result[..copy_len]);
523 return Ok(());
524 }
525 Err(e) => {
526 tracing::warn!(error = %e, "GPU gemm failed, falling back to CPU");
527 return Self::cpu_gemm(blocks, input, output, m, n_rows, k);
528 }
529 }
530 }
531 Self::cpu_gemm(blocks, input, output, m, n_rows, k)
532 }
533 }
534 }
535
536 fn name(&self) -> &'static str {
537 match self.tier {
538 KernelTier::Reference => "Q1_0_g128 reference (scalar)",
539 #[cfg(target_arch = "x86_64")]
540 KernelTier::Avx2 => "Q1_0_g128 AVX2+FMA (256-bit)",
541 #[cfg(target_arch = "x86_64")]
542 KernelTier::Avx512 => "Q1_0_g128 AVX-512 (512-bit)",
543 #[cfg(target_arch = "aarch64")]
544 KernelTier::Neon => "Q1_0_g128 NEON (128-bit)",
545 #[cfg(feature = "gpu")]
546 KernelTier::Gpu => "Q1_0_g128 GPU (accelerated)",
547 }
548 }
549
550 fn is_gpu_accelerated(&self) -> bool {
551 #[cfg(feature = "gpu")]
552 let answer = self.tier == KernelTier::Gpu;
553 #[cfg(not(feature = "gpu"))]
554 let answer = false;
555 answer
556 }
557
558 fn upload_weights(&self, blocks: &[BlockQ1_0G128]) -> Option<GpuWeightHandle> {
559 #[cfg(feature = "gpu")]
560 {
561 if let (KernelTier::Gpu, Some(ref backend)) = (self.tier, &self.gpu_backend) {
562 let bytes = Self::blocks_as_bytes(blocks);
563 match backend.upload_weights_raw(bytes) {
564 Ok(handle) => return Some(handle),
565 Err(e) => {
566 tracing::warn!(error = %e, "failed to upload weights to GPU");
567 }
568 }
569 }
570 }
571 let _ = blocks;
572 None
573 }
574
575 fn gemv_cached(
576 &self,
577 handle: GpuWeightHandle,
578 input: &[f32],
579 output: &mut [f32],
580 n_rows: usize,
581 k: usize,
582 ) -> KernelResult<()> {
583 #[cfg(feature = "gpu")]
584 {
585 if let (KernelTier::Gpu, Some(ref backend)) = (self.tier, &self.gpu_backend) {
586 match backend.gemv_q1_g128_cached(handle, input, n_rows, k) {
587 Ok(result) => {
588 let len = output.len().min(result.len());
589 output[..len].copy_from_slice(&result[..len]);
590 return Ok(());
591 }
592 Err(e) => {
593 tracing::warn!(error = %e, "cached GPU gemv failed, cannot fallback without blocks");
594 return Err(crate::error::KernelError::GpuError(e.to_string()));
595 }
596 }
597 }
598 }
599 let _ = (handle, input, output, n_rows, k);
600 Err(crate::error::KernelError::UnsupportedOperation(
601 "gemv_cached requires GPU tier".into(),
602 ))
603 }
604
605 fn batch_attn_phase(
606 &self,
607 hidden: &[f32],
608 norm_weight: &[f32],
609 norm_eps: f32,
610 qkv_handle: GpuWeightHandle,
611 q_rows: usize,
612 k_rows: usize,
613 h: usize,
614 ) -> KernelResult<Option<(Vec<f32>, Vec<f32>, Vec<f32>)>> {
615 let _ = (hidden, norm_weight, norm_eps, qkv_handle, q_rows, k_rows, h);
620 Ok(None)
621 }
622
623 fn batch_ffn_phase(
624 &self,
625 hidden: &mut [f32],
626 attn_out: &[f32],
627 norm_weight: &[f32],
628 norm_eps: f32,
629 attn_proj_handle: GpuWeightHandle,
630 gate_up_handle: GpuWeightHandle,
631 down_handle: GpuWeightHandle,
632 h: usize,
633 intermediate: usize,
634 attn_proj_k: usize,
635 ) -> KernelResult<bool> {
636 #[cfg(feature = "gpu")]
637 {
638 if let (KernelTier::Gpu, Some(ref backend)) = (self.tier, &self.gpu_backend) {
639 match backend.batch_ffn_phase(
640 hidden,
641 attn_out,
642 norm_weight,
643 norm_eps,
644 attn_proj_handle,
645 gate_up_handle,
646 down_handle,
647 h,
648 intermediate,
649 attn_proj_k,
650 ) {
651 Ok(true) => return Ok(true),
652 Ok(false) => return Ok(false),
653 Err(e) => {
654 tracing::warn!(error = %e, "batch FFN phase failed, falling back");
655 return Ok(false);
656 }
657 }
658 }
659 }
660 let _ = (
661 hidden,
662 attn_out,
663 norm_weight,
664 norm_eps,
665 attn_proj_handle,
666 gate_up_handle,
667 down_handle,
668 h,
669 intermediate,
670 attn_proj_k,
671 );
672 Ok(false)
673 }
674}
675
676impl TernaryKernel for KernelDispatcher {
677 fn dequant_ternary_g128(
678 &self,
679 blocks: &[oxibonsai_core::BlockTQ2_0_g128],
680 output: &mut [f32],
681 ) -> KernelResult<()> {
682 match self.tier {
683 KernelTier::Reference => crate::dequant_ternary::dequant_tq2_0_g128(blocks, output),
684 #[cfg(target_arch = "x86_64")]
685 KernelTier::Avx2 => unsafe {
686 crate::simd_avx2::dequant_tq2_0_g128_avx2(blocks, output)
687 },
688 #[cfg(target_arch = "x86_64")]
689 KernelTier::Avx512 => unsafe {
690 crate::simd_avx512::dequant_tq2_0_g128_avx512(blocks, output)
691 },
692 #[cfg(target_arch = "aarch64")]
693 KernelTier::Neon => unsafe {
694 crate::simd_neon::dequant_tq2_0_g128_neon(blocks, output)
695 },
696 #[cfg(feature = "gpu")]
698 KernelTier::Gpu => Self::cpu_dequant_ternary(blocks, output),
699 }
700 }
701
702 fn gemv_ternary_g128(
703 &self,
704 blocks: &[oxibonsai_core::BlockTQ2_0_g128],
705 input: &[f32],
706 output: &mut [f32],
707 n_rows: usize,
708 k: usize,
709 ) -> KernelResult<()> {
710 match self.tier {
711 KernelTier::Reference => {
712 crate::gemv_ternary::gemv_tq2_0_g128(blocks, input, output, n_rows, k)
713 }
714 #[cfg(target_arch = "x86_64")]
715 KernelTier::Avx2 => unsafe {
716 crate::simd_avx2::gemv_tq2_0_g128_avx2_prefetch(blocks, input, output, n_rows, k)
717 },
718 #[cfg(target_arch = "x86_64")]
719 KernelTier::Avx512 => unsafe {
720 crate::simd_avx512::gemv_tq2_0_g128_avx512_prefetch(
721 blocks, input, output, n_rows, k,
722 )
723 },
724 #[cfg(target_arch = "aarch64")]
725 KernelTier::Neon => unsafe {
726 crate::simd_neon::gemv_tq2_0_g128_neon_prefetch(blocks, input, output, n_rows, k)
727 },
728 #[cfg(feature = "gpu")]
730 KernelTier::Gpu => Self::cpu_gemv_ternary(blocks, input, output, n_rows, k),
731 }
732 }
733
734 fn gemm_ternary_g128(
735 &self,
736 blocks: &[oxibonsai_core::BlockTQ2_0_g128],
737 input: &[f32],
738 output: &mut [f32],
739 m: usize,
740 n_rows: usize,
741 k: usize,
742 ) -> KernelResult<()> {
743 match self.tier {
744 KernelTier::Reference => {
745 crate::gemm_ternary::gemm_tq2_0_g128(blocks, input, output, m, n_rows, k)
746 }
747 #[cfg(target_arch = "x86_64")]
748 KernelTier::Avx2 => unsafe {
749 crate::simd_avx2::gemm_tq2_0_g128_avx2(blocks, input, output, m, n_rows, k)
750 },
751 #[cfg(target_arch = "x86_64")]
752 KernelTier::Avx512 => unsafe {
753 crate::simd_avx512::gemm_tq2_0_g128_avx512(blocks, input, output, m, n_rows, k)
754 },
755 #[cfg(target_arch = "aarch64")]
756 KernelTier::Neon => unsafe {
757 crate::simd_neon::gemm_tq2_0_g128_neon(blocks, input, output, m, n_rows, k)
758 },
759 #[cfg(feature = "gpu")]
761 KernelTier::Gpu => Self::cpu_gemm_ternary(blocks, input, output, m, n_rows, k),
762 }
763 }
764
765 fn upload_weights_ternary(
766 &self,
767 blocks: &[oxibonsai_core::BlockTQ2_0_g128],
768 ) -> Option<GpuWeightHandle> {
769 #[cfg(feature = "gpu")]
770 {
771 if let (KernelTier::Gpu, Some(ref backend)) = (self.tier, &self.gpu_backend) {
772 match backend.upload_weights_ternary(blocks) {
773 Ok(handle) => return Some(handle),
774 Err(e) => {
775 use std::sync::atomic::{AtomicBool, Ordering};
780 static WARNED: AtomicBool = AtomicBool::new(false);
781 if !WARNED.swap(true, Ordering::Relaxed) {
782 tracing::warn!(
783 error = %e,
784 backend = backend.name(),
785 "ternary weight GPU upload not supported by backend; \
786 falling back to CPU SIMD for ternary GEMV (this message \
787 is shown once per process)"
788 );
789 }
790 }
791 }
792 }
793 }
794 let _ = blocks;
795 None
796 }
797
798 fn gemv_ternary_g128_cached(
799 &self,
800 handle: GpuWeightHandle,
801 input: &[f32],
802 output: &mut [f32],
803 n_rows: usize,
804 k: usize,
805 ) -> KernelResult<()> {
806 #[cfg(feature = "gpu")]
807 {
808 if let (KernelTier::Gpu, Some(ref backend)) = (self.tier, &self.gpu_backend) {
809 match backend.gemv_tq2_g128_cached(handle, input, n_rows, k) {
810 Ok(result) => {
811 let len = output.len().min(result.len());
812 output[..len].copy_from_slice(&result[..len]);
813 return Ok(());
814 }
815 Err(e) => {
816 tracing::warn!(error = %e, "cached GPU ternary gemv failed, cannot fallback without blocks");
817 return Err(crate::error::KernelError::GpuError(e.to_string()));
818 }
819 }
820 }
821 }
822 let _ = (handle, input, output, n_rows, k);
823 Err(crate::error::KernelError::UnsupportedOperation(
824 "gemv_ternary_g128_cached requires GPU tier".into(),
825 ))
826 }
827}
828
829const _: () =
832 assert!(std::mem::size_of::<oxibonsai_core::BlockFP8E4M3>() == oxibonsai_core::BLOCK_FP8_BYTES);
833const _: () =
834 assert!(std::mem::size_of::<oxibonsai_core::BlockFP8E5M2>() == oxibonsai_core::BLOCK_FP8_BYTES);
835
836impl Fp8Kernel for KernelDispatcher {
837 fn dequant_fp8_e4m3(&self, blocks: &[BlockFP8E4M3], output: &mut [f32]) -> KernelResult<()> {
839 match self.tier {
840 #[cfg(target_arch = "x86_64")]
841 KernelTier::Avx512 => unsafe {
842 crate::simd_fp8_avx512::dequant_fp8_e4m3_avx512(blocks, output)
843 },
844 #[cfg(target_arch = "x86_64")]
845 KernelTier::Avx2 => unsafe {
846 crate::simd_fp8_avx2::dequant_fp8_e4m3_avx2(blocks, output)
847 },
848 #[cfg(target_arch = "aarch64")]
849 KernelTier::Neon => unsafe {
850 crate::simd_fp8_neon::dequant_fp8_e4m3_neon(blocks, output)
851 },
852 _ => crate::dequant_fp8::dequant_fp8_e4m3(blocks, output),
853 }
854 }
855
856 fn dequant_fp8_e5m2(&self, blocks: &[BlockFP8E5M2], output: &mut [f32]) -> KernelResult<()> {
858 match self.tier {
859 #[cfg(target_arch = "x86_64")]
860 KernelTier::Avx512 => unsafe {
861 crate::simd_fp8_avx512::dequant_fp8_e5m2_avx512(blocks, output)
862 },
863 #[cfg(target_arch = "x86_64")]
864 KernelTier::Avx2 => unsafe {
865 crate::simd_fp8_avx2::dequant_fp8_e5m2_avx2(blocks, output)
866 },
867 #[cfg(target_arch = "aarch64")]
868 KernelTier::Neon => unsafe {
869 crate::simd_fp8_neon::dequant_fp8_e5m2_neon(blocks, output)
870 },
871 _ => crate::dequant_fp8::dequant_fp8_e5m2(blocks, output),
872 }
873 }
874
875 fn gemv_fp8_e4m3(
885 &self,
886 blocks: &[BlockFP8E4M3],
887 input: &[f32],
888 output: &mut [f32],
889 n_rows: usize,
890 k: usize,
891 ) -> KernelResult<()> {
892 #[cfg(all(feature = "metal", target_os = "macos"))]
894 {
895 let bytes = unsafe {
897 std::slice::from_raw_parts(
898 blocks.as_ptr().cast::<u8>(),
899 blocks.len() * oxibonsai_core::BLOCK_FP8_BYTES,
900 )
901 };
902 match crate::gpu_backend::metal_gemv_fp8_e4m3(bytes, input, output, n_rows, k) {
903 Ok(()) => return Ok(()),
904 Err(e) => {
905 let msg = e.to_string();
907 if !msg.contains("no Metal-capable GPU device") {
908 tracing::warn!(
909 error = %e,
910 "Metal FP8 E4M3 GEMV failed, falling back to CPU SIMD"
911 );
912 }
913 }
914 }
915 }
916
917 #[cfg(all(
919 feature = "native-cuda",
920 any(target_os = "linux", target_os = "windows")
921 ))]
922 {
923 let bytes = unsafe {
927 std::slice::from_raw_parts(
928 blocks.as_ptr().cast::<u8>(),
929 blocks.len() * oxibonsai_core::BLOCK_FP8_BYTES,
930 )
931 };
932 match crate::gpu_backend::cuda_gemv_fp8_e4m3(bytes, input, output, n_rows, k) {
933 Ok(()) => return Ok(()),
934 Err(e) => {
935 let msg = e.to_string();
938 if !msg.contains("no CUDA device") {
939 tracing::warn!(
940 error = %e,
941 "CUDA FP8 E4M3 GEMV failed, falling back to CPU SIMD"
942 );
943 }
944 }
945 }
946 }
947
948 match self.tier {
949 #[cfg(target_arch = "x86_64")]
950 KernelTier::Avx512 => unsafe {
951 crate::simd_fp8_avx512::gemv_fp8_e4m3_avx512(blocks, input, output, n_rows, k)
952 },
953 #[cfg(target_arch = "x86_64")]
954 KernelTier::Avx2 => unsafe {
955 crate::simd_fp8_avx2::gemv_fp8_e4m3_avx2(blocks, input, output, n_rows, k)
956 },
957 #[cfg(target_arch = "aarch64")]
958 KernelTier::Neon => unsafe {
959 crate::simd_fp8_neon::gemv_fp8_e4m3_neon(blocks, input, output, n_rows, k)
960 },
961 _ => crate::gemv_fp8::gemv_fp8_e4m3(blocks, input, output, n_rows, k),
962 }
963 }
964
965 fn gemv_fp8_e5m2(
971 &self,
972 blocks: &[BlockFP8E5M2],
973 input: &[f32],
974 output: &mut [f32],
975 n_rows: usize,
976 k: usize,
977 ) -> KernelResult<()> {
978 #[cfg(all(feature = "metal", target_os = "macos"))]
980 {
981 let bytes = unsafe {
983 std::slice::from_raw_parts(
984 blocks.as_ptr().cast::<u8>(),
985 blocks.len() * oxibonsai_core::BLOCK_FP8_BYTES,
986 )
987 };
988 match crate::gpu_backend::metal_gemv_fp8_e5m2(bytes, input, output, n_rows, k) {
989 Ok(()) => return Ok(()),
990 Err(e) => {
991 let msg = e.to_string();
992 if !msg.contains("no Metal-capable GPU device") {
993 tracing::warn!(
994 error = %e,
995 "Metal FP8 E5M2 GEMV failed, falling back to CPU SIMD"
996 );
997 }
998 }
999 }
1000 }
1001
1002 #[cfg(all(
1004 feature = "native-cuda",
1005 any(target_os = "linux", target_os = "windows")
1006 ))]
1007 {
1008 let bytes = unsafe {
1011 std::slice::from_raw_parts(
1012 blocks.as_ptr().cast::<u8>(),
1013 blocks.len() * oxibonsai_core::BLOCK_FP8_BYTES,
1014 )
1015 };
1016 match crate::gpu_backend::cuda_gemv_fp8_e5m2(bytes, input, output, n_rows, k) {
1017 Ok(()) => return Ok(()),
1018 Err(e) => {
1019 let msg = e.to_string();
1020 if !msg.contains("no CUDA device") {
1021 tracing::warn!(
1022 error = %e,
1023 "CUDA FP8 E5M2 GEMV failed, falling back to CPU SIMD"
1024 );
1025 }
1026 }
1027 }
1028 }
1029
1030 match self.tier {
1031 #[cfg(target_arch = "x86_64")]
1032 KernelTier::Avx512 => unsafe {
1033 crate::simd_fp8_avx512::gemv_fp8_e5m2_avx512(blocks, input, output, n_rows, k)
1034 },
1035 #[cfg(target_arch = "x86_64")]
1036 KernelTier::Avx2 => unsafe {
1037 crate::simd_fp8_avx2::gemv_fp8_e5m2_avx2(blocks, input, output, n_rows, k)
1038 },
1039 #[cfg(target_arch = "aarch64")]
1040 KernelTier::Neon => unsafe {
1041 crate::simd_fp8_neon::gemv_fp8_e5m2_neon(blocks, input, output, n_rows, k)
1042 },
1043 _ => crate::gemv_fp8::gemv_fp8_e5m2(blocks, input, output, n_rows, k),
1044 }
1045 }
1046
1047 fn gemm_fp8_e4m3(
1049 &self,
1050 blocks: &[BlockFP8E4M3],
1051 inputs: &[f32],
1052 outputs: &mut [f32],
1053 n_rows: usize,
1054 k: usize,
1055 batch: usize,
1056 ) -> KernelResult<()> {
1057 match self.tier {
1058 #[cfg(target_arch = "x86_64")]
1059 KernelTier::Avx512 => unsafe {
1060 crate::simd_fp8_avx512::gemm_fp8_e4m3_avx512(
1061 blocks, inputs, outputs, n_rows, k, batch,
1062 )
1063 },
1064 #[cfg(target_arch = "x86_64")]
1065 KernelTier::Avx2 => unsafe {
1066 crate::simd_fp8_avx2::gemm_fp8_e4m3_avx2(blocks, inputs, outputs, n_rows, k, batch)
1067 },
1068 #[cfg(target_arch = "aarch64")]
1069 KernelTier::Neon => unsafe {
1070 crate::simd_fp8_neon::gemm_fp8_e4m3_neon(blocks, inputs, outputs, n_rows, k, batch)
1071 },
1072 _ => crate::gemm_fp8::gemm_fp8_e4m3(blocks, inputs, outputs, n_rows, k, batch),
1073 }
1074 }
1075
1076 fn gemm_fp8_e5m2(
1078 &self,
1079 blocks: &[BlockFP8E5M2],
1080 inputs: &[f32],
1081 outputs: &mut [f32],
1082 n_rows: usize,
1083 k: usize,
1084 batch: usize,
1085 ) -> KernelResult<()> {
1086 match self.tier {
1087 #[cfg(target_arch = "x86_64")]
1088 KernelTier::Avx512 => unsafe {
1089 crate::simd_fp8_avx512::gemm_fp8_e5m2_avx512(
1090 blocks, inputs, outputs, n_rows, k, batch,
1091 )
1092 },
1093 #[cfg(target_arch = "x86_64")]
1094 KernelTier::Avx2 => unsafe {
1095 crate::simd_fp8_avx2::gemm_fp8_e5m2_avx2(blocks, inputs, outputs, n_rows, k, batch)
1096 },
1097 #[cfg(target_arch = "aarch64")]
1098 KernelTier::Neon => unsafe {
1099 crate::simd_fp8_neon::gemm_fp8_e5m2_neon(blocks, inputs, outputs, n_rows, k, batch)
1100 },
1101 _ => crate::gemm_fp8::gemm_fp8_e5m2(blocks, inputs, outputs, n_rows, k, batch),
1102 }
1103 }
1104
1105 fn name_fp8(&self) -> &'static str {
1106 match self.tier {
1107 #[cfg(target_arch = "x86_64")]
1108 KernelTier::Avx512 => "fp8_avx512",
1109 #[cfg(target_arch = "x86_64")]
1110 KernelTier::Avx2 => "fp8_avx2",
1111 #[cfg(target_arch = "aarch64")]
1112 KernelTier::Neon => "fp8_neon",
1113 _ => "fp8_reference",
1114 }
1115 }
1116}
1117
1118#[cfg(test)]
1119mod tests {
1120 use super::*;
1121
1122 #[test]
1123 fn auto_detect_creates_dispatcher() {
1124 let dispatcher = KernelDispatcher::auto_detect();
1125 let _tier = dispatcher.tier();
1127 let _name = dispatcher.name();
1128 }
1129
1130 #[cfg(target_arch = "x86_64")]
1133 #[test]
1134 fn cpu_feature_detection_uses_std() {
1135 let has_avx2 = is_x86_feature_detected!("avx2");
1142 let has_fma = is_x86_feature_detected!("fma");
1143
1144 let dispatcher = KernelDispatcher::auto_detect();
1145 let tier = dispatcher.tier();
1146
1147 if has_avx2 && has_fma {
1152 #[cfg(feature = "gpu")]
1153 let acceptable = matches!(
1154 tier,
1155 KernelTier::Avx2 | KernelTier::Avx512 | KernelTier::Gpu
1156 );
1157 #[cfg(not(feature = "gpu"))]
1158 let acceptable = matches!(tier, KernelTier::Avx2 | KernelTier::Avx512);
1159 assert!(
1160 acceptable,
1161 "Expected AVX2/AVX-512/GPU tier when AVX2+FMA detected, got {:?}",
1162 tier
1163 );
1164 }
1165 }
1166
1167 #[test]
1168 fn reference_tier_works() {
1169 let dispatcher = KernelDispatcher::with_tier(KernelTier::Reference);
1170 assert_eq!(dispatcher.tier(), KernelTier::Reference);
1171 assert_eq!(dispatcher.name(), "Q1_0_g128 reference (scalar)");
1172 }
1173
1174 #[cfg(target_arch = "x86_64")]
1175 #[test]
1176 fn avx2_tier_name() {
1177 if !(is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma")) {
1178 return;
1179 }
1180 let dispatcher = KernelDispatcher::with_tier(KernelTier::Avx2);
1181 assert_eq!(dispatcher.tier(), KernelTier::Avx2);
1182 assert_eq!(dispatcher.name(), "Q1_0_g128 AVX2+FMA (256-bit)");
1183 }
1184
1185 #[cfg(target_arch = "aarch64")]
1186 #[test]
1187 fn neon_tier_name() {
1188 let dispatcher = KernelDispatcher::with_tier(KernelTier::Neon);
1189 assert_eq!(dispatcher.tier(), KernelTier::Neon);
1190 assert_eq!(dispatcher.name(), "Q1_0_g128 NEON (128-bit)");
1191 }
1192
1193 #[test]
1194 fn dispatcher_exposes_ternary_gemv() {
1195 use crate::TernaryKernel;
1196 use half::f16;
1197 use oxibonsai_core::BlockTQ2_0_g128;
1198
1199 let dispatcher = KernelDispatcher::auto_detect();
1200
1201 let block_pos = BlockTQ2_0_g128 {
1204 qs: [0xAA; 32],
1205 d: f16::from_f32(1.0),
1206 };
1207 let block_neg = BlockTQ2_0_g128 {
1208 qs: [0x00; 32],
1209 d: f16::from_f32(1.0),
1210 };
1211 let blocks = vec![block_pos, block_neg];
1212 let input = vec![1.0f32; 128];
1213 let mut output = vec![0.0f32; 2];
1214
1215 dispatcher
1216 .gemv_ternary_g128(&blocks, &input, &mut output, 2, 128)
1217 .expect("gemv_ternary_g128 should succeed");
1218 assert!(
1219 (output[0] - 128.0).abs() < 1.0,
1220 "row0 expected ~128.0, got {}",
1221 output[0]
1222 );
1223 assert!(
1224 (output[1] + 128.0).abs() < 1.0,
1225 "row1 expected ~-128.0, got {}",
1226 output[1]
1227 );
1228 }
1229
1230 #[test]
1231 fn dispatcher_ternary_reference_tier() {
1232 use crate::TernaryKernel;
1233 use half::f16;
1234 use oxibonsai_core::BlockTQ2_0_g128;
1235
1236 let dispatcher = KernelDispatcher::with_tier(KernelTier::Reference);
1237 let blocks = vec![BlockTQ2_0_g128 {
1238 qs: [0xAA; 32],
1239 d: f16::from_f32(1.0),
1240 }];
1241 let input = vec![1.0f32; 128];
1242 let mut output = vec![0.0f32; 1];
1243
1244 dispatcher
1245 .gemv_ternary_g128(&blocks, &input, &mut output, 1, 128)
1246 .expect("gemv_ternary_g128 should succeed");
1247 assert!((output[0] - 128.0).abs() < 1.0);
1248 }
1249
1250 #[test]
1251 fn ternary_upload_non_gpu_returns_none() {
1252 use crate::TernaryKernel;
1253 use half::f16;
1254 use oxibonsai_core::BlockTQ2_0_g128;
1255
1256 let dispatcher = KernelDispatcher::with_tier(KernelTier::Reference);
1258 let block = BlockTQ2_0_g128 {
1259 qs: [0xAAu8; 32],
1260 d: f16::from_f32(1.0),
1261 };
1262 let handle = dispatcher.upload_weights_ternary(&[block]);
1263 assert!(
1264 handle.is_none(),
1265 "expected None for non-GPU tier, got {:?}",
1266 handle
1267 );
1268 }
1269}