use super::qtypes::QuantError;
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct MatMulConfig {
pub m: usize,
pub n: usize,
pub k: usize,
pub scale_a: f32,
pub scale_b: f32,
}
pub fn matmul_i8_f32(
a: &[i8],
b: &[i8],
out: &mut [f32],
cfg: MatMulConfig,
) -> Result<(), QuantError> {
let MatMulConfig { m, n, k, scale_a, scale_b } = cfg;
if !scale_a.is_finite() || !scale_b.is_finite() || scale_a <= 0.0 || scale_b <= 0.0 {
return Err(QuantError::InvalidScale);
}
let len_a = m.checked_mul(k).ok_or(QuantError::ShapeMismatch)?;
let len_b = k.checked_mul(n).ok_or(QuantError::ShapeMismatch)?;
let len_o = m.checked_mul(n).ok_or(QuantError::ShapeMismatch)?;
if a.len() < len_a || b.len() < len_b || out.len() < len_o {
return Err(QuantError::ShapeMismatch);
}
let scale = scale_a * scale_b;
for row in 0..m {
for col in 0..n {
let mut acc = 0i32;
for p in 0..k {
acc += (a[row * k + p] as i32) * (b[p * n + col] as i32);
}
out[row * n + col] = acc as f32 * scale;
}
}
Ok(())
}