#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
pub mod cuda_attn_kernels;
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
pub mod cuda_full_layer;
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
pub mod cuda_graph;
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
pub mod cuda_kernels;
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
pub mod cuda_prefill;
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
pub mod cuda_prefill_kernels;
pub mod kernel_sources;
#[cfg(all(feature = "metal", target_os = "macos"))]
mod metal_dispatch;
#[cfg(all(feature = "metal", target_os = "macos"))]
pub mod metal_full_layer;
#[cfg(all(feature = "metal", target_os = "macos"))]
pub mod metal_graph;
#[cfg(all(feature = "metal", target_os = "macos"))]
mod metal_prefill;
pub mod scirs2_backend;
use thiserror::Error;
#[allow(unused_imports)]
use tracing::warn;
#[cfg(feature = "gpu")]
pub use scirs2_backend::Scirs2Backend;
#[cfg(all(feature = "metal", target_os = "macos"))]
pub use metal_graph::{
try_metal_ffn, try_metal_qkv, MetalGraph, MetalGraphError, MetalWeightHandle,
};
#[cfg(all(feature = "metal", target_os = "macos"))]
pub use metal_full_layer::{
build_cached_weights, print_gpu_profile_summary, try_metal_full_forward,
try_metal_full_forward_cached, try_metal_full_forward_ternary, try_metal_full_layer,
CachedLayerWeights, CachedModelWeights, FullForwardLayerParams, FullForwardLayerParamsTernary,
};
#[cfg(all(feature = "metal", target_os = "macos"))]
pub use metal_prefill::{try_metal_full_forward_prefill, try_metal_full_forward_prefill_verify};
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
pub use cuda_graph::{try_cuda_ffn, try_cuda_qkv, CudaGraph, CudaGraphError, NativeCudaBackend};
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
pub use cuda_full_layer::{
try_cuda_full_forward, try_cuda_full_forward_with_gpu_lm_head, try_cuda_full_layer,
CudaCachedLayerWeights, CudaFullForwardLayerParams,
};
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
pub use cuda_prefill::try_cuda_prefill;
pub struct DeviceBuffer {
pub data: Vec<f32>,
pub size: usize,
pub device_id: usize,
}
impl DeviceBuffer {
pub fn new(size: usize, device_id: usize) -> Self {
Self {
data: vec![0.0_f32; size],
size,
device_id,
}
}
pub fn from_slice(data: &[f32], device_id: usize) -> Self {
let size = data.len();
Self {
data: data.to_vec(),
size,
device_id,
}
}
pub fn to_vec(&self) -> Vec<f32> {
self.data.clone()
}
pub fn size(&self) -> usize {
self.size
}
pub fn device_id(&self) -> usize {
self.device_id
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct LaunchConfig {
pub grid_dim: (u32, u32, u32),
pub block_dim: (u32, u32, u32),
pub shared_mem_bytes: u32,
}
const DEFAULT_BLOCK_SIZE: u32 = 256;
impl LaunchConfig {
pub fn for_n_elements(n: usize) -> Self {
let block = DEFAULT_BLOCK_SIZE;
let grid = ((n as u32).saturating_add(block - 1)) / block;
Self {
grid_dim: (grid.max(1), 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
}
}
pub fn default_1d() -> Self {
Self {
grid_dim: (1, 1, 1),
block_dim: (DEFAULT_BLOCK_SIZE, 1, 1),
shared_mem_bytes: 0,
}
}
}
#[derive(Debug, Error)]
pub enum GpuError {
#[error("GPU not available: {0}")]
NotAvailable(String),
#[error("out of device memory: requested {requested} bytes on device {device}")]
OutOfMemory {
requested: usize,
device: usize,
},
#[error("kernel launch failed: {0}")]
KernelLaunch(String),
#[error("device synchronization failed: {0}")]
SyncFailed(String),
#[error("invalid argument: {0}")]
InvalidArgument(String),
}
pub trait GpuBackendTrait: Send + Sync {
fn name(&self) -> &'static str;
fn is_accelerated(&self) -> bool;
fn device_count(&self) -> usize;
fn alloc(&self, size: usize, device_id: usize) -> Result<DeviceBuffer, GpuError>;
fn host_to_device(&self, src: &[f32], device_id: usize) -> Result<DeviceBuffer, GpuError>;
fn device_to_host(&self, buf: &DeviceBuffer) -> Result<Vec<f32>, GpuError>;
fn matvec(
&self,
a: &DeviceBuffer,
x: &DeviceBuffer,
m: usize,
k: usize,
device_id: usize,
) -> Result<DeviceBuffer, GpuError>;
fn relu(&self, x: &DeviceBuffer, device_id: usize) -> Result<DeviceBuffer, GpuError>;
fn softmax(
&self,
x: &DeviceBuffer,
size: usize,
device_id: usize,
) -> Result<DeviceBuffer, GpuError>;
fn synchronize(&self, device_id: usize) -> Result<(), GpuError>;
fn memory_info(&self, device_id: usize) -> Result<(usize, usize), GpuError>;
fn gemv_q1_g128(
&self,
block_bytes: &[u8],
input: &[f32],
n_rows: usize,
k: usize,
) -> Result<Vec<f32>, GpuError> {
cpu_gemv_1bit_fallback(block_bytes, input, n_rows, k)
}
fn gemm_q1_g128(
&self,
block_bytes: &[u8],
input: &[f32],
m: usize,
n_rows: usize,
k: usize,
) -> Result<Vec<f32>, GpuError> {
let mut output = vec![0.0_f32; m * n_rows];
for i in 0..m {
let row_input = &input[i * k..(i + 1) * k];
let row_output = self.gemv_q1_g128(block_bytes, row_input, n_rows, k)?;
output[i * n_rows..(i + 1) * n_rows].copy_from_slice(&row_output);
}
Ok(output)
}
fn upload_weights_raw(
&self,
_block_bytes: &[u8],
) -> Result<crate::weight_cache::GpuWeightHandle, GpuError> {
Err(GpuError::NotAvailable(
"weight caching not supported by this backend".into(),
))
}
fn gemv_q1_g128_cached(
&self,
_handle: crate::weight_cache::GpuWeightHandle,
_input: &[f32],
_n_rows: usize,
_k: usize,
) -> Result<Vec<f32>, GpuError> {
Err(GpuError::NotAvailable(
"cached GEMV not supported by this backend".into(),
))
}
fn upload_weights_ternary(
&self,
_blocks: &[oxibonsai_core::BlockTQ2_0_g128],
) -> Result<crate::weight_cache::GpuWeightHandle, GpuError> {
Err(GpuError::NotAvailable(
"ternary weight upload not supported by this backend".into(),
))
}
fn gemv_tq2_g128_cached(
&self,
_handle: crate::weight_cache::GpuWeightHandle,
_input: &[f32],
_n_rows: usize,
_k: usize,
) -> Result<Vec<f32>, GpuError> {
Err(GpuError::NotAvailable(
"cached ternary GEMV not supported by this backend".into(),
))
}
#[allow(clippy::too_many_arguments, clippy::type_complexity)]
fn batch_attn_phase(
&self,
_hidden: &[f32],
_norm_weight: &[f32],
_norm_eps: f32,
_qkv_handle: crate::weight_cache::GpuWeightHandle,
_q_rows: usize,
_k_rows: usize,
_h: usize,
) -> Result<Option<(Vec<f32>, Vec<f32>, Vec<f32>)>, GpuError> {
Ok(None)
}
#[allow(clippy::too_many_arguments)]
fn batch_ffn_phase(
&self,
_hidden: &mut [f32],
_attn_out: &[f32],
_norm_weight: &[f32],
_norm_eps: f32,
_attn_proj_handle: crate::weight_cache::GpuWeightHandle,
_gate_up_handle: crate::weight_cache::GpuWeightHandle,
_down_handle: crate::weight_cache::GpuWeightHandle,
_h: usize,
_intermediate: usize,
_attn_proj_k: usize,
) -> Result<bool, GpuError> {
Ok(false)
}
}
pub type GpuBackend = dyn GpuBackendTrait;
pub struct CpuBackend {
pub simulated_memory_bytes: usize,
}
impl CpuBackend {
pub fn new() -> Self {
Self {
simulated_memory_bytes: 4 * 1024 * 1024 * 1024,
}
}
pub fn with_memory(bytes: usize) -> Self {
Self {
simulated_memory_bytes: bytes,
}
}
}
impl Default for CpuBackend {
fn default() -> Self {
Self::new()
}
}
impl GpuBackendTrait for CpuBackend {
fn name(&self) -> &'static str {
"cpu"
}
fn is_accelerated(&self) -> bool {
false
}
fn device_count(&self) -> usize {
1
}
fn alloc(&self, size: usize, device_id: usize) -> Result<DeviceBuffer, GpuError> {
Ok(DeviceBuffer::new(size, device_id))
}
fn host_to_device(&self, src: &[f32], device_id: usize) -> Result<DeviceBuffer, GpuError> {
Ok(DeviceBuffer::from_slice(src, device_id))
}
fn device_to_host(&self, buf: &DeviceBuffer) -> Result<Vec<f32>, GpuError> {
Ok(buf.to_vec())
}
fn matvec(
&self,
a: &DeviceBuffer,
x: &DeviceBuffer,
m: usize,
k: usize,
device_id: usize,
) -> Result<DeviceBuffer, GpuError> {
if a.size() != m * k {
return Err(GpuError::InvalidArgument(format!(
"matrix buffer size {} does not match m={} k={}",
a.size(),
m,
k
)));
}
if x.size() != k {
return Err(GpuError::InvalidArgument(format!(
"vector buffer size {} does not match k={}",
x.size(),
k
)));
}
let mut result = vec![0.0_f32; m];
for (row, slot) in result.iter_mut().enumerate().take(m) {
let mut acc = 0.0_f32;
for col in 0..k {
acc += a.data[row * k + col] * x.data[col];
}
*slot = acc;
}
Ok(DeviceBuffer::from_slice(&result, device_id))
}
fn relu(&self, x: &DeviceBuffer, device_id: usize) -> Result<DeviceBuffer, GpuError> {
let result: Vec<f32> = x.data.iter().map(|&v| v.max(0.0)).collect();
Ok(DeviceBuffer::from_slice(&result, device_id))
}
fn softmax(
&self,
x: &DeviceBuffer,
size: usize,
device_id: usize,
) -> Result<DeviceBuffer, GpuError> {
if x.size() != size {
return Err(GpuError::InvalidArgument(format!(
"buffer size {} does not match size={}",
x.size(),
size
)));
}
if size == 0 {
return Ok(DeviceBuffer::new(0, device_id));
}
let max_val = x.data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = x.data.iter().map(|&v| (v - max_val).exp()).collect();
let sum: f32 = exps.iter().sum();
let result: Vec<f32> = if sum == 0.0 {
vec![1.0 / size as f32; size]
} else {
exps.iter().map(|&e| e / sum).collect()
};
Ok(DeviceBuffer::from_slice(&result, device_id))
}
fn synchronize(&self, _device_id: usize) -> Result<(), GpuError> {
Ok(())
}
fn memory_info(&self, _device_id: usize) -> Result<(usize, usize), GpuError> {
let total = self.simulated_memory_bytes;
let free = total / 2;
Ok((free, total))
}
}
#[cfg(feature = "cuda")]
pub struct CudaBackend {
pub device_count: usize,
cpu_fallback: CpuBackend,
}
#[cfg(feature = "cuda")]
impl CudaBackend {
pub fn new() -> Result<Self, GpuError> {
warn!("CudaBackend: CUDA stub active — no real GPU acceleration");
Ok(Self {
device_count: 1,
cpu_fallback: CpuBackend::new(),
})
}
}
#[cfg(feature = "cuda")]
impl GpuBackendTrait for CudaBackend {
fn name(&self) -> &'static str {
"cuda"
}
fn is_accelerated(&self) -> bool {
false
}
fn device_count(&self) -> usize {
self.device_count
}
fn alloc(&self, size: usize, device_id: usize) -> Result<DeviceBuffer, GpuError> {
warn!("CudaBackend::alloc delegating to CPU fallback");
self.cpu_fallback.alloc(size, device_id)
}
fn host_to_device(&self, src: &[f32], device_id: usize) -> Result<DeviceBuffer, GpuError> {
warn!("CudaBackend::host_to_device delegating to CPU fallback");
self.cpu_fallback.host_to_device(src, device_id)
}
fn device_to_host(&self, buf: &DeviceBuffer) -> Result<Vec<f32>, GpuError> {
warn!("CudaBackend::device_to_host delegating to CPU fallback");
self.cpu_fallback.device_to_host(buf)
}
fn matvec(
&self,
a: &DeviceBuffer,
x: &DeviceBuffer,
m: usize,
k: usize,
device_id: usize,
) -> Result<DeviceBuffer, GpuError> {
warn!("CudaBackend::matvec delegating to CPU fallback");
self.cpu_fallback.matvec(a, x, m, k, device_id)
}
fn relu(&self, x: &DeviceBuffer, device_id: usize) -> Result<DeviceBuffer, GpuError> {
warn!("CudaBackend::relu delegating to CPU fallback");
self.cpu_fallback.relu(x, device_id)
}
fn softmax(
&self,
x: &DeviceBuffer,
size: usize,
device_id: usize,
) -> Result<DeviceBuffer, GpuError> {
warn!("CudaBackend::softmax delegating to CPU fallback");
self.cpu_fallback.softmax(x, size, device_id)
}
fn synchronize(&self, device_id: usize) -> Result<(), GpuError> {
warn!("CudaBackend::synchronize delegating to CPU fallback");
self.cpu_fallback.synchronize(device_id)
}
fn memory_info(&self, device_id: usize) -> Result<(usize, usize), GpuError> {
warn!("CudaBackend::memory_info delegating to CPU fallback");
self.cpu_fallback.memory_info(device_id)
}
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub struct MetalBackend {
pub device_count: usize,
cpu_fallback: CpuBackend,
}
#[cfg(all(feature = "metal", target_os = "macos"))]
impl MetalBackend {
pub fn new() -> Result<Self, GpuError> {
warn!("MetalBackend: Metal stub active — no real GPU acceleration");
Ok(Self {
device_count: 1,
cpu_fallback: CpuBackend::new(),
})
}
}
#[cfg(all(feature = "metal", target_os = "macos"))]
impl GpuBackendTrait for MetalBackend {
fn name(&self) -> &'static str {
"metal"
}
fn is_accelerated(&self) -> bool {
false
}
fn device_count(&self) -> usize {
self.device_count
}
fn alloc(&self, size: usize, device_id: usize) -> Result<DeviceBuffer, GpuError> {
warn!("MetalBackend::alloc delegating to CPU fallback");
self.cpu_fallback.alloc(size, device_id)
}
fn host_to_device(&self, src: &[f32], device_id: usize) -> Result<DeviceBuffer, GpuError> {
warn!("MetalBackend::host_to_device delegating to CPU fallback");
self.cpu_fallback.host_to_device(src, device_id)
}
fn device_to_host(&self, buf: &DeviceBuffer) -> Result<Vec<f32>, GpuError> {
warn!("MetalBackend::device_to_host delegating to CPU fallback");
self.cpu_fallback.device_to_host(buf)
}
fn matvec(
&self,
a: &DeviceBuffer,
x: &DeviceBuffer,
m: usize,
k: usize,
device_id: usize,
) -> Result<DeviceBuffer, GpuError> {
warn!("MetalBackend::matvec delegating to CPU fallback");
self.cpu_fallback.matvec(a, x, m, k, device_id)
}
fn relu(&self, x: &DeviceBuffer, device_id: usize) -> Result<DeviceBuffer, GpuError> {
warn!("MetalBackend::relu delegating to CPU fallback");
self.cpu_fallback.relu(x, device_id)
}
fn softmax(
&self,
x: &DeviceBuffer,
size: usize,
device_id: usize,
) -> Result<DeviceBuffer, GpuError> {
warn!("MetalBackend::softmax delegating to CPU fallback");
self.cpu_fallback.softmax(x, size, device_id)
}
fn synchronize(&self, device_id: usize) -> Result<(), GpuError> {
warn!("MetalBackend::synchronize delegating to CPU fallback");
self.cpu_fallback.synchronize(device_id)
}
fn memory_info(&self, device_id: usize) -> Result<(usize, usize), GpuError> {
warn!("MetalBackend::memory_info delegating to CPU fallback");
self.cpu_fallback.memory_info(device_id)
}
}
#[cfg(feature = "gpu")]
pub(crate) struct Scirs2BackendHandle(pub(crate) std::sync::Arc<Scirs2Backend>);
#[cfg(feature = "gpu")]
impl GpuBackendTrait for Scirs2BackendHandle {
fn name(&self) -> &'static str {
self.0.name()
}
fn is_accelerated(&self) -> bool {
self.0.is_accelerated()
}
fn device_count(&self) -> usize {
self.0.device_count()
}
fn alloc(&self, size: usize, device_id: usize) -> Result<DeviceBuffer, GpuError> {
self.0.alloc(size, device_id)
}
fn host_to_device(&self, src: &[f32], device_id: usize) -> Result<DeviceBuffer, GpuError> {
self.0.host_to_device(src, device_id)
}
fn device_to_host(&self, buf: &DeviceBuffer) -> Result<Vec<f32>, GpuError> {
self.0.device_to_host(buf)
}
fn matvec(
&self,
a: &DeviceBuffer,
x: &DeviceBuffer,
m: usize,
k: usize,
device_id: usize,
) -> Result<DeviceBuffer, GpuError> {
self.0.matvec(a, x, m, k, device_id)
}
fn relu(&self, x: &DeviceBuffer, device_id: usize) -> Result<DeviceBuffer, GpuError> {
self.0.relu(x, device_id)
}
fn softmax(
&self,
x: &DeviceBuffer,
size: usize,
device_id: usize,
) -> Result<DeviceBuffer, GpuError> {
self.0.softmax(x, size, device_id)
}
fn synchronize(&self, device_id: usize) -> Result<(), GpuError> {
self.0.synchronize(device_id)
}
fn memory_info(&self, device_id: usize) -> Result<(usize, usize), GpuError> {
self.0.memory_info(device_id)
}
fn gemv_q1_g128(
&self,
block_bytes: &[u8],
input: &[f32],
n_rows: usize,
k: usize,
) -> Result<Vec<f32>, GpuError> {
self.0.gemv_q1_g128(block_bytes, input, n_rows, k)
}
fn gemm_q1_g128(
&self,
block_bytes: &[u8],
input: &[f32],
m: usize,
n_rows: usize,
k: usize,
) -> Result<Vec<f32>, GpuError> {
self.0.gemm_q1_g128(block_bytes, input, m, n_rows, k)
}
fn upload_weights_raw(
&self,
block_bytes: &[u8],
) -> Result<crate::weight_cache::GpuWeightHandle, GpuError> {
self.0.upload_weights(block_bytes)
}
fn gemv_q1_g128_cached(
&self,
handle: crate::weight_cache::GpuWeightHandle,
input: &[f32],
n_rows: usize,
k: usize,
) -> Result<Vec<f32>, GpuError> {
self.0.gemv_q1_g128_cached(handle, input, n_rows, k)
}
fn upload_weights_ternary(
&self,
blocks: &[oxibonsai_core::BlockTQ2_0_g128],
) -> Result<crate::weight_cache::GpuWeightHandle, GpuError> {
self.0.upload_weights_ternary(blocks)
}
fn gemv_tq2_g128_cached(
&self,
handle: crate::weight_cache::GpuWeightHandle,
input: &[f32],
n_rows: usize,
k: usize,
) -> Result<Vec<f32>, GpuError> {
self.0.gemv_tq2_g128_cached(handle, input, n_rows, k)
}
fn batch_attn_phase(
&self,
hidden: &[f32],
norm_weight: &[f32],
norm_eps: f32,
qkv_handle: crate::weight_cache::GpuWeightHandle,
q_rows: usize,
k_rows: usize,
h: usize,
) -> Result<Option<(Vec<f32>, Vec<f32>, Vec<f32>)>, GpuError> {
match self
.0
.batch_attn_phase(hidden, norm_weight, norm_eps, qkv_handle, q_rows, k_rows, h)
{
Ok(result) => Ok(Some(result)),
Err(e) => {
tracing::warn!(error = %e, "batch_attn_phase failed, falling back");
Ok(None)
}
}
}
fn batch_ffn_phase(
&self,
hidden: &mut [f32],
attn_out: &[f32],
norm_weight: &[f32],
norm_eps: f32,
attn_proj_handle: crate::weight_cache::GpuWeightHandle,
gate_up_handle: crate::weight_cache::GpuWeightHandle,
down_handle: crate::weight_cache::GpuWeightHandle,
h: usize,
intermediate: usize,
attn_proj_k: usize,
) -> Result<bool, GpuError> {
match self.0.batch_ffn_phase(
hidden,
attn_out,
norm_weight,
norm_eps,
attn_proj_handle,
gate_up_handle,
down_handle,
h,
intermediate,
attn_proj_k,
) {
Ok(()) => Ok(true),
Err(e) => {
tracing::warn!(error = %e, "batch_ffn_phase failed, falling back");
Ok(false)
}
}
}
}
pub fn select_backend() -> Box<dyn GpuBackendTrait> {
#[cfg(feature = "gpu")]
use std::sync::atomic::{AtomicBool, Ordering};
#[cfg(feature = "gpu")]
fn warn_once(flag: &AtomicBool, msg: impl FnOnce()) {
if !flag.swap(true, Ordering::Relaxed) {
msg();
}
}
#[cfg(feature = "gpu")]
{
static SCIRS2_NOT_ACCEL: AtomicBool = AtomicBool::new(false);
static SCIRS2_INIT_FAIL: AtomicBool = AtomicBool::new(false);
match Scirs2Backend::global() {
Ok(b) => {
if b.is_accelerated() {
return Box::new(Scirs2BackendHandle(b));
}
warn_once(&SCIRS2_NOT_ACCEL, || {
warn!(
"select_backend: Scirs2Backend is not accelerated (backend={}), trying next",
b.backend_name()
);
});
}
Err(e) => {
warn_once(&SCIRS2_INIT_FAIL, || {
warn!("select_backend: Scirs2Backend init failed ({e}), trying next");
});
}
}
}
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
{
match NativeCudaBackend::new() {
Ok(b) => {
tracing::info!("select_backend: NativeCudaBackend initialised");
return Box::new(b);
}
Err(e) => {
warn!("select_backend: NativeCudaBackend init failed ({e}), trying next");
}
}
}
#[cfg(feature = "cuda")]
{
match CudaBackend::new() {
Ok(b) => {
return Box::new(b);
}
Err(e) => {
warn!("select_backend: CUDA init failed ({e}), trying next");
}
}
}
#[cfg(all(feature = "metal", target_os = "macos"))]
{
match MetalBackend::new() {
Ok(b) => {
return Box::new(b);
}
Err(e) => {
warn!("select_backend: Metal init failed ({e}), trying CPU");
}
}
}
Box::new(CpuBackend::new())
}
pub fn gpu_matmul(
backend: &dyn GpuBackendTrait,
a: &[f32],
b: &[f32],
m: usize,
k: usize,
n: usize,
device_id: usize,
) -> Result<Vec<f32>, GpuError> {
if a.len() != m * k {
return Err(GpuError::InvalidArgument(format!(
"a.len()={} does not match m={} k={}",
a.len(),
m,
k
)));
}
if b.len() != k * n {
return Err(GpuError::InvalidArgument(format!(
"b.len()={} does not match k={} n={}",
b.len(),
k,
n
)));
}
let a_buf = backend.host_to_device(a, device_id)?;
let mut c = vec![0.0_f32; m * n];
for col in 0..n {
let b_col: Vec<f32> = (0..k).map(|row| b[row * n + col]).collect();
let x_buf = backend.host_to_device(&b_col, device_id)?;
let y_buf = backend.matvec(&a_buf, &x_buf, m, k, device_id)?;
let y = backend.device_to_host(&y_buf)?;
for row in 0..m {
c[row * n + col] = y[row];
}
}
backend.synchronize(device_id)?;
Ok(c)
}
pub fn gpu_gemv_1bit(
block_bytes: &[u8],
input: &[f32],
n_rows: usize,
k: usize,
) -> Result<Vec<f32>, GpuError> {
#[cfg(feature = "gpu")]
{
match Scirs2Backend::global() {
Ok(backend) => {
if backend.is_accelerated() {
return backend.gemv_q1_g128(block_bytes, input, n_rows, k);
}
}
Err(e) => {
warn!("gpu_gemv_1bit: GPU init failed ({e}), using CPU fallback");
}
}
}
cpu_gemv_1bit_fallback(block_bytes, input, n_rows, k)
}
fn cpu_gemv_1bit_fallback(
block_bytes: &[u8],
input: &[f32],
n_rows: usize,
k: usize,
) -> Result<Vec<f32>, GpuError> {
if k == 0 || k % 128 != 0 {
return Err(GpuError::InvalidArgument(format!(
"k={k} must be a positive multiple of 128"
)));
}
if input.len() != k {
return Err(GpuError::InvalidArgument(format!(
"input.len()={} != k={}",
input.len(),
k
)));
}
let blocks_per_row = k / 128;
let block_size = 18_usize;
let expected = n_rows * blocks_per_row * block_size;
if block_bytes.len() < expected {
return Err(GpuError::InvalidArgument(format!(
"block_bytes too small: {} < {}",
block_bytes.len(),
expected,
)));
}
let mut output = vec![0.0_f32; n_rows];
for (row, output_val) in output.iter_mut().enumerate().take(n_rows) {
let mut sum = 0.0_f32;
for b in 0..blocks_per_row {
let block_idx = row * blocks_per_row + b;
let off = block_idx * block_size;
let d_bits = u16::from_le_bytes([block_bytes[off], block_bytes[off + 1]]);
let scale = half::f16::from_bits(d_bits).to_f32();
let input_base = b * 128;
for w in 0..4_usize {
let byte_off = off + 2 + w * 4;
let bits = u32::from_le_bytes([
block_bytes[byte_off],
block_bytes[byte_off + 1],
block_bytes[byte_off + 2],
block_bytes[byte_off + 3],
]);
let base = input_base + w * 32;
for i in 0..32_usize {
let sign = if (bits >> i) & 1 == 1 {
1.0_f32
} else {
-1.0_f32
};
sum += scale * sign * input[base + i];
}
}
}
*output_val = sum;
}
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn device_buffer_new_zeroed() {
let buf = DeviceBuffer::new(4, 0);
assert_eq!(buf.size(), 4);
assert_eq!(buf.device_id(), 0);
assert!(buf.data.iter().all(|&v| v == 0.0));
}
#[test]
fn device_buffer_from_slice_roundtrip() {
let src = [1.0_f32, 2.0, 3.0];
let buf = DeviceBuffer::from_slice(&src, 1);
assert_eq!(buf.to_vec(), src);
}
#[test]
fn launch_config_for_zero_elements() {
let cfg = LaunchConfig::for_n_elements(0);
assert_eq!(cfg.grid_dim.0, 1);
}
#[test]
fn cpu_softmax_empty() {
let backend = CpuBackend::new();
let buf = DeviceBuffer::new(0, 0);
let out = backend.softmax(&buf, 0, 0).expect("softmax empty");
assert_eq!(out.size(), 0);
}
#[test]
fn cpu_gemv_1bit_identity_scale() {
let scale = half::f16::from_f32(1.0);
let scale_bytes = scale.to_bits().to_le_bytes();
let mut block = vec![0u8; 18];
block[0] = scale_bytes[0];
block[1] = scale_bytes[1];
block[2..18].fill(0xFF);
let input: Vec<f32> = (0..128).map(|i| i as f32).collect();
let expected: f32 = input.iter().sum();
let result =
cpu_gemv_1bit_fallback(&block, &input, 1, 128).expect("cpu_gemv_1bit_fallback");
assert!(
(result[0] - expected).abs() < 1e-2,
"got {} expected {}",
result[0],
expected,
);
}
#[test]
fn cpu_gemv_1bit_negative_scale() {
let scale = half::f16::from_f32(1.0);
let scale_bytes = scale.to_bits().to_le_bytes();
let mut block = vec![0u8; 18];
block[0] = scale_bytes[0];
block[1] = scale_bytes[1];
let input = vec![1.0_f32; 128];
let result =
cpu_gemv_1bit_fallback(&block, &input, 1, 128).expect("cpu_gemv_1bit_fallback");
assert!(
(result[0] - (-128.0)).abs() < 1e-2,
"got {} expected -128",
result[0],
);
}
#[test]
fn cpu_gemv_1bit_bad_k() {
let result = cpu_gemv_1bit_fallback(&[], &[], 0, 64);
assert!(result.is_err());
}
#[test]
fn gpu_gemv_1bit_without_gpu() {
let scale = half::f16::from_f32(1.0);
let scale_bytes = scale.to_bits().to_le_bytes();
let mut block = vec![0u8; 18];
block[0] = scale_bytes[0];
block[1] = scale_bytes[1];
block[2..18].fill(0xFF);
let input: Vec<f32> = vec![1.0_f32; 128];
let result = gpu_gemv_1bit(&block, &input, 1, 128).expect("gpu_gemv_1bit");
assert!((result[0] - 128.0).abs() < 1e-2, "got {}", result[0]);
}
}