use oxibonsai_core::{BlockQ4_0, BlockQ8_0, QK_Q4_0, QK_Q8_0};
use oxibonsai_kernels::{gemv_q4_0, gemv_q8_0};
use crate::error::{ModelError, ModelResult};
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
const _: () =
assert!(std::mem::size_of::<oxibonsai_core::BlockQ4_0>() == oxibonsai_core::BLOCK_Q4_0_BYTES,);
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
const _: () =
assert!(std::mem::size_of::<oxibonsai_core::BlockQ8_0>() == oxibonsai_core::BLOCK_Q8_0_BYTES,);
#[derive(Debug)]
pub struct LinearQ4_0<'a> {
blocks: &'a [BlockQ4_0],
out_features: usize,
in_features: usize,
}
impl<'a> LinearQ4_0<'a> {
pub fn new(
blocks: &'a [BlockQ4_0],
out_features: usize,
in_features: usize,
) -> ModelResult<Self> {
if in_features == 0 || in_features % QK_Q4_0 != 0 {
return Err(ModelError::ShapeMismatch {
name: "LinearQ4_0".into(),
expected: vec![out_features, in_features],
actual: vec![out_features, in_features],
});
}
let expected_blocks = out_features * (in_features / QK_Q4_0);
if blocks.len() != expected_blocks {
return Err(ModelError::ShapeMismatch {
name: "LinearQ4_0".into(),
expected: vec![expected_blocks],
actual: vec![blocks.len()],
});
}
Ok(Self {
blocks,
out_features,
in_features,
})
}
pub fn out_features(&self) -> usize {
self.out_features
}
pub fn in_features(&self) -> usize {
self.in_features
}
pub fn blocks(&self) -> &[BlockQ4_0] {
self.blocks
}
pub fn forward(&self, input: &[f32], output: &mut [f32]) -> ModelResult<()> {
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
if oxibonsai_kernels::CudaGraph::global().is_ok() {
let raw = unsafe {
std::slice::from_raw_parts(
self.blocks.as_ptr().cast::<u8>(),
self.blocks.len() * oxibonsai_core::BLOCK_Q4_0_BYTES,
)
};
match oxibonsai_kernels::cuda_gemv_q4_0(
raw,
input,
output,
self.out_features,
self.in_features,
) {
Ok(()) => return Ok(()),
Err(e) => {
let msg = format!("{e}");
if !msg.contains("no CUDA device") {
tracing::warn!(
error = %e,
"CUDA Q4_0 GEMV failed, falling back to CPU scalar"
);
}
}
}
}
gemv_q4_0(
self.blocks,
input,
output,
self.out_features,
self.in_features,
)
.map_err(ModelError::Kernel)
}
pub fn forward_batch(&self, input: &[f32], output: &mut [f32], m: usize) -> ModelResult<()> {
for t in 0..m {
let inp = &input[t * self.in_features..(t + 1) * self.in_features];
let out = &mut output[t * self.out_features..(t + 1) * self.out_features];
self.forward(inp, out)?;
}
Ok(())
}
}
#[derive(Debug)]
pub struct LinearQ8_0<'a> {
blocks: &'a [BlockQ8_0],
out_features: usize,
in_features: usize,
}
impl<'a> LinearQ8_0<'a> {
pub fn new(
blocks: &'a [BlockQ8_0],
out_features: usize,
in_features: usize,
) -> ModelResult<Self> {
if in_features == 0 || in_features % QK_Q8_0 != 0 {
return Err(ModelError::ShapeMismatch {
name: "LinearQ8_0".into(),
expected: vec![out_features, in_features],
actual: vec![out_features, in_features],
});
}
let expected_blocks = out_features * (in_features / QK_Q8_0);
if blocks.len() != expected_blocks {
return Err(ModelError::ShapeMismatch {
name: "LinearQ8_0".into(),
expected: vec![expected_blocks],
actual: vec![blocks.len()],
});
}
Ok(Self {
blocks,
out_features,
in_features,
})
}
pub fn out_features(&self) -> usize {
self.out_features
}
pub fn in_features(&self) -> usize {
self.in_features
}
pub fn blocks(&self) -> &[BlockQ8_0] {
self.blocks
}
pub fn forward(&self, input: &[f32], output: &mut [f32]) -> ModelResult<()> {
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
if oxibonsai_kernels::CudaGraph::global().is_ok() {
let raw = unsafe {
std::slice::from_raw_parts(
self.blocks.as_ptr().cast::<u8>(),
self.blocks.len() * oxibonsai_core::BLOCK_Q8_0_BYTES,
)
};
match oxibonsai_kernels::cuda_gemv_q8_0(
raw,
input,
output,
self.out_features,
self.in_features,
) {
Ok(()) => return Ok(()),
Err(e) => {
let msg = format!("{e}");
if !msg.contains("no CUDA device") {
tracing::warn!(
error = %e,
"CUDA Q8_0 GEMV failed, falling back to CPU scalar"
);
}
}
}
}
gemv_q8_0(
self.blocks,
input,
output,
self.out_features,
self.in_features,
)
.map_err(ModelError::Kernel)
}
pub fn forward_batch(&self, input: &[f32], output: &mut [f32], m: usize) -> ModelResult<()> {
for t in 0..m {
let inp = &input[t * self.in_features..(t + 1) * self.in_features];
let out = &mut output[t * self.out_features..(t + 1) * self.out_features];
self.forward(inp, out)?;
}
Ok(())
}
}