use oxibonsai_core::{BlockQ2K, BlockQ3K, BlockQ4K, BlockQ8K};
use oxibonsai_kernels::{gemv_q2k, gemv_q3k, gemv_q4k, gemv_q8k};
use crate::error::{ModelError, ModelResult};
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
const _: () =
assert!(std::mem::size_of::<oxibonsai_core::BlockQ2K>() == oxibonsai_core::BLOCK_Q2_K_BYTES,);
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
const _: () =
assert!(std::mem::size_of::<oxibonsai_core::BlockQ3K>() == oxibonsai_core::BLOCK_Q3K_BYTES,);
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
const _: () =
assert!(std::mem::size_of::<oxibonsai_core::BlockQ4K>() == oxibonsai_core::BLOCK_Q4_K_BYTES,);
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
const _: () =
assert!(std::mem::size_of::<oxibonsai_core::BlockQ8K>() == oxibonsai_core::BLOCK_Q8K_BYTES,);
#[derive(Debug)]
pub struct LinearQ2K<'a> {
blocks: &'a [BlockQ2K],
out_features: usize,
in_features: usize,
}
impl<'a> LinearQ2K<'a> {
pub fn new(
blocks: &'a [BlockQ2K],
out_features: usize,
in_features: usize,
) -> ModelResult<Self> {
const QK_K: usize = 256;
if in_features == 0 || in_features % QK_K != 0 {
return Err(ModelError::ShapeMismatch {
name: "LinearQ2K".into(),
expected: vec![out_features, in_features],
actual: vec![out_features, in_features],
});
}
let blocks_per_row = in_features / QK_K;
let expected_blocks = out_features * blocks_per_row;
if blocks.len() != expected_blocks {
return Err(ModelError::ShapeMismatch {
name: "LinearQ2K".into(),
expected: vec![expected_blocks],
actual: vec![blocks.len()],
});
}
Ok(Self {
blocks,
out_features,
in_features,
})
}
pub fn out_features(&self) -> usize {
self.out_features
}
pub fn in_features(&self) -> usize {
self.in_features
}
pub fn blocks(&self) -> &[BlockQ2K] {
self.blocks
}
pub fn memory_bytes(&self) -> usize {
self.blocks.len() * oxibonsai_core::BLOCK_Q2_K_BYTES
}
pub fn forward(&self, input: &[f32], output: &mut [f32]) -> ModelResult<()> {
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
if oxibonsai_kernels::CudaGraph::global().is_ok() {
let raw = unsafe {
std::slice::from_raw_parts(
self.blocks.as_ptr().cast::<u8>(),
self.blocks.len() * oxibonsai_core::BLOCK_Q2_K_BYTES,
)
};
match oxibonsai_kernels::cuda_gemv_q2k(
raw,
input,
output,
self.out_features,
self.in_features,
) {
Ok(()) => return Ok(()),
Err(e) => {
let msg = format!("{e}");
if !msg.contains("no CUDA device") {
tracing::warn!(
error = %e,
"CUDA Q2K GEMV failed, falling back to CPU scalar"
);
}
}
}
}
gemv_q2k(
self.blocks,
input,
output,
self.out_features,
self.in_features,
)
.map_err(ModelError::Kernel)
}
pub fn forward_batch(&self, input: &[f32], output: &mut [f32], m: usize) -> ModelResult<()> {
for batch in 0..m {
let input_row = &input[batch * self.in_features..(batch + 1) * self.in_features];
let output_row =
&mut output[batch * self.out_features..(batch + 1) * self.out_features];
self.forward(input_row, output_row)?;
}
Ok(())
}
}
#[derive(Debug)]
pub struct LinearQ3K<'a> {
blocks: &'a [BlockQ3K],
out_features: usize,
in_features: usize,
}
impl<'a> LinearQ3K<'a> {
pub fn new(
blocks: &'a [BlockQ3K],
out_features: usize,
in_features: usize,
) -> ModelResult<Self> {
const QK_K: usize = 256;
if in_features == 0 || in_features % QK_K != 0 {
return Err(ModelError::ShapeMismatch {
name: "LinearQ3K".into(),
expected: vec![out_features, in_features],
actual: vec![out_features, in_features],
});
}
let blocks_per_row = in_features / QK_K;
let expected_blocks = out_features * blocks_per_row;
if blocks.len() != expected_blocks {
return Err(ModelError::ShapeMismatch {
name: "LinearQ3K".into(),
expected: vec![expected_blocks],
actual: vec![blocks.len()],
});
}
Ok(Self {
blocks,
out_features,
in_features,
})
}
pub fn out_features(&self) -> usize {
self.out_features
}
pub fn in_features(&self) -> usize {
self.in_features
}
pub fn blocks(&self) -> &[BlockQ3K] {
self.blocks
}
pub fn memory_bytes(&self) -> usize {
self.blocks.len() * oxibonsai_core::BLOCK_Q3K_BYTES
}
pub fn forward(&self, input: &[f32], output: &mut [f32]) -> ModelResult<()> {
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
if oxibonsai_kernels::CudaGraph::global().is_ok() {
let raw = unsafe {
std::slice::from_raw_parts(
self.blocks.as_ptr().cast::<u8>(),
self.blocks.len() * oxibonsai_core::BLOCK_Q3K_BYTES,
)
};
match oxibonsai_kernels::cuda_gemv_q3k(
raw,
input,
output,
self.out_features,
self.in_features,
) {
Ok(()) => return Ok(()),
Err(e) => {
let msg = format!("{e}");
if !msg.contains("no CUDA device") {
tracing::warn!(
error = %e,
"CUDA Q3K GEMV failed, falling back to CPU scalar"
);
}
}
}
}
gemv_q3k(
self.blocks,
input,
output,
self.out_features,
self.in_features,
)
.map_err(ModelError::Kernel)
}
pub fn forward_batch(&self, input: &[f32], output: &mut [f32], m: usize) -> ModelResult<()> {
for batch in 0..m {
let input_row = &input[batch * self.in_features..(batch + 1) * self.in_features];
let output_row =
&mut output[batch * self.out_features..(batch + 1) * self.out_features];
self.forward(input_row, output_row)?;
}
Ok(())
}
}
#[derive(Debug)]
pub struct LinearQ4K<'a> {
blocks: &'a [BlockQ4K],
out_features: usize,
in_features: usize,
}
impl<'a> LinearQ4K<'a> {
pub fn new(
blocks: &'a [BlockQ4K],
out_features: usize,
in_features: usize,
) -> ModelResult<Self> {
const QK_K: usize = 256;
if in_features == 0 || in_features % QK_K != 0 {
return Err(ModelError::ShapeMismatch {
name: "LinearQ4K".into(),
expected: vec![out_features, in_features],
actual: vec![out_features, in_features],
});
}
let blocks_per_row = in_features / QK_K;
let expected_blocks = out_features * blocks_per_row;
if blocks.len() != expected_blocks {
return Err(ModelError::ShapeMismatch {
name: "LinearQ4K".into(),
expected: vec![expected_blocks],
actual: vec![blocks.len()],
});
}
Ok(Self {
blocks,
out_features,
in_features,
})
}
pub fn out_features(&self) -> usize {
self.out_features
}
pub fn in_features(&self) -> usize {
self.in_features
}
pub fn blocks(&self) -> &[BlockQ4K] {
self.blocks
}
pub fn memory_bytes(&self) -> usize {
self.blocks.len() * oxibonsai_core::BLOCK_Q4_K_BYTES
}
pub fn forward(&self, input: &[f32], output: &mut [f32]) -> ModelResult<()> {
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
if oxibonsai_kernels::CudaGraph::global().is_ok() {
let raw = unsafe {
std::slice::from_raw_parts(
self.blocks.as_ptr().cast::<u8>(),
self.blocks.len() * oxibonsai_core::BLOCK_Q4_K_BYTES,
)
};
match oxibonsai_kernels::cuda_gemv_q4k(
raw,
input,
output,
self.out_features,
self.in_features,
) {
Ok(()) => return Ok(()),
Err(e) => {
let msg = format!("{e}");
if !msg.contains("no CUDA device") {
tracing::warn!(
error = %e,
"CUDA Q4K GEMV failed, falling back to CPU scalar"
);
}
}
}
}
gemv_q4k(
self.blocks,
input,
output,
self.out_features,
self.in_features,
)
.map_err(ModelError::Kernel)
}
pub fn forward_batch(&self, input: &[f32], output: &mut [f32], m: usize) -> ModelResult<()> {
for batch in 0..m {
let input_row = &input[batch * self.in_features..(batch + 1) * self.in_features];
let output_row =
&mut output[batch * self.out_features..(batch + 1) * self.out_features];
self.forward(input_row, output_row)?;
}
Ok(())
}
}
#[derive(Debug)]
pub struct LinearQ8K<'a> {
blocks: &'a [BlockQ8K],
out_features: usize,
in_features: usize,
}
impl<'a> LinearQ8K<'a> {
pub fn new(
blocks: &'a [BlockQ8K],
out_features: usize,
in_features: usize,
) -> ModelResult<Self> {
const QK_K: usize = 256;
if in_features == 0 || in_features % QK_K != 0 {
return Err(ModelError::ShapeMismatch {
name: "LinearQ8K".into(),
expected: vec![out_features, in_features],
actual: vec![out_features, in_features],
});
}
let blocks_per_row = in_features / QK_K;
let expected_blocks = out_features * blocks_per_row;
if blocks.len() != expected_blocks {
return Err(ModelError::ShapeMismatch {
name: "LinearQ8K".into(),
expected: vec![expected_blocks],
actual: vec![blocks.len()],
});
}
Ok(Self {
blocks,
out_features,
in_features,
})
}
pub fn out_features(&self) -> usize {
self.out_features
}
pub fn in_features(&self) -> usize {
self.in_features
}
pub fn blocks(&self) -> &[BlockQ8K] {
self.blocks
}
pub fn memory_bytes(&self) -> usize {
self.blocks.len() * oxibonsai_core::BLOCK_Q8K_BYTES
}
pub fn forward(&self, input: &[f32], output: &mut [f32]) -> ModelResult<()> {
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
if oxibonsai_kernels::CudaGraph::global().is_ok() {
let raw = unsafe {
std::slice::from_raw_parts(
self.blocks.as_ptr().cast::<u8>(),
self.blocks.len() * oxibonsai_core::BLOCK_Q8K_BYTES,
)
};
match oxibonsai_kernels::cuda_gemv_q8k(
raw,
input,
output,
self.out_features,
self.in_features,
) {
Ok(()) => return Ok(()),
Err(e) => {
let msg = format!("{e}");
if !msg.contains("no CUDA device") {
tracing::warn!(
error = %e,
"CUDA Q8K GEMV failed, falling back to CPU scalar"
);
}
}
}
}
gemv_q8k(
self.blocks,
input,
output,
self.out_features,
self.in_features,
)
.map_err(ModelError::Kernel)
}
pub fn forward_batch(&self, input: &[f32], output: &mut [f32], m: usize) -> ModelResult<()> {
for batch in 0..m {
let input_row = &input[batch * self.in_features..(batch + 1) * self.in_features];
let output_row =
&mut output[batch * self.out_features..(batch + 1) * self.out_features];
self.forward(input_row, output_row)?;
}
Ok(())
}
}