use crate::error::{QuantError, QuantResult};
use crate::types::QuantTensor;
const Q8_0_BLOCK_SIZE: usize = 32;
const Q8_0_BLOCK_BYTES: usize = 34;
pub trait QuantKernel: Send + Sync {
fn dequant_block(&self, block: &[u8], output: &mut [f32]) -> QuantResult<()>;
fn gemv(
&self,
quant_matrix: &QuantTensor,
input: &[f32],
output: &mut [f32],
) -> QuantResult<()>;
fn gemm(
&self,
quant_matrix: &QuantTensor,
input: &[f32],
output: &mut [f32],
m: usize,
n: usize,
k: usize,
) -> QuantResult<()>;
fn matvec_q8_fused(
&self,
weights: &[u8],
acts_q8: &[u8],
out: &mut [f32],
n_rows: usize,
n_cols: usize,
) -> QuantResult<()> {
if out.len() < n_rows {
return Err(QuantError::DimensionMismatch {
expected: n_rows,
got: out.len(),
});
}
let bs = self.block_size();
let bb = self.block_bytes();
if bs == 0 {
return Err(QuantError::KernelError {
message: "block_size() returned 0 — cannot fuse GEMV".to_string(),
});
}
let blocks_per_row = n_cols.div_ceil(bs);
let row_bytes = blocks_per_row * bb;
let acts_needed = blocks_per_row * Q8_0_BLOCK_BYTES;
if weights.len() < n_rows * row_bytes {
return Err(QuantError::BufferTooSmall {
needed: n_rows * row_bytes,
available: weights.len(),
});
}
if acts_q8.len() < acts_needed {
return Err(QuantError::BufferTooSmall {
needed: acts_needed,
available: acts_q8.len(),
});
}
let mut w_scratch = vec![0.0f32; bs];
let mut a_scratch = [0.0f32; Q8_0_BLOCK_SIZE];
for (row, out_val) in out.iter_mut().enumerate().take(n_rows) {
let row_start = row * row_bytes;
let mut sum = 0.0f32;
for blk in 0..blocks_per_row {
let w_block_start = row_start + blk * bb;
let w_block = &weights[w_block_start..w_block_start + bb];
self.dequant_block(w_block, &mut w_scratch)?;
let a_block_start = blk * Q8_0_BLOCK_BYTES;
let a_block = &acts_q8[a_block_start..a_block_start + Q8_0_BLOCK_BYTES];
let d_a =
half::f16::from_bits(u16::from_le_bytes([a_block[0], a_block[1]])).to_f32();
let q8_bytes = &a_block[2..];
let w_start = blk * bs;
let w_end = (w_start + bs).min(n_cols);
let valid = w_end - w_start;
for i in 0..valid {
let q = q8_bytes[i] as i8;
a_scratch[i] = q as f32 * d_a;
}
for i in 0..valid {
sum += w_scratch[i] * a_scratch[i];
}
}
*out_val += sum;
}
Ok(())
}
fn block_size(&self) -> usize;
fn block_bytes(&self) -> usize;
fn name(&self) -> &'static str;
}