Skip to main content

oxibonsai_kernels/gpu_backend/
mod.rs

1//! GPU backend abstraction layer for CUDA and Metal acceleration.
2//!
3//! This module defines the [`GpuBackendTrait`] trait and provides:
4//! - [`CpuBackend`]: Always-available CPU implementation (baseline)
5//! - [`CudaBackend`]: CUDA stub (feature = "cuda", compile-only placeholder)
6//! - [`MetalBackend`]: Metal stub (feature = "metal", compile-only placeholder)
7//! - [`Scirs2Backend`]: **Real** GPU backend via scirs2-core (feature = "gpu")
8//!
9//! # Architecture
10//! All GPU operations follow the same pattern:
11//! 1. Allocate device buffers
12//! 2. Copy host → device
13//! 3. Execute kernel
14//! 4. Copy device → host
15//!
16//! The [`Scirs2Backend`] compiles Metal/CUDA kernels at runtime through
17//! scirs2-core and dispatches real GPU work.  Stub backends delegate to
18//! CPU operations.
19//!
20//! # Q1_0_g128 GPU acceleration
21//!
22//! The [`gpu_gemv_1bit`] function provides a high-level entry point for
23//! 1-bit quantised matrix-vector multiplication on the GPU.
24
25#[cfg(all(
26    feature = "native-cuda",
27    any(target_os = "linux", target_os = "windows")
28))]
29pub mod cuda_attn_kernels;
30#[cfg(all(
31    feature = "native-cuda",
32    any(target_os = "linux", target_os = "windows")
33))]
34pub mod cuda_fp8_kernels;
35#[cfg(all(
36    feature = "native-cuda",
37    any(target_os = "linux", target_os = "windows")
38))]
39pub mod cuda_fp8_prefill;
40#[cfg(all(
41    feature = "native-cuda",
42    any(target_os = "linux", target_os = "windows")
43))]
44pub mod cuda_fp8_prefill_kernels;
45#[cfg(all(
46    feature = "native-cuda",
47    any(target_os = "linux", target_os = "windows")
48))]
49pub mod cuda_full_layer;
50#[cfg(all(
51    feature = "native-cuda",
52    any(target_os = "linux", target_os = "windows")
53))]
54pub mod cuda_graph;
55#[cfg(all(
56    feature = "native-cuda",
57    any(target_os = "linux", target_os = "windows")
58))]
59pub mod cuda_k_quant_kernels;
60#[cfg(all(
61    feature = "native-cuda",
62    any(target_os = "linux", target_os = "windows")
63))]
64pub mod cuda_k_quant_prefill;
65#[cfg(all(
66    feature = "native-cuda",
67    any(target_os = "linux", target_os = "windows")
68))]
69pub mod cuda_k_quant_prefill_kernels;
70#[cfg(all(
71    feature = "native-cuda",
72    any(target_os = "linux", target_os = "windows")
73))]
74pub mod cuda_kernels;
75#[cfg(all(
76    feature = "native-cuda",
77    any(target_os = "linux", target_os = "windows")
78))]
79pub mod cuda_prefill;
80#[cfg(all(
81    feature = "native-cuda",
82    any(target_os = "linux", target_os = "windows")
83))]
84pub mod cuda_prefill_kernels;
85#[cfg(all(
86    feature = "native-cuda",
87    any(target_os = "linux", target_os = "windows")
88))]
89pub mod cuda_q_std_kernels;
90#[cfg(all(
91    feature = "native-cuda",
92    any(target_os = "linux", target_os = "windows")
93))]
94pub mod cuda_q_std_prefill;
95#[cfg(all(
96    feature = "native-cuda",
97    any(target_os = "linux", target_os = "windows")
98))]
99pub mod cuda_q_std_prefill_kernels;
100pub mod kernel_sources;
101#[cfg(all(feature = "metal", target_os = "macos"))]
102mod metal_dispatch;
103#[cfg(all(feature = "metal", target_os = "macos"))]
104pub mod metal_fp8_kernels;
105#[cfg(all(feature = "metal", target_os = "macos"))]
106pub mod metal_fp8_prefill;
107#[cfg(all(feature = "metal", target_os = "macos"))]
108pub mod metal_full_layer;
109#[cfg(all(feature = "metal", target_os = "macos"))]
110pub mod metal_graph;
111#[cfg(all(feature = "metal", target_os = "macos"))]
112mod metal_prefill;
113pub mod scirs2_backend;
114
115use thiserror::Error;
116#[allow(unused_imports)]
117use tracing::warn;
118
119#[cfg(feature = "gpu")]
120pub use scirs2_backend::Scirs2Backend;
121
122#[cfg(all(feature = "metal", target_os = "macos"))]
123pub use metal_fp8_kernels::{metal_gemv_fp8_e4m3, metal_gemv_fp8_e5m2};
124
125#[cfg(all(feature = "metal", target_os = "macos"))]
126pub use metal_fp8_prefill::{
127    metal_fused_gate_up_swiglu_fp8_e4m3, metal_fused_gate_up_swiglu_fp8_e5m2, metal_gemm_fp8_e4m3,
128    metal_gemm_fp8_e4m3_residual, metal_gemm_fp8_e5m2, metal_gemm_fp8_e5m2_residual,
129};
130
131#[cfg(all(feature = "metal", target_os = "macos"))]
132pub use metal_graph::{MetalGraph, MetalGraphError, MetalWeightHandle};
133
134#[cfg(all(feature = "metal", target_os = "macos"))]
135pub use metal_full_layer::{
136    build_cached_weights, build_cached_weights_ternary_only, print_gpu_profile_summary,
137    try_metal_ffn, try_metal_forward_greedy_ternary, try_metal_full_forward,
138    try_metal_full_forward_cached, try_metal_full_forward_ternary, try_metal_full_layer,
139    try_metal_prefill_ternary, try_metal_prefill_verify_ternary, try_metal_qkv, CachedLayerWeights,
140    CachedModelWeights, FullForwardLayerParams, FullForwardLayerParamsTernary,
141};
142
143#[cfg(all(feature = "metal", target_os = "macos"))]
144pub use metal_prefill::{
145    try_metal_full_forward_prefill, try_metal_full_forward_prefill_ternary,
146    try_metal_full_forward_prefill_verify, try_metal_full_forward_prefill_verify_ternary,
147};
148
149#[cfg(all(
150    feature = "native-cuda",
151    any(target_os = "linux", target_os = "windows")
152))]
153pub use cuda_graph::{try_cuda_ffn, try_cuda_qkv, CudaGraph, CudaGraphError, NativeCudaBackend};
154
155#[cfg(all(
156    feature = "native-cuda",
157    any(target_os = "linux", target_os = "windows")
158))]
159pub use cuda_full_layer::{
160    try_cuda_full_forward, try_cuda_full_forward_ternary,
161    try_cuda_full_forward_ternary_with_gpu_lm_head, try_cuda_full_forward_with_gpu_lm_head,
162    try_cuda_full_layer, CudaCachedLayerWeights, CudaFullForwardLayerParams,
163    CudaFullForwardLayerParamsTernary,
164};
165
166#[cfg(all(
167    feature = "native-cuda",
168    any(target_os = "linux", target_os = "windows")
169))]
170pub use cuda_prefill::{try_cuda_prefill, try_cuda_prefill_ternary};
171
172#[cfg(all(
173    feature = "native-cuda",
174    any(target_os = "linux", target_os = "windows")
175))]
176pub use cuda_fp8_kernels::{cuda_gemv_fp8_e4m3, cuda_gemv_fp8_e5m2};
177
178#[cfg(all(
179    feature = "native-cuda",
180    any(target_os = "linux", target_os = "windows")
181))]
182pub use cuda_k_quant_kernels::{
183    cuda_gemv_q2k, cuda_gemv_q3k, cuda_gemv_q4k, cuda_gemv_q5k, cuda_gemv_q6k, cuda_gemv_q8k,
184};
185#[cfg(all(
186    feature = "native-cuda",
187    any(target_os = "linux", target_os = "windows")
188))]
189pub use cuda_q_std_kernels::{cuda_gemv_q4_0, cuda_gemv_q8_0};
190
191#[cfg(all(
192    feature = "native-cuda",
193    any(target_os = "linux", target_os = "windows")
194))]
195pub use cuda_q_std_prefill::{try_cuda_prefill_q_std, CudaQStdPrefillLayerParams};
196
197#[cfg(all(
198    feature = "native-cuda",
199    any(target_os = "linux", target_os = "windows")
200))]
201pub use cuda_k_quant_prefill::{
202    try_cuda_prefill_k_quant, CudaKQuantPrefillLayerParams, KQuantFormat,
203};
204
205#[cfg(all(
206    feature = "native-cuda",
207    any(target_os = "linux", target_os = "windows")
208))]
209pub use cuda_fp8_prefill::{try_cuda_prefill_fp8, CudaFP8PrefillLayerParams};
210
211// ═══════════════════════════════════════════════════════════════════════════
212// DeviceBuffer
213// ═══════════════════════════════════════════════════════════════════════════
214
215/// Device memory buffer (opaque handle).
216///
217/// For CPU and stub backends this is simply a heap-allocated `Vec<f32>`.
218/// A future hardware backend would replace `data` with a raw device pointer
219/// and keep the `Vec` only as a host-side staging buffer.
220pub struct DeviceBuffer {
221    /// CPU backing store (used by all stub backends).
222    pub data: Vec<f32>,
223    /// Number of `f32` elements in the buffer.
224    pub size: usize,
225    /// Logical device index this buffer is associated with.
226    pub device_id: usize,
227}
228
229impl DeviceBuffer {
230    /// Allocate a zero-initialised buffer of `size` elements on `device_id`.
231    pub fn new(size: usize, device_id: usize) -> Self {
232        Self {
233            data: vec![0.0_f32; size],
234            size,
235            device_id,
236        }
237    }
238
239    /// Create a buffer pre-populated from a host slice.
240    pub fn from_slice(data: &[f32], device_id: usize) -> Self {
241        let size = data.len();
242        Self {
243            data: data.to_vec(),
244            size,
245            device_id,
246        }
247    }
248
249    /// Copy the buffer contents back to a host `Vec<f32>`.
250    pub fn to_vec(&self) -> Vec<f32> {
251        self.data.clone()
252    }
253
254    /// Number of `f32` elements stored in this buffer.
255    pub fn size(&self) -> usize {
256        self.size
257    }
258
259    /// Logical device index this buffer is bound to.
260    pub fn device_id(&self) -> usize {
261        self.device_id
262    }
263}
264
265// ═══════════════════════════════════════════════════════════════════════════
266// LaunchConfig
267// ═══════════════════════════════════════════════════════════════════════════
268
269/// Kernel launch configuration (CUDA-style grid/block decomposition).
270///
271/// For CPU and Metal backends these values are informational only; the actual
272/// parallelism strategy is determined by the backend itself.
273#[derive(Debug, Clone, Copy, PartialEq, Eq)]
274pub struct LaunchConfig {
275    /// Number of thread-blocks (x, y, z).
276    pub grid_dim: (u32, u32, u32),
277    /// Threads per block (x, y, z).
278    pub block_dim: (u32, u32, u32),
279    /// Dynamic shared memory per block in bytes.
280    pub shared_mem_bytes: u32,
281}
282
283/// Default block size used by `for_n_elements`.
284const DEFAULT_BLOCK_SIZE: u32 = 256;
285
286impl LaunchConfig {
287    /// Auto-compute a 1-D launch configuration for `n` elements.
288    ///
289    /// Uses a block size of 256 threads and rounds the grid up to cover all
290    /// elements.  `shared_mem_bytes` is set to zero.
291    pub fn for_n_elements(n: usize) -> Self {
292        let block = DEFAULT_BLOCK_SIZE;
293        let grid = ((n as u32).saturating_add(block - 1)) / block;
294        Self {
295            grid_dim: (grid.max(1), 1, 1),
296            block_dim: (block, 1, 1),
297            shared_mem_bytes: 0,
298        }
299    }
300
301    /// A sensible default 1-D config (1 block of 256 threads).
302    pub fn default_1d() -> Self {
303        Self {
304            grid_dim: (1, 1, 1),
305            block_dim: (DEFAULT_BLOCK_SIZE, 1, 1),
306            shared_mem_bytes: 0,
307        }
308    }
309}
310
311// ═══════════════════════════════════════════════════════════════════════════
312// GpuError
313// ═══════════════════════════════════════════════════════════════════════════
314
315/// Error type for GPU backend operations.
316#[derive(Debug, Error)]
317pub enum GpuError {
318    /// The requested GPU/backend is not present or not compiled in.
319    #[error("GPU not available: {0}")]
320    NotAvailable(String),
321
322    /// A device-side allocation failed due to insufficient memory.
323    #[error("out of device memory: requested {requested} bytes on device {device}")]
324    OutOfMemory {
325        /// Requested allocation size in bytes.
326        requested: usize,
327        /// Device index that was targeted.
328        device: usize,
329    },
330
331    /// A kernel could not be launched (bad dimensions, missing module, etc.).
332    #[error("kernel launch failed: {0}")]
333    KernelLaunch(String),
334
335    /// The device failed to synchronise after kernel execution.
336    #[error("device synchronization failed: {0}")]
337    SyncFailed(String),
338
339    /// A parameter value is out of range or logically inconsistent.
340    #[error("invalid argument: {0}")]
341    InvalidArgument(String),
342}
343
344// ═══════════════════════════════════════════════════════════════════════════
345// GpuBackendTrait
346// ═══════════════════════════════════════════════════════════════════════════
347
348/// Core GPU backend trait.
349///
350/// Implementations of this trait provide the primitive operations required by
351/// the OxiBonsai inference engine.  The [`CpuBackend`] is always available
352/// and is used as a correctness baseline; hardware backends are feature-gated.
353///
354/// # Backwards compatibility
355///
356/// This trait was previously named `GpuBackend`.  The type alias
357/// [`GpuBackend`] preserves source compatibility.
358pub trait GpuBackendTrait: Send + Sync {
359    /// Human-readable backend identifier (e.g. `"cpu"`, `"cuda"`, `"metal"`).
360    fn name(&self) -> &'static str;
361
362    /// Returns `true` only when the backend is backed by real GPU hardware.
363    fn is_accelerated(&self) -> bool;
364
365    /// Number of logical devices available to this backend.
366    fn device_count(&self) -> usize;
367
368    /// Allocate an uninitialised (zero-filled for stubs) device buffer.
369    fn alloc(&self, size: usize, device_id: usize) -> Result<DeviceBuffer, GpuError>;
370
371    /// Copy a host slice to a new device buffer and return it.
372    fn host_to_device(&self, src: &[f32], device_id: usize) -> Result<DeviceBuffer, GpuError>;
373
374    /// Copy a device buffer to a new host `Vec<f32>`.
375    fn device_to_host(&self, buf: &DeviceBuffer) -> Result<Vec<f32>, GpuError>;
376
377    /// Matrix-vector multiply: **y = A · x**.
378    ///
379    /// - `a` — row-major matrix of shape `[m, k]`
380    /// - `x` — column vector of length `k`
381    /// - Returns a buffer of length `m`.
382    fn matvec(
383        &self,
384        a: &DeviceBuffer,
385        x: &DeviceBuffer,
386        m: usize,
387        k: usize,
388        device_id: usize,
389    ) -> Result<DeviceBuffer, GpuError>;
390
391    /// Element-wise ReLU: **y_i = max(0, x_i)**.
392    fn relu(&self, x: &DeviceBuffer, device_id: usize) -> Result<DeviceBuffer, GpuError>;
393
394    /// Softmax over the entire buffer (treated as a 1-D vector of `size` elements).
395    fn softmax(
396        &self,
397        x: &DeviceBuffer,
398        size: usize,
399        device_id: usize,
400    ) -> Result<DeviceBuffer, GpuError>;
401
402    /// Block until all previously submitted kernels on `device_id` have finished.
403    fn synchronize(&self, device_id: usize) -> Result<(), GpuError>;
404
405    /// Query device memory: returns `(free_bytes, total_bytes)`.
406    fn memory_info(&self, device_id: usize) -> Result<(usize, usize), GpuError>;
407
408    /// Q1_0_g128 matrix-vector product.
409    ///
410    /// Default implementation falls back to CPU dequant + scalar GEMV.
411    /// [`Scirs2Backend`] overrides this with a real GPU kernel.
412    fn gemv_q1_g128(
413        &self,
414        block_bytes: &[u8],
415        input: &[f32],
416        n_rows: usize,
417        k: usize,
418    ) -> Result<Vec<f32>, GpuError> {
419        cpu_gemv_1bit_fallback(block_bytes, input, n_rows, k)
420    }
421
422    /// Q1_0_g128 matrix-matrix product.
423    ///
424    /// Default implementation falls back to repeated [`gemv_q1_g128`](Self::gemv_q1_g128) calls.
425    fn gemm_q1_g128(
426        &self,
427        block_bytes: &[u8],
428        input: &[f32],
429        m: usize,
430        n_rows: usize,
431        k: usize,
432    ) -> Result<Vec<f32>, GpuError> {
433        let mut output = vec![0.0_f32; m * n_rows];
434        for i in 0..m {
435            let row_input = &input[i * k..(i + 1) * k];
436            let row_output = self.gemv_q1_g128(block_bytes, row_input, n_rows, k)?;
437            output[i * n_rows..(i + 1) * n_rows].copy_from_slice(&row_output);
438        }
439        Ok(output)
440    }
441
442    /// Upload weight block bytes to GPU memory and return a reusable handle.
443    ///
444    /// Default: not supported (returns `NotAvailable`).
445    fn upload_weights_raw(
446        &self,
447        _block_bytes: &[u8],
448    ) -> Result<crate::weight_cache::GpuWeightHandle, GpuError> {
449        Err(GpuError::NotAvailable(
450            "weight caching not supported by this backend".into(),
451        ))
452    }
453
454    /// Q1_0_g128 GEMV using a pre-uploaded GPU-resident weight buffer.
455    ///
456    /// Default: not supported (returns `NotAvailable`).
457    fn gemv_q1_g128_cached(
458        &self,
459        _handle: crate::weight_cache::GpuWeightHandle,
460        _input: &[f32],
461        _n_rows: usize,
462        _k: usize,
463    ) -> Result<Vec<f32>, GpuError> {
464        Err(GpuError::NotAvailable(
465            "cached GEMV not supported by this backend".into(),
466        ))
467    }
468
469    /// Upload TQ2_0_g128 weight blocks to GPU memory in SoA layout.
470    ///
471    /// Default: not supported (returns `NotAvailable`).
472    fn upload_weights_ternary(
473        &self,
474        _blocks: &[oxibonsai_core::BlockTQ2_0_g128],
475    ) -> Result<crate::weight_cache::GpuWeightHandle, GpuError> {
476        Err(GpuError::NotAvailable(
477            "ternary weight upload not supported by this backend".into(),
478        ))
479    }
480
481    /// TQ2_0_g128 GEMV using a pre-uploaded GPU-resident weight buffer.
482    ///
483    /// Default: not supported (returns `NotAvailable`).
484    fn gemv_tq2_g128_cached(
485        &self,
486        _handle: crate::weight_cache::GpuWeightHandle,
487        _input: &[f32],
488        _n_rows: usize,
489        _k: usize,
490    ) -> Result<Vec<f32>, GpuError> {
491        Err(GpuError::NotAvailable(
492            "cached ternary GEMV not supported by this backend".into(),
493        ))
494    }
495
496    /// Batch-execute attention input phase (RMSNorm + QKV) in one command buffer.
497    ///
498    /// Returns `Ok(Some((q, k, v)))` if batching succeeded, or `Ok(None)` if
499    /// not supported by this backend.
500    #[allow(clippy::too_many_arguments, clippy::type_complexity)]
501    fn batch_attn_phase(
502        &self,
503        _hidden: &[f32],
504        _norm_weight: &[f32],
505        _norm_eps: f32,
506        _qkv_handle: crate::weight_cache::GpuWeightHandle,
507        _q_rows: usize,
508        _k_rows: usize,
509        _h: usize,
510    ) -> Result<Option<(Vec<f32>, Vec<f32>, Vec<f32>)>, GpuError> {
511        Ok(None)
512    }
513
514    /// Batch-execute FFN phase in one command buffer.
515    ///
516    /// Returns `Ok(true)` if batching succeeded and `hidden` was modified
517    /// in-place, or `Ok(false)` if not supported by this backend.
518    #[allow(clippy::too_many_arguments)]
519    fn batch_ffn_phase(
520        &self,
521        _hidden: &mut [f32],
522        _attn_out: &[f32],
523        _norm_weight: &[f32],
524        _norm_eps: f32,
525        _attn_proj_handle: crate::weight_cache::GpuWeightHandle,
526        _gate_up_handle: crate::weight_cache::GpuWeightHandle,
527        _down_handle: crate::weight_cache::GpuWeightHandle,
528        _h: usize,
529        _intermediate: usize,
530        _attn_proj_k: usize,
531    ) -> Result<bool, GpuError> {
532        Ok(false)
533    }
534}
535
536/// Backwards-compatible type alias for the GPU backend trait.
537///
538/// Existing code that references `GpuBackend` as a trait will continue to
539/// compile.
540pub type GpuBackend = dyn GpuBackendTrait;
541
542// ═══════════════════════════════════════════════════════════════════════════
543// CpuBackend
544// ═══════════════════════════════════════════════════════════════════════════
545
546/// CPU backend — always available, no GPU required.
547///
548/// Implements [`GpuBackendTrait`] using plain scalar Rust operations.
549pub struct CpuBackend {
550    /// Simulated total device memory reported by `memory_info`.
551    pub simulated_memory_bytes: usize,
552}
553
554impl CpuBackend {
555    /// Create a `CpuBackend` with a default simulated memory of 4 GiB.
556    pub fn new() -> Self {
557        Self {
558            simulated_memory_bytes: 4 * 1024 * 1024 * 1024,
559        }
560    }
561
562    /// Create a `CpuBackend` with a custom simulated memory size (bytes).
563    pub fn with_memory(bytes: usize) -> Self {
564        Self {
565            simulated_memory_bytes: bytes,
566        }
567    }
568}
569
570impl Default for CpuBackend {
571    fn default() -> Self {
572        Self::new()
573    }
574}
575
576impl GpuBackendTrait for CpuBackend {
577    fn name(&self) -> &'static str {
578        "cpu"
579    }
580
581    fn is_accelerated(&self) -> bool {
582        false
583    }
584
585    fn device_count(&self) -> usize {
586        1
587    }
588
589    fn alloc(&self, size: usize, device_id: usize) -> Result<DeviceBuffer, GpuError> {
590        Ok(DeviceBuffer::new(size, device_id))
591    }
592
593    fn host_to_device(&self, src: &[f32], device_id: usize) -> Result<DeviceBuffer, GpuError> {
594        Ok(DeviceBuffer::from_slice(src, device_id))
595    }
596
597    fn device_to_host(&self, buf: &DeviceBuffer) -> Result<Vec<f32>, GpuError> {
598        Ok(buf.to_vec())
599    }
600
601    fn matvec(
602        &self,
603        a: &DeviceBuffer,
604        x: &DeviceBuffer,
605        m: usize,
606        k: usize,
607        device_id: usize,
608    ) -> Result<DeviceBuffer, GpuError> {
609        if a.size() != m * k {
610            return Err(GpuError::InvalidArgument(format!(
611                "matrix buffer size {} does not match m={} k={}",
612                a.size(),
613                m,
614                k
615            )));
616        }
617        if x.size() != k {
618            return Err(GpuError::InvalidArgument(format!(
619                "vector buffer size {} does not match k={}",
620                x.size(),
621                k
622            )));
623        }
624
625        let mut result = vec![0.0_f32; m];
626        for (row, slot) in result.iter_mut().enumerate().take(m) {
627            let mut acc = 0.0_f32;
628            for col in 0..k {
629                acc += a.data[row * k + col] * x.data[col];
630            }
631            *slot = acc;
632        }
633
634        Ok(DeviceBuffer::from_slice(&result, device_id))
635    }
636
637    fn relu(&self, x: &DeviceBuffer, device_id: usize) -> Result<DeviceBuffer, GpuError> {
638        let result: Vec<f32> = x.data.iter().map(|&v| v.max(0.0)).collect();
639        Ok(DeviceBuffer::from_slice(&result, device_id))
640    }
641
642    fn softmax(
643        &self,
644        x: &DeviceBuffer,
645        size: usize,
646        device_id: usize,
647    ) -> Result<DeviceBuffer, GpuError> {
648        if x.size() != size {
649            return Err(GpuError::InvalidArgument(format!(
650                "buffer size {} does not match size={}",
651                x.size(),
652                size
653            )));
654        }
655        if size == 0 {
656            return Ok(DeviceBuffer::new(0, device_id));
657        }
658
659        let max_val = x.data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
660        let exps: Vec<f32> = x.data.iter().map(|&v| (v - max_val).exp()).collect();
661        let sum: f32 = exps.iter().sum();
662
663        let result: Vec<f32> = if sum == 0.0 {
664            vec![1.0 / size as f32; size]
665        } else {
666            exps.iter().map(|&e| e / sum).collect()
667        };
668
669        Ok(DeviceBuffer::from_slice(&result, device_id))
670    }
671
672    fn synchronize(&self, _device_id: usize) -> Result<(), GpuError> {
673        Ok(())
674    }
675
676    fn memory_info(&self, _device_id: usize) -> Result<(usize, usize), GpuError> {
677        let total = self.simulated_memory_bytes;
678        let free = total / 2;
679        Ok((free, total))
680    }
681}
682
683// ═══════════════════════════════════════════════════════════════════════════
684// CudaBackend (stub, feature = "cuda")
685// ═══════════════════════════════════════════════════════════════════════════
686
687/// CUDA backend stub — feature-gated, compile-only placeholder.
688///
689/// All operations delegate to `CpuBackend` and emit a `warn!` trace event.
690/// Use [`Scirs2Backend`] for real GPU acceleration.
691#[cfg(feature = "cuda")]
692pub struct CudaBackend {
693    /// Number of CUDA devices detected at construction time.
694    pub device_count: usize,
695    cpu_fallback: CpuBackend,
696}
697
698#[cfg(feature = "cuda")]
699impl CudaBackend {
700    /// Attempt to initialise the CUDA backend (stub).
701    pub fn new() -> Result<Self, GpuError> {
702        warn!("CudaBackend: CUDA stub active — no real GPU acceleration");
703        Ok(Self {
704            device_count: 1,
705            cpu_fallback: CpuBackend::new(),
706        })
707    }
708}
709
710#[cfg(feature = "cuda")]
711impl GpuBackendTrait for CudaBackend {
712    fn name(&self) -> &'static str {
713        "cuda"
714    }
715
716    fn is_accelerated(&self) -> bool {
717        false
718    }
719
720    fn device_count(&self) -> usize {
721        self.device_count
722    }
723
724    fn alloc(&self, size: usize, device_id: usize) -> Result<DeviceBuffer, GpuError> {
725        warn!("CudaBackend::alloc delegating to CPU fallback");
726        self.cpu_fallback.alloc(size, device_id)
727    }
728
729    fn host_to_device(&self, src: &[f32], device_id: usize) -> Result<DeviceBuffer, GpuError> {
730        warn!("CudaBackend::host_to_device delegating to CPU fallback");
731        self.cpu_fallback.host_to_device(src, device_id)
732    }
733
734    fn device_to_host(&self, buf: &DeviceBuffer) -> Result<Vec<f32>, GpuError> {
735        warn!("CudaBackend::device_to_host delegating to CPU fallback");
736        self.cpu_fallback.device_to_host(buf)
737    }
738
739    fn matvec(
740        &self,
741        a: &DeviceBuffer,
742        x: &DeviceBuffer,
743        m: usize,
744        k: usize,
745        device_id: usize,
746    ) -> Result<DeviceBuffer, GpuError> {
747        warn!("CudaBackend::matvec delegating to CPU fallback");
748        self.cpu_fallback.matvec(a, x, m, k, device_id)
749    }
750
751    fn relu(&self, x: &DeviceBuffer, device_id: usize) -> Result<DeviceBuffer, GpuError> {
752        warn!("CudaBackend::relu delegating to CPU fallback");
753        self.cpu_fallback.relu(x, device_id)
754    }
755
756    fn softmax(
757        &self,
758        x: &DeviceBuffer,
759        size: usize,
760        device_id: usize,
761    ) -> Result<DeviceBuffer, GpuError> {
762        warn!("CudaBackend::softmax delegating to CPU fallback");
763        self.cpu_fallback.softmax(x, size, device_id)
764    }
765
766    fn synchronize(&self, device_id: usize) -> Result<(), GpuError> {
767        warn!("CudaBackend::synchronize delegating to CPU fallback");
768        self.cpu_fallback.synchronize(device_id)
769    }
770
771    fn memory_info(&self, device_id: usize) -> Result<(usize, usize), GpuError> {
772        warn!("CudaBackend::memory_info delegating to CPU fallback");
773        self.cpu_fallback.memory_info(device_id)
774    }
775}
776
777// ═══════════════════════════════════════════════════════════════════════════
778// MetalBackend (stub, feature = "metal", macOS only)
779// ═══════════════════════════════════════════════════════════════════════════
780
781/// Metal backend stub — feature-gated, macOS only, compile-only placeholder.
782///
783/// Use [`Scirs2Backend`] for real GPU acceleration.
784#[cfg(all(feature = "metal", target_os = "macos"))]
785pub struct MetalBackend {
786    /// Number of Metal devices detected at construction time.
787    pub device_count: usize,
788    cpu_fallback: CpuBackend,
789}
790
791#[cfg(all(feature = "metal", target_os = "macos"))]
792impl MetalBackend {
793    /// Attempt to initialise the Metal backend (stub).
794    pub fn new() -> Result<Self, GpuError> {
795        warn!("MetalBackend: Metal stub active — no real GPU acceleration");
796        Ok(Self {
797            device_count: 1,
798            cpu_fallback: CpuBackend::new(),
799        })
800    }
801}
802
803#[cfg(all(feature = "metal", target_os = "macos"))]
804impl GpuBackendTrait for MetalBackend {
805    fn name(&self) -> &'static str {
806        "metal"
807    }
808
809    fn is_accelerated(&self) -> bool {
810        false
811    }
812
813    fn device_count(&self) -> usize {
814        self.device_count
815    }
816
817    fn alloc(&self, size: usize, device_id: usize) -> Result<DeviceBuffer, GpuError> {
818        warn!("MetalBackend::alloc delegating to CPU fallback");
819        self.cpu_fallback.alloc(size, device_id)
820    }
821
822    fn host_to_device(&self, src: &[f32], device_id: usize) -> Result<DeviceBuffer, GpuError> {
823        warn!("MetalBackend::host_to_device delegating to CPU fallback");
824        self.cpu_fallback.host_to_device(src, device_id)
825    }
826
827    fn device_to_host(&self, buf: &DeviceBuffer) -> Result<Vec<f32>, GpuError> {
828        warn!("MetalBackend::device_to_host delegating to CPU fallback");
829        self.cpu_fallback.device_to_host(buf)
830    }
831
832    fn matvec(
833        &self,
834        a: &DeviceBuffer,
835        x: &DeviceBuffer,
836        m: usize,
837        k: usize,
838        device_id: usize,
839    ) -> Result<DeviceBuffer, GpuError> {
840        warn!("MetalBackend::matvec delegating to CPU fallback");
841        self.cpu_fallback.matvec(a, x, m, k, device_id)
842    }
843
844    fn relu(&self, x: &DeviceBuffer, device_id: usize) -> Result<DeviceBuffer, GpuError> {
845        warn!("MetalBackend::relu delegating to CPU fallback");
846        self.cpu_fallback.relu(x, device_id)
847    }
848
849    fn softmax(
850        &self,
851        x: &DeviceBuffer,
852        size: usize,
853        device_id: usize,
854    ) -> Result<DeviceBuffer, GpuError> {
855        warn!("MetalBackend::softmax delegating to CPU fallback");
856        self.cpu_fallback.softmax(x, size, device_id)
857    }
858
859    fn synchronize(&self, device_id: usize) -> Result<(), GpuError> {
860        warn!("MetalBackend::synchronize delegating to CPU fallback");
861        self.cpu_fallback.synchronize(device_id)
862    }
863
864    fn memory_info(&self, device_id: usize) -> Result<(usize, usize), GpuError> {
865        warn!("MetalBackend::memory_info delegating to CPU fallback");
866        self.cpu_fallback.memory_info(device_id)
867    }
868}
869
870// ═══════════════════════════════════════════════════════════════════════════
871// Scirs2BackendHandle (singleton wrapper)
872// ═══════════════════════════════════════════════════════════════════════════
873
874/// Thin wrapper around `Arc<Scirs2Backend>` that implements [`GpuBackendTrait`].
875///
876/// This allows the process-wide singleton to be used wherever a
877/// `Box<dyn GpuBackendTrait>` is expected (e.g. [`select_backend`]).
878#[cfg(feature = "gpu")]
879pub(crate) struct Scirs2BackendHandle(pub(crate) std::sync::Arc<Scirs2Backend>);
880
881#[cfg(feature = "gpu")]
882impl GpuBackendTrait for Scirs2BackendHandle {
883    fn name(&self) -> &'static str {
884        self.0.name()
885    }
886    fn is_accelerated(&self) -> bool {
887        self.0.is_accelerated()
888    }
889    fn device_count(&self) -> usize {
890        self.0.device_count()
891    }
892    fn alloc(&self, size: usize, device_id: usize) -> Result<DeviceBuffer, GpuError> {
893        self.0.alloc(size, device_id)
894    }
895    fn host_to_device(&self, src: &[f32], device_id: usize) -> Result<DeviceBuffer, GpuError> {
896        self.0.host_to_device(src, device_id)
897    }
898    fn device_to_host(&self, buf: &DeviceBuffer) -> Result<Vec<f32>, GpuError> {
899        self.0.device_to_host(buf)
900    }
901    fn matvec(
902        &self,
903        a: &DeviceBuffer,
904        x: &DeviceBuffer,
905        m: usize,
906        k: usize,
907        device_id: usize,
908    ) -> Result<DeviceBuffer, GpuError> {
909        self.0.matvec(a, x, m, k, device_id)
910    }
911    fn relu(&self, x: &DeviceBuffer, device_id: usize) -> Result<DeviceBuffer, GpuError> {
912        self.0.relu(x, device_id)
913    }
914    fn softmax(
915        &self,
916        x: &DeviceBuffer,
917        size: usize,
918        device_id: usize,
919    ) -> Result<DeviceBuffer, GpuError> {
920        self.0.softmax(x, size, device_id)
921    }
922    fn synchronize(&self, device_id: usize) -> Result<(), GpuError> {
923        self.0.synchronize(device_id)
924    }
925    fn memory_info(&self, device_id: usize) -> Result<(usize, usize), GpuError> {
926        self.0.memory_info(device_id)
927    }
928    fn gemv_q1_g128(
929        &self,
930        block_bytes: &[u8],
931        input: &[f32],
932        n_rows: usize,
933        k: usize,
934    ) -> Result<Vec<f32>, GpuError> {
935        self.0.gemv_q1_g128(block_bytes, input, n_rows, k)
936    }
937    fn gemm_q1_g128(
938        &self,
939        block_bytes: &[u8],
940        input: &[f32],
941        m: usize,
942        n_rows: usize,
943        k: usize,
944    ) -> Result<Vec<f32>, GpuError> {
945        self.0.gemm_q1_g128(block_bytes, input, m, n_rows, k)
946    }
947    fn upload_weights_raw(
948        &self,
949        block_bytes: &[u8],
950    ) -> Result<crate::weight_cache::GpuWeightHandle, GpuError> {
951        self.0.upload_weights(block_bytes)
952    }
953    fn gemv_q1_g128_cached(
954        &self,
955        handle: crate::weight_cache::GpuWeightHandle,
956        input: &[f32],
957        n_rows: usize,
958        k: usize,
959    ) -> Result<Vec<f32>, GpuError> {
960        self.0.gemv_q1_g128_cached(handle, input, n_rows, k)
961    }
962
963    fn upload_weights_ternary(
964        &self,
965        blocks: &[oxibonsai_core::BlockTQ2_0_g128],
966    ) -> Result<crate::weight_cache::GpuWeightHandle, GpuError> {
967        self.0.upload_weights_ternary(blocks)
968    }
969
970    fn gemv_tq2_g128_cached(
971        &self,
972        handle: crate::weight_cache::GpuWeightHandle,
973        input: &[f32],
974        n_rows: usize,
975        k: usize,
976    ) -> Result<Vec<f32>, GpuError> {
977        self.0.gemv_tq2_g128_cached(handle, input, n_rows, k)
978    }
979
980    fn batch_attn_phase(
981        &self,
982        hidden: &[f32],
983        norm_weight: &[f32],
984        norm_eps: f32,
985        qkv_handle: crate::weight_cache::GpuWeightHandle,
986        q_rows: usize,
987        k_rows: usize,
988        h: usize,
989    ) -> Result<Option<(Vec<f32>, Vec<f32>, Vec<f32>)>, GpuError> {
990        match self
991            .0
992            .batch_attn_phase(hidden, norm_weight, norm_eps, qkv_handle, q_rows, k_rows, h)
993        {
994            Ok(result) => Ok(Some(result)),
995            Err(e) => {
996                tracing::warn!(error = %e, "batch_attn_phase failed, falling back");
997                Ok(None)
998            }
999        }
1000    }
1001
1002    fn batch_ffn_phase(
1003        &self,
1004        hidden: &mut [f32],
1005        attn_out: &[f32],
1006        norm_weight: &[f32],
1007        norm_eps: f32,
1008        attn_proj_handle: crate::weight_cache::GpuWeightHandle,
1009        gate_up_handle: crate::weight_cache::GpuWeightHandle,
1010        down_handle: crate::weight_cache::GpuWeightHandle,
1011        h: usize,
1012        intermediate: usize,
1013        attn_proj_k: usize,
1014    ) -> Result<bool, GpuError> {
1015        match self.0.batch_ffn_phase(
1016            hidden,
1017            attn_out,
1018            norm_weight,
1019            norm_eps,
1020            attn_proj_handle,
1021            gate_up_handle,
1022            down_handle,
1023            h,
1024            intermediate,
1025            attn_proj_k,
1026        ) {
1027            Ok(()) => Ok(true),
1028            Err(e) => {
1029                tracing::warn!(error = %e, "batch_ffn_phase failed, falling back");
1030                Ok(false)
1031            }
1032        }
1033    }
1034}
1035
1036// ═══════════════════════════════════════════════════════════════════════════
1037// select_backend
1038// ═══════════════════════════════════════════════════════════════════════════
1039
1040/// Select the best available backend automatically.
1041///
1042/// Priority order (highest to lowest):
1043/// 1. [`Scirs2Backend`] (feature = "gpu") — Metal-accelerated via scirs2-core
1044/// 2. `NativeCudaBackend` (feature = "native-cuda") — direct cudarc CUDA
1045/// 3. CUDA stub (feature = "cuda", no "native-cuda") — falls back to CPU
1046/// 4. Metal stub (feature = "metal", macOS only) — falls back to CPU
1047/// 5. [`CpuBackend`] (always available)
1048///
1049/// If initialisation fails at any level the function falls through to the
1050/// next option, ultimately always returning a functional `CpuBackend`.
1051pub fn select_backend() -> Box<dyn GpuBackendTrait> {
1052    // `select_backend` may be called several times in a process (model load,
1053    // engine init, tests). The "scirs2 not accelerated" / "init failed" warnings
1054    // are properties of the host environment, not of any individual call site,
1055    // so emit each variant at most once per process.
1056    #[cfg(feature = "gpu")]
1057    use std::sync::atomic::{AtomicBool, Ordering};
1058    #[cfg(feature = "gpu")]
1059    fn warn_once(flag: &AtomicBool, msg: impl FnOnce()) {
1060        if !flag.swap(true, Ordering::Relaxed) {
1061            msg();
1062        }
1063    }
1064
1065    // ── 1. Try scirs2-core GPU backend (Metal on macOS) ─────────────────
1066    #[cfg(feature = "gpu")]
1067    {
1068        static SCIRS2_NOT_ACCEL: AtomicBool = AtomicBool::new(false);
1069        static SCIRS2_INIT_FAIL: AtomicBool = AtomicBool::new(false);
1070        match Scirs2Backend::global() {
1071            Ok(b) => {
1072                if b.is_accelerated() {
1073                    return Box::new(Scirs2BackendHandle(b));
1074                }
1075                // scirs2-core returned a CPU context; skip and try stubs.
1076                warn_once(&SCIRS2_NOT_ACCEL, || {
1077                    warn!(
1078                        "select_backend: Scirs2Backend is not accelerated (backend={}), trying next",
1079                        b.backend_name()
1080                    );
1081                });
1082            }
1083            Err(e) => {
1084                warn_once(&SCIRS2_INIT_FAIL, || {
1085                    warn!("select_backend: Scirs2Backend init failed ({e}), trying next");
1086                });
1087            }
1088        }
1089    }
1090
1091    // ── 2. Native CUDA backend (direct cudarc, Linux/Windows) ───────────
1092    #[cfg(all(
1093        feature = "native-cuda",
1094        any(target_os = "linux", target_os = "windows")
1095    ))]
1096    {
1097        match NativeCudaBackend::new() {
1098            Ok(b) => {
1099                tracing::info!("select_backend: NativeCudaBackend initialised");
1100                return Box::new(b);
1101            }
1102            Err(e) => {
1103                warn!("select_backend: NativeCudaBackend init failed ({e}), trying next");
1104            }
1105        }
1106    }
1107
1108    // ── 3. CUDA stub ─────────────────────────────────────────────────────
1109    #[cfg(feature = "cuda")]
1110    {
1111        match CudaBackend::new() {
1112            Ok(b) => {
1113                return Box::new(b);
1114            }
1115            Err(e) => {
1116                warn!("select_backend: CUDA init failed ({e}), trying next");
1117            }
1118        }
1119    }
1120
1121    // ── 3. Metal stub ───────────────────────────────────────────────────
1122    #[cfg(all(feature = "metal", target_os = "macos"))]
1123    {
1124        match MetalBackend::new() {
1125            Ok(b) => {
1126                return Box::new(b);
1127            }
1128            Err(e) => {
1129                warn!("select_backend: Metal init failed ({e}), trying CPU");
1130            }
1131        }
1132    }
1133
1134    // ── 4. CPU fallback ─────────────────────────────────────────────────
1135    Box::new(CpuBackend::new())
1136}
1137
1138// ═══════════════════════════════════════════════════════════════════════════
1139// gpu_matmul utility
1140// ═══════════════════════════════════════════════════════════════════════════
1141
1142/// Perform a general matrix multiplication **C = A · B** using a GPU backend.
1143///
1144/// - `a` — row-major `[m, k]` matrix (length `m * k`)
1145/// - `b` — row-major `[k, n]` matrix (length `k * n`)
1146/// - Returns a row-major `[m, n]` matrix (length `m * n`)
1147///
1148/// This is implemented as `n` calls to `backend.matvec` (one per column of B)
1149/// and is provided as a convenience for callers that do not wish to manage
1150/// `DeviceBuffer` objects directly.
1151pub fn gpu_matmul(
1152    backend: &dyn GpuBackendTrait,
1153    a: &[f32],
1154    b: &[f32],
1155    m: usize,
1156    k: usize,
1157    n: usize,
1158    device_id: usize,
1159) -> Result<Vec<f32>, GpuError> {
1160    if a.len() != m * k {
1161        return Err(GpuError::InvalidArgument(format!(
1162            "a.len()={} does not match m={} k={}",
1163            a.len(),
1164            m,
1165            k
1166        )));
1167    }
1168    if b.len() != k * n {
1169        return Err(GpuError::InvalidArgument(format!(
1170            "b.len()={} does not match k={} n={}",
1171            b.len(),
1172            k,
1173            n
1174        )));
1175    }
1176
1177    let a_buf = backend.host_to_device(a, device_id)?;
1178
1179    let mut c = vec![0.0_f32; m * n];
1180
1181    for col in 0..n {
1182        let b_col: Vec<f32> = (0..k).map(|row| b[row * n + col]).collect();
1183        let x_buf = backend.host_to_device(&b_col, device_id)?;
1184        let y_buf = backend.matvec(&a_buf, &x_buf, m, k, device_id)?;
1185        let y = backend.device_to_host(&y_buf)?;
1186
1187        for row in 0..m {
1188            c[row * n + col] = y[row];
1189        }
1190    }
1191
1192    backend.synchronize(device_id)?;
1193    Ok(c)
1194}
1195
1196// ═══════════════════════════════════════════════════════════════════════════
1197// gpu_gemv_1bit — high-level Q1_0_g128 GPU GEMV
1198// ═══════════════════════════════════════════════════════════════════════════
1199
1200/// Perform Q1_0_g128 matrix-vector multiply on the GPU.
1201///
1202/// This is the primary entry point for callers that have raw
1203/// `BlockQ1_0G128` data and want GPU-accelerated inference.
1204///
1205/// # Arguments
1206/// - `block_bytes` — `&[u8]` raw bytes of `BlockQ1_0G128[]` (18 bytes each)
1207/// - `input` — `&[f32]` of length `k`
1208/// - `n_rows` — number of weight matrix rows
1209/// - `k` — input dimension (must be a multiple of 128)
1210///
1211/// # Returns
1212/// `Vec<f32>` of length `n_rows`, or falls back to CPU dequant+GEMV if GPU
1213/// is not available.
1214///
1215/// # Feature gates
1216/// Requires the `gpu` feature.  Without it, this function is still available
1217/// but always uses the CPU fallback path.
1218pub fn gpu_gemv_1bit(
1219    block_bytes: &[u8],
1220    input: &[f32],
1221    n_rows: usize,
1222    k: usize,
1223) -> Result<Vec<f32>, GpuError> {
1224    #[cfg(feature = "gpu")]
1225    {
1226        match Scirs2Backend::global() {
1227            Ok(backend) => {
1228                if backend.is_accelerated() {
1229                    return backend.gemv_q1_g128(block_bytes, input, n_rows, k);
1230                }
1231                // Non-accelerated (CPU) scirs2 context — use our own CPU path.
1232            }
1233            Err(e) => {
1234                warn!("gpu_gemv_1bit: GPU init failed ({e}), using CPU fallback");
1235            }
1236        }
1237    }
1238
1239    // CPU fallback: dequant + scalar GEMV.
1240    cpu_gemv_1bit_fallback(block_bytes, input, n_rows, k)
1241}
1242
1243/// CPU fallback for Q1_0_g128 GEMV.
1244///
1245/// Dequantises blocks inline and computes the dot-product per row.
1246fn cpu_gemv_1bit_fallback(
1247    block_bytes: &[u8],
1248    input: &[f32],
1249    n_rows: usize,
1250    k: usize,
1251) -> Result<Vec<f32>, GpuError> {
1252    if k == 0 || k % 128 != 0 {
1253        return Err(GpuError::InvalidArgument(format!(
1254            "k={k} must be a positive multiple of 128"
1255        )));
1256    }
1257    if input.len() != k {
1258        return Err(GpuError::InvalidArgument(format!(
1259            "input.len()={} != k={}",
1260            input.len(),
1261            k
1262        )));
1263    }
1264    let blocks_per_row = k / 128;
1265    let block_size = 18_usize;
1266    let expected = n_rows * blocks_per_row * block_size;
1267    if block_bytes.len() < expected {
1268        return Err(GpuError::InvalidArgument(format!(
1269            "block_bytes too small: {} < {}",
1270            block_bytes.len(),
1271            expected,
1272        )));
1273    }
1274
1275    let mut output = vec![0.0_f32; n_rows];
1276
1277    for (row, output_val) in output.iter_mut().enumerate().take(n_rows) {
1278        let mut sum = 0.0_f32;
1279        for b in 0..blocks_per_row {
1280            let block_idx = row * blocks_per_row + b;
1281            let off = block_idx * block_size;
1282
1283            // Read FP16 scale factor (little-endian).
1284            let d_bits = u16::from_le_bytes([block_bytes[off], block_bytes[off + 1]]);
1285            let scale = half::f16::from_bits(d_bits).to_f32();
1286
1287            let input_base = b * 128;
1288            // Process 4 × u32 = 128 bits.
1289            for w in 0..4_usize {
1290                let byte_off = off + 2 + w * 4;
1291                let bits = u32::from_le_bytes([
1292                    block_bytes[byte_off],
1293                    block_bytes[byte_off + 1],
1294                    block_bytes[byte_off + 2],
1295                    block_bytes[byte_off + 3],
1296                ]);
1297                let base = input_base + w * 32;
1298                for i in 0..32_usize {
1299                    let sign = if (bits >> i) & 1 == 1 {
1300                        1.0_f32
1301                    } else {
1302                        -1.0_f32
1303                    };
1304                    sum += scale * sign * input[base + i];
1305                }
1306            }
1307        }
1308        *output_val = sum;
1309    }
1310
1311    Ok(output)
1312}
1313
1314// ═══════════════════════════════════════════════════════════════════════════
1315// Unit tests
1316// ═══════════════════════════════════════════════════════════════════════════
1317
1318#[cfg(test)]
1319mod tests {
1320    use super::*;
1321
1322    #[test]
1323    fn device_buffer_new_zeroed() {
1324        let buf = DeviceBuffer::new(4, 0);
1325        assert_eq!(buf.size(), 4);
1326        assert_eq!(buf.device_id(), 0);
1327        assert!(buf.data.iter().all(|&v| v == 0.0));
1328    }
1329
1330    #[test]
1331    fn device_buffer_from_slice_roundtrip() {
1332        let src = [1.0_f32, 2.0, 3.0];
1333        let buf = DeviceBuffer::from_slice(&src, 1);
1334        assert_eq!(buf.to_vec(), src);
1335    }
1336
1337    #[test]
1338    fn launch_config_for_zero_elements() {
1339        let cfg = LaunchConfig::for_n_elements(0);
1340        assert_eq!(cfg.grid_dim.0, 1);
1341    }
1342
1343    #[test]
1344    fn cpu_softmax_empty() {
1345        let backend = CpuBackend::new();
1346        let buf = DeviceBuffer::new(0, 0);
1347        let out = backend.softmax(&buf, 0, 0).expect("softmax empty");
1348        assert_eq!(out.size(), 0);
1349    }
1350
1351    // ── CPU fallback GEMV tests ─────────────────────────────────────────
1352
1353    #[test]
1354    fn cpu_gemv_1bit_identity_scale() {
1355        // 1 row, k=128, all bits set (weight = +1), scale = 1.0
1356        let scale = half::f16::from_f32(1.0);
1357        let scale_bytes = scale.to_bits().to_le_bytes();
1358
1359        let mut block = vec![0u8; 18];
1360        block[0] = scale_bytes[0];
1361        block[1] = scale_bytes[1];
1362        // Set all 128 bits to 1 → all weights = +scale = +1
1363        block[2..18].fill(0xFF);
1364
1365        let input: Vec<f32> = (0..128).map(|i| i as f32).collect();
1366        let expected: f32 = input.iter().sum(); // sum(0..128) = 8128
1367
1368        let result =
1369            cpu_gemv_1bit_fallback(&block, &input, 1, 128).expect("cpu_gemv_1bit_fallback");
1370        assert!(
1371            (result[0] - expected).abs() < 1e-2,
1372            "got {} expected {}",
1373            result[0],
1374            expected,
1375        );
1376    }
1377
1378    #[test]
1379    fn cpu_gemv_1bit_negative_scale() {
1380        // All bits 0 → weight = -scale.  With scale=1.0 and input=1.0:
1381        // output = -1 * 128 * 1.0 = -128
1382        let scale = half::f16::from_f32(1.0);
1383        let scale_bytes = scale.to_bits().to_le_bytes();
1384
1385        let mut block = vec![0u8; 18];
1386        block[0] = scale_bytes[0];
1387        block[1] = scale_bytes[1];
1388        // qs all zero → weight = -scale
1389
1390        let input = vec![1.0_f32; 128];
1391        let result =
1392            cpu_gemv_1bit_fallback(&block, &input, 1, 128).expect("cpu_gemv_1bit_fallback");
1393        assert!(
1394            (result[0] - (-128.0)).abs() < 1e-2,
1395            "got {} expected -128",
1396            result[0],
1397        );
1398    }
1399
1400    #[test]
1401    fn cpu_gemv_1bit_bad_k() {
1402        let result = cpu_gemv_1bit_fallback(&[], &[], 0, 64);
1403        assert!(result.is_err());
1404    }
1405
1406    #[test]
1407    fn gpu_gemv_1bit_without_gpu() {
1408        // gpu_gemv_1bit should fall back to CPU.
1409        let scale = half::f16::from_f32(1.0);
1410        let scale_bytes = scale.to_bits().to_le_bytes();
1411
1412        let mut block = vec![0u8; 18];
1413        block[0] = scale_bytes[0];
1414        block[1] = scale_bytes[1];
1415        block[2..18].fill(0xFF);
1416
1417        let input: Vec<f32> = vec![1.0_f32; 128];
1418        let result = gpu_gemv_1bit(&block, &input, 1, 128).expect("gpu_gemv_1bit");
1419        assert!((result[0] - 128.0).abs() < 1e-2, "got {}", result[0]);
1420    }
1421}