1#[cfg(all(
26 feature = "native-cuda",
27 any(target_os = "linux", target_os = "windows")
28))]
29pub mod cuda_attn_kernels;
30#[cfg(all(
31 feature = "native-cuda",
32 any(target_os = "linux", target_os = "windows")
33))]
34pub mod cuda_fp8_kernels;
35#[cfg(all(
36 feature = "native-cuda",
37 any(target_os = "linux", target_os = "windows")
38))]
39pub mod cuda_fp8_prefill;
40#[cfg(all(
41 feature = "native-cuda",
42 any(target_os = "linux", target_os = "windows")
43))]
44pub mod cuda_fp8_prefill_kernels;
45#[cfg(all(
46 feature = "native-cuda",
47 any(target_os = "linux", target_os = "windows")
48))]
49pub mod cuda_full_layer;
50#[cfg(all(
51 feature = "native-cuda",
52 any(target_os = "linux", target_os = "windows")
53))]
54pub mod cuda_graph;
55#[cfg(all(
56 feature = "native-cuda",
57 any(target_os = "linux", target_os = "windows")
58))]
59pub mod cuda_k_quant_kernels;
60#[cfg(all(
61 feature = "native-cuda",
62 any(target_os = "linux", target_os = "windows")
63))]
64pub mod cuda_k_quant_prefill;
65#[cfg(all(
66 feature = "native-cuda",
67 any(target_os = "linux", target_os = "windows")
68))]
69pub mod cuda_k_quant_prefill_kernels;
70#[cfg(all(
71 feature = "native-cuda",
72 any(target_os = "linux", target_os = "windows")
73))]
74pub mod cuda_kernels;
75#[cfg(all(
76 feature = "native-cuda",
77 any(target_os = "linux", target_os = "windows")
78))]
79pub mod cuda_prefill;
80#[cfg(all(
81 feature = "native-cuda",
82 any(target_os = "linux", target_os = "windows")
83))]
84pub mod cuda_prefill_kernels;
85#[cfg(all(
86 feature = "native-cuda",
87 any(target_os = "linux", target_os = "windows")
88))]
89pub mod cuda_q_std_kernels;
90#[cfg(all(
91 feature = "native-cuda",
92 any(target_os = "linux", target_os = "windows")
93))]
94pub mod cuda_q_std_prefill;
95#[cfg(all(
96 feature = "native-cuda",
97 any(target_os = "linux", target_os = "windows")
98))]
99pub mod cuda_q_std_prefill_kernels;
100pub mod kernel_sources;
101#[cfg(all(feature = "metal", target_os = "macos"))]
102mod metal_dispatch;
103#[cfg(all(feature = "metal", target_os = "macos"))]
104pub mod metal_fp8_kernels;
105#[cfg(all(feature = "metal", target_os = "macos"))]
106pub mod metal_fp8_prefill;
107#[cfg(all(feature = "metal", target_os = "macos"))]
108pub mod metal_full_layer;
109#[cfg(all(feature = "metal", target_os = "macos"))]
110pub mod metal_graph;
111#[cfg(all(feature = "metal", target_os = "macos"))]
112mod metal_prefill;
113pub mod scirs2_backend;
114
115use thiserror::Error;
116#[allow(unused_imports)]
117use tracing::warn;
118
119#[cfg(feature = "gpu")]
120pub use scirs2_backend::Scirs2Backend;
121
122#[cfg(all(feature = "metal", target_os = "macos"))]
123pub use metal_fp8_kernels::{metal_gemv_fp8_e4m3, metal_gemv_fp8_e5m2};
124
125#[cfg(all(feature = "metal", target_os = "macos"))]
126pub use metal_fp8_prefill::{
127 metal_fused_gate_up_swiglu_fp8_e4m3, metal_fused_gate_up_swiglu_fp8_e5m2, metal_gemm_fp8_e4m3,
128 metal_gemm_fp8_e4m3_residual, metal_gemm_fp8_e5m2, metal_gemm_fp8_e5m2_residual,
129};
130
131#[cfg(all(feature = "metal", target_os = "macos"))]
132pub use metal_graph::{MetalGraph, MetalGraphError, MetalWeightHandle};
133
134#[cfg(all(feature = "metal", target_os = "macos"))]
135pub use metal_full_layer::{
136 build_cached_weights, build_cached_weights_ternary_only, print_gpu_profile_summary,
137 try_metal_ffn, try_metal_forward_greedy_ternary, try_metal_full_forward,
138 try_metal_full_forward_cached, try_metal_full_forward_ternary, try_metal_full_layer,
139 try_metal_prefill_ternary, try_metal_prefill_verify_ternary, try_metal_qkv, CachedLayerWeights,
140 CachedModelWeights, FullForwardLayerParams, FullForwardLayerParamsTernary,
141};
142
143#[cfg(all(feature = "metal", target_os = "macos"))]
144pub use metal_prefill::{
145 try_metal_full_forward_prefill, try_metal_full_forward_prefill_ternary,
146 try_metal_full_forward_prefill_verify, try_metal_full_forward_prefill_verify_ternary,
147};
148
149#[cfg(all(
150 feature = "native-cuda",
151 any(target_os = "linux", target_os = "windows")
152))]
153pub use cuda_graph::{try_cuda_ffn, try_cuda_qkv, CudaGraph, CudaGraphError, NativeCudaBackend};
154
155#[cfg(all(
156 feature = "native-cuda",
157 any(target_os = "linux", target_os = "windows")
158))]
159pub use cuda_full_layer::{
160 try_cuda_full_forward, try_cuda_full_forward_ternary,
161 try_cuda_full_forward_ternary_with_gpu_lm_head, try_cuda_full_forward_with_gpu_lm_head,
162 try_cuda_full_layer, CudaCachedLayerWeights, CudaFullForwardLayerParams,
163 CudaFullForwardLayerParamsTernary,
164};
165
166#[cfg(all(
167 feature = "native-cuda",
168 any(target_os = "linux", target_os = "windows")
169))]
170pub use cuda_prefill::{try_cuda_prefill, try_cuda_prefill_ternary};
171
172#[cfg(all(
173 feature = "native-cuda",
174 any(target_os = "linux", target_os = "windows")
175))]
176pub use cuda_fp8_kernels::{cuda_gemv_fp8_e4m3, cuda_gemv_fp8_e5m2};
177
178#[cfg(all(
179 feature = "native-cuda",
180 any(target_os = "linux", target_os = "windows")
181))]
182pub use cuda_k_quant_kernels::{
183 cuda_gemv_q2k, cuda_gemv_q3k, cuda_gemv_q4k, cuda_gemv_q5k, cuda_gemv_q6k, cuda_gemv_q8k,
184};
185#[cfg(all(
186 feature = "native-cuda",
187 any(target_os = "linux", target_os = "windows")
188))]
189pub use cuda_q_std_kernels::{cuda_gemv_q4_0, cuda_gemv_q8_0};
190
191#[cfg(all(
192 feature = "native-cuda",
193 any(target_os = "linux", target_os = "windows")
194))]
195pub use cuda_q_std_prefill::{try_cuda_prefill_q_std, CudaQStdPrefillLayerParams};
196
197#[cfg(all(
198 feature = "native-cuda",
199 any(target_os = "linux", target_os = "windows")
200))]
201pub use cuda_k_quant_prefill::{
202 try_cuda_prefill_k_quant, CudaKQuantPrefillLayerParams, KQuantFormat,
203};
204
205#[cfg(all(
206 feature = "native-cuda",
207 any(target_os = "linux", target_os = "windows")
208))]
209pub use cuda_fp8_prefill::{try_cuda_prefill_fp8, CudaFP8PrefillLayerParams};
210
211pub struct DeviceBuffer {
221 pub data: Vec<f32>,
223 pub size: usize,
225 pub device_id: usize,
227}
228
229impl DeviceBuffer {
230 pub fn new(size: usize, device_id: usize) -> Self {
232 Self {
233 data: vec![0.0_f32; size],
234 size,
235 device_id,
236 }
237 }
238
239 pub fn from_slice(data: &[f32], device_id: usize) -> Self {
241 let size = data.len();
242 Self {
243 data: data.to_vec(),
244 size,
245 device_id,
246 }
247 }
248
249 pub fn to_vec(&self) -> Vec<f32> {
251 self.data.clone()
252 }
253
254 pub fn size(&self) -> usize {
256 self.size
257 }
258
259 pub fn device_id(&self) -> usize {
261 self.device_id
262 }
263}
264
265#[derive(Debug, Clone, Copy, PartialEq, Eq)]
274pub struct LaunchConfig {
275 pub grid_dim: (u32, u32, u32),
277 pub block_dim: (u32, u32, u32),
279 pub shared_mem_bytes: u32,
281}
282
283const DEFAULT_BLOCK_SIZE: u32 = 256;
285
286impl LaunchConfig {
287 pub fn for_n_elements(n: usize) -> Self {
292 let block = DEFAULT_BLOCK_SIZE;
293 let grid = ((n as u32).saturating_add(block - 1)) / block;
294 Self {
295 grid_dim: (grid.max(1), 1, 1),
296 block_dim: (block, 1, 1),
297 shared_mem_bytes: 0,
298 }
299 }
300
301 pub fn default_1d() -> Self {
303 Self {
304 grid_dim: (1, 1, 1),
305 block_dim: (DEFAULT_BLOCK_SIZE, 1, 1),
306 shared_mem_bytes: 0,
307 }
308 }
309}
310
311#[derive(Debug, Error)]
317pub enum GpuError {
318 #[error("GPU not available: {0}")]
320 NotAvailable(String),
321
322 #[error("out of device memory: requested {requested} bytes on device {device}")]
324 OutOfMemory {
325 requested: usize,
327 device: usize,
329 },
330
331 #[error("kernel launch failed: {0}")]
333 KernelLaunch(String),
334
335 #[error("device synchronization failed: {0}")]
337 SyncFailed(String),
338
339 #[error("invalid argument: {0}")]
341 InvalidArgument(String),
342}
343
344pub trait GpuBackendTrait: Send + Sync {
359 fn name(&self) -> &'static str;
361
362 fn is_accelerated(&self) -> bool;
364
365 fn device_count(&self) -> usize;
367
368 fn alloc(&self, size: usize, device_id: usize) -> Result<DeviceBuffer, GpuError>;
370
371 fn host_to_device(&self, src: &[f32], device_id: usize) -> Result<DeviceBuffer, GpuError>;
373
374 fn device_to_host(&self, buf: &DeviceBuffer) -> Result<Vec<f32>, GpuError>;
376
377 fn matvec(
383 &self,
384 a: &DeviceBuffer,
385 x: &DeviceBuffer,
386 m: usize,
387 k: usize,
388 device_id: usize,
389 ) -> Result<DeviceBuffer, GpuError>;
390
391 fn relu(&self, x: &DeviceBuffer, device_id: usize) -> Result<DeviceBuffer, GpuError>;
393
394 fn softmax(
396 &self,
397 x: &DeviceBuffer,
398 size: usize,
399 device_id: usize,
400 ) -> Result<DeviceBuffer, GpuError>;
401
402 fn synchronize(&self, device_id: usize) -> Result<(), GpuError>;
404
405 fn memory_info(&self, device_id: usize) -> Result<(usize, usize), GpuError>;
407
408 fn gemv_q1_g128(
413 &self,
414 block_bytes: &[u8],
415 input: &[f32],
416 n_rows: usize,
417 k: usize,
418 ) -> Result<Vec<f32>, GpuError> {
419 cpu_gemv_1bit_fallback(block_bytes, input, n_rows, k)
420 }
421
422 fn gemm_q1_g128(
426 &self,
427 block_bytes: &[u8],
428 input: &[f32],
429 m: usize,
430 n_rows: usize,
431 k: usize,
432 ) -> Result<Vec<f32>, GpuError> {
433 let mut output = vec![0.0_f32; m * n_rows];
434 for i in 0..m {
435 let row_input = &input[i * k..(i + 1) * k];
436 let row_output = self.gemv_q1_g128(block_bytes, row_input, n_rows, k)?;
437 output[i * n_rows..(i + 1) * n_rows].copy_from_slice(&row_output);
438 }
439 Ok(output)
440 }
441
442 fn upload_weights_raw(
446 &self,
447 _block_bytes: &[u8],
448 ) -> Result<crate::weight_cache::GpuWeightHandle, GpuError> {
449 Err(GpuError::NotAvailable(
450 "weight caching not supported by this backend".into(),
451 ))
452 }
453
454 fn gemv_q1_g128_cached(
458 &self,
459 _handle: crate::weight_cache::GpuWeightHandle,
460 _input: &[f32],
461 _n_rows: usize,
462 _k: usize,
463 ) -> Result<Vec<f32>, GpuError> {
464 Err(GpuError::NotAvailable(
465 "cached GEMV not supported by this backend".into(),
466 ))
467 }
468
469 fn upload_weights_ternary(
473 &self,
474 _blocks: &[oxibonsai_core::BlockTQ2_0_g128],
475 ) -> Result<crate::weight_cache::GpuWeightHandle, GpuError> {
476 Err(GpuError::NotAvailable(
477 "ternary weight upload not supported by this backend".into(),
478 ))
479 }
480
481 fn gemv_tq2_g128_cached(
485 &self,
486 _handle: crate::weight_cache::GpuWeightHandle,
487 _input: &[f32],
488 _n_rows: usize,
489 _k: usize,
490 ) -> Result<Vec<f32>, GpuError> {
491 Err(GpuError::NotAvailable(
492 "cached ternary GEMV not supported by this backend".into(),
493 ))
494 }
495
496 #[allow(clippy::too_many_arguments, clippy::type_complexity)]
501 fn batch_attn_phase(
502 &self,
503 _hidden: &[f32],
504 _norm_weight: &[f32],
505 _norm_eps: f32,
506 _qkv_handle: crate::weight_cache::GpuWeightHandle,
507 _q_rows: usize,
508 _k_rows: usize,
509 _h: usize,
510 ) -> Result<Option<(Vec<f32>, Vec<f32>, Vec<f32>)>, GpuError> {
511 Ok(None)
512 }
513
514 #[allow(clippy::too_many_arguments)]
519 fn batch_ffn_phase(
520 &self,
521 _hidden: &mut [f32],
522 _attn_out: &[f32],
523 _norm_weight: &[f32],
524 _norm_eps: f32,
525 _attn_proj_handle: crate::weight_cache::GpuWeightHandle,
526 _gate_up_handle: crate::weight_cache::GpuWeightHandle,
527 _down_handle: crate::weight_cache::GpuWeightHandle,
528 _h: usize,
529 _intermediate: usize,
530 _attn_proj_k: usize,
531 ) -> Result<bool, GpuError> {
532 Ok(false)
533 }
534}
535
536pub type GpuBackend = dyn GpuBackendTrait;
541
542pub struct CpuBackend {
550 pub simulated_memory_bytes: usize,
552}
553
554impl CpuBackend {
555 pub fn new() -> Self {
557 Self {
558 simulated_memory_bytes: 4 * 1024 * 1024 * 1024,
559 }
560 }
561
562 pub fn with_memory(bytes: usize) -> Self {
564 Self {
565 simulated_memory_bytes: bytes,
566 }
567 }
568}
569
570impl Default for CpuBackend {
571 fn default() -> Self {
572 Self::new()
573 }
574}
575
576impl GpuBackendTrait for CpuBackend {
577 fn name(&self) -> &'static str {
578 "cpu"
579 }
580
581 fn is_accelerated(&self) -> bool {
582 false
583 }
584
585 fn device_count(&self) -> usize {
586 1
587 }
588
589 fn alloc(&self, size: usize, device_id: usize) -> Result<DeviceBuffer, GpuError> {
590 Ok(DeviceBuffer::new(size, device_id))
591 }
592
593 fn host_to_device(&self, src: &[f32], device_id: usize) -> Result<DeviceBuffer, GpuError> {
594 Ok(DeviceBuffer::from_slice(src, device_id))
595 }
596
597 fn device_to_host(&self, buf: &DeviceBuffer) -> Result<Vec<f32>, GpuError> {
598 Ok(buf.to_vec())
599 }
600
601 fn matvec(
602 &self,
603 a: &DeviceBuffer,
604 x: &DeviceBuffer,
605 m: usize,
606 k: usize,
607 device_id: usize,
608 ) -> Result<DeviceBuffer, GpuError> {
609 if a.size() != m * k {
610 return Err(GpuError::InvalidArgument(format!(
611 "matrix buffer size {} does not match m={} k={}",
612 a.size(),
613 m,
614 k
615 )));
616 }
617 if x.size() != k {
618 return Err(GpuError::InvalidArgument(format!(
619 "vector buffer size {} does not match k={}",
620 x.size(),
621 k
622 )));
623 }
624
625 let mut result = vec![0.0_f32; m];
626 for (row, slot) in result.iter_mut().enumerate().take(m) {
627 let mut acc = 0.0_f32;
628 for col in 0..k {
629 acc += a.data[row * k + col] * x.data[col];
630 }
631 *slot = acc;
632 }
633
634 Ok(DeviceBuffer::from_slice(&result, device_id))
635 }
636
637 fn relu(&self, x: &DeviceBuffer, device_id: usize) -> Result<DeviceBuffer, GpuError> {
638 let result: Vec<f32> = x.data.iter().map(|&v| v.max(0.0)).collect();
639 Ok(DeviceBuffer::from_slice(&result, device_id))
640 }
641
642 fn softmax(
643 &self,
644 x: &DeviceBuffer,
645 size: usize,
646 device_id: usize,
647 ) -> Result<DeviceBuffer, GpuError> {
648 if x.size() != size {
649 return Err(GpuError::InvalidArgument(format!(
650 "buffer size {} does not match size={}",
651 x.size(),
652 size
653 )));
654 }
655 if size == 0 {
656 return Ok(DeviceBuffer::new(0, device_id));
657 }
658
659 let max_val = x.data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
660 let exps: Vec<f32> = x.data.iter().map(|&v| (v - max_val).exp()).collect();
661 let sum: f32 = exps.iter().sum();
662
663 let result: Vec<f32> = if sum == 0.0 {
664 vec![1.0 / size as f32; size]
665 } else {
666 exps.iter().map(|&e| e / sum).collect()
667 };
668
669 Ok(DeviceBuffer::from_slice(&result, device_id))
670 }
671
672 fn synchronize(&self, _device_id: usize) -> Result<(), GpuError> {
673 Ok(())
674 }
675
676 fn memory_info(&self, _device_id: usize) -> Result<(usize, usize), GpuError> {
677 let total = self.simulated_memory_bytes;
678 let free = total / 2;
679 Ok((free, total))
680 }
681}
682
683#[cfg(feature = "cuda")]
692pub struct CudaBackend {
693 pub device_count: usize,
695 cpu_fallback: CpuBackend,
696}
697
698#[cfg(feature = "cuda")]
699impl CudaBackend {
700 pub fn new() -> Result<Self, GpuError> {
702 warn!("CudaBackend: CUDA stub active — no real GPU acceleration");
703 Ok(Self {
704 device_count: 1,
705 cpu_fallback: CpuBackend::new(),
706 })
707 }
708}
709
710#[cfg(feature = "cuda")]
711impl GpuBackendTrait for CudaBackend {
712 fn name(&self) -> &'static str {
713 "cuda"
714 }
715
716 fn is_accelerated(&self) -> bool {
717 false
718 }
719
720 fn device_count(&self) -> usize {
721 self.device_count
722 }
723
724 fn alloc(&self, size: usize, device_id: usize) -> Result<DeviceBuffer, GpuError> {
725 warn!("CudaBackend::alloc delegating to CPU fallback");
726 self.cpu_fallback.alloc(size, device_id)
727 }
728
729 fn host_to_device(&self, src: &[f32], device_id: usize) -> Result<DeviceBuffer, GpuError> {
730 warn!("CudaBackend::host_to_device delegating to CPU fallback");
731 self.cpu_fallback.host_to_device(src, device_id)
732 }
733
734 fn device_to_host(&self, buf: &DeviceBuffer) -> Result<Vec<f32>, GpuError> {
735 warn!("CudaBackend::device_to_host delegating to CPU fallback");
736 self.cpu_fallback.device_to_host(buf)
737 }
738
739 fn matvec(
740 &self,
741 a: &DeviceBuffer,
742 x: &DeviceBuffer,
743 m: usize,
744 k: usize,
745 device_id: usize,
746 ) -> Result<DeviceBuffer, GpuError> {
747 warn!("CudaBackend::matvec delegating to CPU fallback");
748 self.cpu_fallback.matvec(a, x, m, k, device_id)
749 }
750
751 fn relu(&self, x: &DeviceBuffer, device_id: usize) -> Result<DeviceBuffer, GpuError> {
752 warn!("CudaBackend::relu delegating to CPU fallback");
753 self.cpu_fallback.relu(x, device_id)
754 }
755
756 fn softmax(
757 &self,
758 x: &DeviceBuffer,
759 size: usize,
760 device_id: usize,
761 ) -> Result<DeviceBuffer, GpuError> {
762 warn!("CudaBackend::softmax delegating to CPU fallback");
763 self.cpu_fallback.softmax(x, size, device_id)
764 }
765
766 fn synchronize(&self, device_id: usize) -> Result<(), GpuError> {
767 warn!("CudaBackend::synchronize delegating to CPU fallback");
768 self.cpu_fallback.synchronize(device_id)
769 }
770
771 fn memory_info(&self, device_id: usize) -> Result<(usize, usize), GpuError> {
772 warn!("CudaBackend::memory_info delegating to CPU fallback");
773 self.cpu_fallback.memory_info(device_id)
774 }
775}
776
777#[cfg(all(feature = "metal", target_os = "macos"))]
785pub struct MetalBackend {
786 pub device_count: usize,
788 cpu_fallback: CpuBackend,
789}
790
791#[cfg(all(feature = "metal", target_os = "macos"))]
792impl MetalBackend {
793 pub fn new() -> Result<Self, GpuError> {
795 warn!("MetalBackend: Metal stub active — no real GPU acceleration");
796 Ok(Self {
797 device_count: 1,
798 cpu_fallback: CpuBackend::new(),
799 })
800 }
801}
802
803#[cfg(all(feature = "metal", target_os = "macos"))]
804impl GpuBackendTrait for MetalBackend {
805 fn name(&self) -> &'static str {
806 "metal"
807 }
808
809 fn is_accelerated(&self) -> bool {
810 false
811 }
812
813 fn device_count(&self) -> usize {
814 self.device_count
815 }
816
817 fn alloc(&self, size: usize, device_id: usize) -> Result<DeviceBuffer, GpuError> {
818 warn!("MetalBackend::alloc delegating to CPU fallback");
819 self.cpu_fallback.alloc(size, device_id)
820 }
821
822 fn host_to_device(&self, src: &[f32], device_id: usize) -> Result<DeviceBuffer, GpuError> {
823 warn!("MetalBackend::host_to_device delegating to CPU fallback");
824 self.cpu_fallback.host_to_device(src, device_id)
825 }
826
827 fn device_to_host(&self, buf: &DeviceBuffer) -> Result<Vec<f32>, GpuError> {
828 warn!("MetalBackend::device_to_host delegating to CPU fallback");
829 self.cpu_fallback.device_to_host(buf)
830 }
831
832 fn matvec(
833 &self,
834 a: &DeviceBuffer,
835 x: &DeviceBuffer,
836 m: usize,
837 k: usize,
838 device_id: usize,
839 ) -> Result<DeviceBuffer, GpuError> {
840 warn!("MetalBackend::matvec delegating to CPU fallback");
841 self.cpu_fallback.matvec(a, x, m, k, device_id)
842 }
843
844 fn relu(&self, x: &DeviceBuffer, device_id: usize) -> Result<DeviceBuffer, GpuError> {
845 warn!("MetalBackend::relu delegating to CPU fallback");
846 self.cpu_fallback.relu(x, device_id)
847 }
848
849 fn softmax(
850 &self,
851 x: &DeviceBuffer,
852 size: usize,
853 device_id: usize,
854 ) -> Result<DeviceBuffer, GpuError> {
855 warn!("MetalBackend::softmax delegating to CPU fallback");
856 self.cpu_fallback.softmax(x, size, device_id)
857 }
858
859 fn synchronize(&self, device_id: usize) -> Result<(), GpuError> {
860 warn!("MetalBackend::synchronize delegating to CPU fallback");
861 self.cpu_fallback.synchronize(device_id)
862 }
863
864 fn memory_info(&self, device_id: usize) -> Result<(usize, usize), GpuError> {
865 warn!("MetalBackend::memory_info delegating to CPU fallback");
866 self.cpu_fallback.memory_info(device_id)
867 }
868}
869
870#[cfg(feature = "gpu")]
879pub(crate) struct Scirs2BackendHandle(pub(crate) std::sync::Arc<Scirs2Backend>);
880
881#[cfg(feature = "gpu")]
882impl GpuBackendTrait for Scirs2BackendHandle {
883 fn name(&self) -> &'static str {
884 self.0.name()
885 }
886 fn is_accelerated(&self) -> bool {
887 self.0.is_accelerated()
888 }
889 fn device_count(&self) -> usize {
890 self.0.device_count()
891 }
892 fn alloc(&self, size: usize, device_id: usize) -> Result<DeviceBuffer, GpuError> {
893 self.0.alloc(size, device_id)
894 }
895 fn host_to_device(&self, src: &[f32], device_id: usize) -> Result<DeviceBuffer, GpuError> {
896 self.0.host_to_device(src, device_id)
897 }
898 fn device_to_host(&self, buf: &DeviceBuffer) -> Result<Vec<f32>, GpuError> {
899 self.0.device_to_host(buf)
900 }
901 fn matvec(
902 &self,
903 a: &DeviceBuffer,
904 x: &DeviceBuffer,
905 m: usize,
906 k: usize,
907 device_id: usize,
908 ) -> Result<DeviceBuffer, GpuError> {
909 self.0.matvec(a, x, m, k, device_id)
910 }
911 fn relu(&self, x: &DeviceBuffer, device_id: usize) -> Result<DeviceBuffer, GpuError> {
912 self.0.relu(x, device_id)
913 }
914 fn softmax(
915 &self,
916 x: &DeviceBuffer,
917 size: usize,
918 device_id: usize,
919 ) -> Result<DeviceBuffer, GpuError> {
920 self.0.softmax(x, size, device_id)
921 }
922 fn synchronize(&self, device_id: usize) -> Result<(), GpuError> {
923 self.0.synchronize(device_id)
924 }
925 fn memory_info(&self, device_id: usize) -> Result<(usize, usize), GpuError> {
926 self.0.memory_info(device_id)
927 }
928 fn gemv_q1_g128(
929 &self,
930 block_bytes: &[u8],
931 input: &[f32],
932 n_rows: usize,
933 k: usize,
934 ) -> Result<Vec<f32>, GpuError> {
935 self.0.gemv_q1_g128(block_bytes, input, n_rows, k)
936 }
937 fn gemm_q1_g128(
938 &self,
939 block_bytes: &[u8],
940 input: &[f32],
941 m: usize,
942 n_rows: usize,
943 k: usize,
944 ) -> Result<Vec<f32>, GpuError> {
945 self.0.gemm_q1_g128(block_bytes, input, m, n_rows, k)
946 }
947 fn upload_weights_raw(
948 &self,
949 block_bytes: &[u8],
950 ) -> Result<crate::weight_cache::GpuWeightHandle, GpuError> {
951 self.0.upload_weights(block_bytes)
952 }
953 fn gemv_q1_g128_cached(
954 &self,
955 handle: crate::weight_cache::GpuWeightHandle,
956 input: &[f32],
957 n_rows: usize,
958 k: usize,
959 ) -> Result<Vec<f32>, GpuError> {
960 self.0.gemv_q1_g128_cached(handle, input, n_rows, k)
961 }
962
963 fn upload_weights_ternary(
964 &self,
965 blocks: &[oxibonsai_core::BlockTQ2_0_g128],
966 ) -> Result<crate::weight_cache::GpuWeightHandle, GpuError> {
967 self.0.upload_weights_ternary(blocks)
968 }
969
970 fn gemv_tq2_g128_cached(
971 &self,
972 handle: crate::weight_cache::GpuWeightHandle,
973 input: &[f32],
974 n_rows: usize,
975 k: usize,
976 ) -> Result<Vec<f32>, GpuError> {
977 self.0.gemv_tq2_g128_cached(handle, input, n_rows, k)
978 }
979
980 fn batch_attn_phase(
981 &self,
982 hidden: &[f32],
983 norm_weight: &[f32],
984 norm_eps: f32,
985 qkv_handle: crate::weight_cache::GpuWeightHandle,
986 q_rows: usize,
987 k_rows: usize,
988 h: usize,
989 ) -> Result<Option<(Vec<f32>, Vec<f32>, Vec<f32>)>, GpuError> {
990 match self
991 .0
992 .batch_attn_phase(hidden, norm_weight, norm_eps, qkv_handle, q_rows, k_rows, h)
993 {
994 Ok(result) => Ok(Some(result)),
995 Err(e) => {
996 tracing::warn!(error = %e, "batch_attn_phase failed, falling back");
997 Ok(None)
998 }
999 }
1000 }
1001
1002 fn batch_ffn_phase(
1003 &self,
1004 hidden: &mut [f32],
1005 attn_out: &[f32],
1006 norm_weight: &[f32],
1007 norm_eps: f32,
1008 attn_proj_handle: crate::weight_cache::GpuWeightHandle,
1009 gate_up_handle: crate::weight_cache::GpuWeightHandle,
1010 down_handle: crate::weight_cache::GpuWeightHandle,
1011 h: usize,
1012 intermediate: usize,
1013 attn_proj_k: usize,
1014 ) -> Result<bool, GpuError> {
1015 match self.0.batch_ffn_phase(
1016 hidden,
1017 attn_out,
1018 norm_weight,
1019 norm_eps,
1020 attn_proj_handle,
1021 gate_up_handle,
1022 down_handle,
1023 h,
1024 intermediate,
1025 attn_proj_k,
1026 ) {
1027 Ok(()) => Ok(true),
1028 Err(e) => {
1029 tracing::warn!(error = %e, "batch_ffn_phase failed, falling back");
1030 Ok(false)
1031 }
1032 }
1033 }
1034}
1035
1036pub fn select_backend() -> Box<dyn GpuBackendTrait> {
1052 #[cfg(feature = "gpu")]
1057 use std::sync::atomic::{AtomicBool, Ordering};
1058 #[cfg(feature = "gpu")]
1059 fn warn_once(flag: &AtomicBool, msg: impl FnOnce()) {
1060 if !flag.swap(true, Ordering::Relaxed) {
1061 msg();
1062 }
1063 }
1064
1065 #[cfg(feature = "gpu")]
1067 {
1068 static SCIRS2_NOT_ACCEL: AtomicBool = AtomicBool::new(false);
1069 static SCIRS2_INIT_FAIL: AtomicBool = AtomicBool::new(false);
1070 match Scirs2Backend::global() {
1071 Ok(b) => {
1072 if b.is_accelerated() {
1073 return Box::new(Scirs2BackendHandle(b));
1074 }
1075 warn_once(&SCIRS2_NOT_ACCEL, || {
1077 warn!(
1078 "select_backend: Scirs2Backend is not accelerated (backend={}), trying next",
1079 b.backend_name()
1080 );
1081 });
1082 }
1083 Err(e) => {
1084 warn_once(&SCIRS2_INIT_FAIL, || {
1085 warn!("select_backend: Scirs2Backend init failed ({e}), trying next");
1086 });
1087 }
1088 }
1089 }
1090
1091 #[cfg(all(
1093 feature = "native-cuda",
1094 any(target_os = "linux", target_os = "windows")
1095 ))]
1096 {
1097 match NativeCudaBackend::new() {
1098 Ok(b) => {
1099 tracing::info!("select_backend: NativeCudaBackend initialised");
1100 return Box::new(b);
1101 }
1102 Err(e) => {
1103 warn!("select_backend: NativeCudaBackend init failed ({e}), trying next");
1104 }
1105 }
1106 }
1107
1108 #[cfg(feature = "cuda")]
1110 {
1111 match CudaBackend::new() {
1112 Ok(b) => {
1113 return Box::new(b);
1114 }
1115 Err(e) => {
1116 warn!("select_backend: CUDA init failed ({e}), trying next");
1117 }
1118 }
1119 }
1120
1121 #[cfg(all(feature = "metal", target_os = "macos"))]
1123 {
1124 match MetalBackend::new() {
1125 Ok(b) => {
1126 return Box::new(b);
1127 }
1128 Err(e) => {
1129 warn!("select_backend: Metal init failed ({e}), trying CPU");
1130 }
1131 }
1132 }
1133
1134 Box::new(CpuBackend::new())
1136}
1137
1138pub fn gpu_matmul(
1152 backend: &dyn GpuBackendTrait,
1153 a: &[f32],
1154 b: &[f32],
1155 m: usize,
1156 k: usize,
1157 n: usize,
1158 device_id: usize,
1159) -> Result<Vec<f32>, GpuError> {
1160 if a.len() != m * k {
1161 return Err(GpuError::InvalidArgument(format!(
1162 "a.len()={} does not match m={} k={}",
1163 a.len(),
1164 m,
1165 k
1166 )));
1167 }
1168 if b.len() != k * n {
1169 return Err(GpuError::InvalidArgument(format!(
1170 "b.len()={} does not match k={} n={}",
1171 b.len(),
1172 k,
1173 n
1174 )));
1175 }
1176
1177 let a_buf = backend.host_to_device(a, device_id)?;
1178
1179 let mut c = vec![0.0_f32; m * n];
1180
1181 for col in 0..n {
1182 let b_col: Vec<f32> = (0..k).map(|row| b[row * n + col]).collect();
1183 let x_buf = backend.host_to_device(&b_col, device_id)?;
1184 let y_buf = backend.matvec(&a_buf, &x_buf, m, k, device_id)?;
1185 let y = backend.device_to_host(&y_buf)?;
1186
1187 for row in 0..m {
1188 c[row * n + col] = y[row];
1189 }
1190 }
1191
1192 backend.synchronize(device_id)?;
1193 Ok(c)
1194}
1195
1196pub fn gpu_gemv_1bit(
1219 block_bytes: &[u8],
1220 input: &[f32],
1221 n_rows: usize,
1222 k: usize,
1223) -> Result<Vec<f32>, GpuError> {
1224 #[cfg(feature = "gpu")]
1225 {
1226 match Scirs2Backend::global() {
1227 Ok(backend) => {
1228 if backend.is_accelerated() {
1229 return backend.gemv_q1_g128(block_bytes, input, n_rows, k);
1230 }
1231 }
1233 Err(e) => {
1234 warn!("gpu_gemv_1bit: GPU init failed ({e}), using CPU fallback");
1235 }
1236 }
1237 }
1238
1239 cpu_gemv_1bit_fallback(block_bytes, input, n_rows, k)
1241}
1242
1243fn cpu_gemv_1bit_fallback(
1247 block_bytes: &[u8],
1248 input: &[f32],
1249 n_rows: usize,
1250 k: usize,
1251) -> Result<Vec<f32>, GpuError> {
1252 if k == 0 || k % 128 != 0 {
1253 return Err(GpuError::InvalidArgument(format!(
1254 "k={k} must be a positive multiple of 128"
1255 )));
1256 }
1257 if input.len() != k {
1258 return Err(GpuError::InvalidArgument(format!(
1259 "input.len()={} != k={}",
1260 input.len(),
1261 k
1262 )));
1263 }
1264 let blocks_per_row = k / 128;
1265 let block_size = 18_usize;
1266 let expected = n_rows * blocks_per_row * block_size;
1267 if block_bytes.len() < expected {
1268 return Err(GpuError::InvalidArgument(format!(
1269 "block_bytes too small: {} < {}",
1270 block_bytes.len(),
1271 expected,
1272 )));
1273 }
1274
1275 let mut output = vec![0.0_f32; n_rows];
1276
1277 for (row, output_val) in output.iter_mut().enumerate().take(n_rows) {
1278 let mut sum = 0.0_f32;
1279 for b in 0..blocks_per_row {
1280 let block_idx = row * blocks_per_row + b;
1281 let off = block_idx * block_size;
1282
1283 let d_bits = u16::from_le_bytes([block_bytes[off], block_bytes[off + 1]]);
1285 let scale = half::f16::from_bits(d_bits).to_f32();
1286
1287 let input_base = b * 128;
1288 for w in 0..4_usize {
1290 let byte_off = off + 2 + w * 4;
1291 let bits = u32::from_le_bytes([
1292 block_bytes[byte_off],
1293 block_bytes[byte_off + 1],
1294 block_bytes[byte_off + 2],
1295 block_bytes[byte_off + 3],
1296 ]);
1297 let base = input_base + w * 32;
1298 for i in 0..32_usize {
1299 let sign = if (bits >> i) & 1 == 1 {
1300 1.0_f32
1301 } else {
1302 -1.0_f32
1303 };
1304 sum += scale * sign * input[base + i];
1305 }
1306 }
1307 }
1308 *output_val = sum;
1309 }
1310
1311 Ok(output)
1312}
1313
1314#[cfg(test)]
1319mod tests {
1320 use super::*;
1321
1322 #[test]
1323 fn device_buffer_new_zeroed() {
1324 let buf = DeviceBuffer::new(4, 0);
1325 assert_eq!(buf.size(), 4);
1326 assert_eq!(buf.device_id(), 0);
1327 assert!(buf.data.iter().all(|&v| v == 0.0));
1328 }
1329
1330 #[test]
1331 fn device_buffer_from_slice_roundtrip() {
1332 let src = [1.0_f32, 2.0, 3.0];
1333 let buf = DeviceBuffer::from_slice(&src, 1);
1334 assert_eq!(buf.to_vec(), src);
1335 }
1336
1337 #[test]
1338 fn launch_config_for_zero_elements() {
1339 let cfg = LaunchConfig::for_n_elements(0);
1340 assert_eq!(cfg.grid_dim.0, 1);
1341 }
1342
1343 #[test]
1344 fn cpu_softmax_empty() {
1345 let backend = CpuBackend::new();
1346 let buf = DeviceBuffer::new(0, 0);
1347 let out = backend.softmax(&buf, 0, 0).expect("softmax empty");
1348 assert_eq!(out.size(), 0);
1349 }
1350
1351 #[test]
1354 fn cpu_gemv_1bit_identity_scale() {
1355 let scale = half::f16::from_f32(1.0);
1357 let scale_bytes = scale.to_bits().to_le_bytes();
1358
1359 let mut block = vec![0u8; 18];
1360 block[0] = scale_bytes[0];
1361 block[1] = scale_bytes[1];
1362 block[2..18].fill(0xFF);
1364
1365 let input: Vec<f32> = (0..128).map(|i| i as f32).collect();
1366 let expected: f32 = input.iter().sum(); let result =
1369 cpu_gemv_1bit_fallback(&block, &input, 1, 128).expect("cpu_gemv_1bit_fallback");
1370 assert!(
1371 (result[0] - expected).abs() < 1e-2,
1372 "got {} expected {}",
1373 result[0],
1374 expected,
1375 );
1376 }
1377
1378 #[test]
1379 fn cpu_gemv_1bit_negative_scale() {
1380 let scale = half::f16::from_f32(1.0);
1383 let scale_bytes = scale.to_bits().to_le_bytes();
1384
1385 let mut block = vec![0u8; 18];
1386 block[0] = scale_bytes[0];
1387 block[1] = scale_bytes[1];
1388 let input = vec![1.0_f32; 128];
1391 let result =
1392 cpu_gemv_1bit_fallback(&block, &input, 1, 128).expect("cpu_gemv_1bit_fallback");
1393 assert!(
1394 (result[0] - (-128.0)).abs() < 1e-2,
1395 "got {} expected -128",
1396 result[0],
1397 );
1398 }
1399
1400 #[test]
1401 fn cpu_gemv_1bit_bad_k() {
1402 let result = cpu_gemv_1bit_fallback(&[], &[], 0, 64);
1403 assert!(result.is_err());
1404 }
1405
1406 #[test]
1407 fn gpu_gemv_1bit_without_gpu() {
1408 let scale = half::f16::from_f32(1.0);
1410 let scale_bytes = scale.to_bits().to_le_bytes();
1411
1412 let mut block = vec![0u8; 18];
1413 block[0] = scale_bytes[0];
1414 block[1] = scale_bytes[1];
1415 block[2..18].fill(0xFF);
1416
1417 let input: Vec<f32> = vec![1.0_f32; 128];
1418 let result = gpu_gemv_1bit(&block, &input, 1, 128).expect("gpu_gemv_1bit");
1419 assert!((result[0] - 128.0).abs() < 1e-2, "got {}", result[0]);
1420 }
1421}