use oxibonsai_core::{BlockQ5K, BlockQ6K};
use oxibonsai_kernels::{gemv_q5k, gemv_q6k};
use crate::error::{ModelError, ModelResult};
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
const _: () =
assert!(std::mem::size_of::<oxibonsai_core::BlockQ5K>() == oxibonsai_core::BLOCK_Q5K_BYTES,);
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
const _: () =
assert!(std::mem::size_of::<oxibonsai_core::BlockQ6K>() == oxibonsai_core::BLOCK_Q6K_BYTES,);
#[derive(Debug)]
pub struct LinearQ5K<'a> {
blocks: &'a [BlockQ5K],
out_features: usize,
in_features: usize,
}
impl<'a> LinearQ5K<'a> {
pub fn new(
blocks: &'a [BlockQ5K],
out_features: usize,
in_features: usize,
) -> ModelResult<Self> {
const QK_K: usize = 256;
if in_features == 0 || in_features % QK_K != 0 {
return Err(ModelError::ShapeMismatch {
name: "LinearQ5K".into(),
expected: vec![out_features, in_features],
actual: vec![out_features, in_features],
});
}
let blocks_per_row = in_features / QK_K;
let expected_blocks = out_features * blocks_per_row;
if blocks.len() != expected_blocks {
return Err(ModelError::ShapeMismatch {
name: "LinearQ5K".into(),
expected: vec![expected_blocks],
actual: vec![blocks.len()],
});
}
Ok(Self {
blocks,
out_features,
in_features,
})
}
pub fn out_features(&self) -> usize {
self.out_features
}
pub fn in_features(&self) -> usize {
self.in_features
}
pub fn blocks(&self) -> &[BlockQ5K] {
self.blocks
}
pub fn forward(&self, input: &[f32], output: &mut [f32]) -> ModelResult<()> {
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
if oxibonsai_kernels::CudaGraph::global().is_ok() {
let raw = unsafe {
std::slice::from_raw_parts(
self.blocks.as_ptr().cast::<u8>(),
self.blocks.len() * oxibonsai_core::BLOCK_Q5K_BYTES,
)
};
match oxibonsai_kernels::cuda_gemv_q5k(
raw,
input,
output,
self.out_features,
self.in_features,
) {
Ok(()) => return Ok(()),
Err(e) => {
let msg = format!("{e}");
if !msg.contains("no CUDA device") {
tracing::warn!(
error = %e,
"CUDA Q5K GEMV failed, falling back to CPU scalar"
);
}
}
}
}
gemv_q5k(
self.blocks,
input,
output,
self.out_features,
self.in_features,
)
.map_err(ModelError::Kernel)
}
pub fn forward_batch(&self, input: &[f32], output: &mut [f32], m: usize) -> ModelResult<()> {
for batch in 0..m {
let input_row = &input[batch * self.in_features..(batch + 1) * self.in_features];
let output_row =
&mut output[batch * self.out_features..(batch + 1) * self.out_features];
self.forward(input_row, output_row)?;
}
Ok(())
}
}
#[derive(Debug)]
pub struct LinearQ6K<'a> {
blocks: &'a [BlockQ6K],
out_features: usize,
in_features: usize,
}
impl<'a> LinearQ6K<'a> {
pub fn new(
blocks: &'a [BlockQ6K],
out_features: usize,
in_features: usize,
) -> ModelResult<Self> {
const QK_K: usize = 256;
if in_features == 0 || in_features % QK_K != 0 {
return Err(ModelError::ShapeMismatch {
name: "LinearQ6K".into(),
expected: vec![out_features, in_features],
actual: vec![out_features, in_features],
});
}
let blocks_per_row = in_features / QK_K;
let expected_blocks = out_features * blocks_per_row;
if blocks.len() != expected_blocks {
return Err(ModelError::ShapeMismatch {
name: "LinearQ6K".into(),
expected: vec![expected_blocks],
actual: vec![blocks.len()],
});
}
Ok(Self {
blocks,
out_features,
in_features,
})
}
pub fn out_features(&self) -> usize {
self.out_features
}
pub fn in_features(&self) -> usize {
self.in_features
}
pub fn blocks(&self) -> &[BlockQ6K] {
self.blocks
}
pub fn forward(&self, input: &[f32], output: &mut [f32]) -> ModelResult<()> {
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
if oxibonsai_kernels::CudaGraph::global().is_ok() {
let raw = unsafe {
std::slice::from_raw_parts(
self.blocks.as_ptr().cast::<u8>(),
self.blocks.len() * oxibonsai_core::BLOCK_Q6K_BYTES,
)
};
match oxibonsai_kernels::cuda_gemv_q6k(
raw,
input,
output,
self.out_features,
self.in_features,
) {
Ok(()) => return Ok(()),
Err(e) => {
let msg = format!("{e}");
if !msg.contains("no CUDA device") {
tracing::warn!(
error = %e,
"CUDA Q6K GEMV failed, falling back to CPU scalar"
);
}
}
}
}
gemv_q6k(
self.blocks,
input,
output,
self.out_features,
self.in_features,
)
.map_err(ModelError::Kernel)
}
pub fn forward_batch(&self, input: &[f32], output: &mut [f32], m: usize) -> ModelResult<()> {
for batch in 0..m {
let input_row = &input[batch * self.in_features..(batch + 1) * self.in_features];
let output_row =
&mut output[batch * self.out_features..(batch + 1) * self.out_features];
self.forward(input_row, output_row)?;
}
Ok(())
}
}