#[derive(Debug)]
pub enum QuantizationStdError {
Empty,
ShapeMismatch,
InvalidScale,
}
impl From<native_neural_network::quantization::QuantError> for QuantizationStdError {
fn from(e: native_neural_network::quantization::QuantError) -> Self {
match e {
native_neural_network::quantization::QuantError::Empty => QuantizationStdError::Empty,
native_neural_network::quantization::QuantError::ShapeMismatch => {
QuantizationStdError::ShapeMismatch
}
native_neural_network::quantization::QuantError::InvalidScale => {
QuantizationStdError::InvalidScale
}
}
}
}
pub fn quantize_i8_symmetric(
input: &[f32],
output: &mut [i8],
) -> Result<f32, QuantizationStdError> {
native_neural_network::quantization::quantize_i8_symmetric(input, output).map_err(Into::into)
}
pub fn dequantize_i8_symmetric(
input: &[i8],
output: &mut [f32],
scale: f32,
) -> Result<(), QuantizationStdError> {
native_neural_network::quantization::dequantize_i8_symmetric(input, output, scale)
.map_err(Into::into)
}
#[derive(Clone, Copy, Debug)]
pub struct MatMulScales {
pub scale_a: f32,
pub scale_b: f32,
}
#[derive(Clone, Copy, Debug)]
pub struct MatMulParams {
pub m: usize,
pub n: usize,
pub k: usize,
pub scales: MatMulScales,
}
pub fn matmul_i8_f32(
a: &[i8],
b: &[i8],
out: &mut [f32],
params: MatMulParams,
) -> Result<(), QuantizationStdError> {
let scale_a = params.scales.scale_a;
let scale_b = params.scales.scale_b;
if !scale_a.is_finite() || !scale_b.is_finite() || scale_a <= 0.0 || scale_b <= 0.0 {
return Err(QuantizationStdError::InvalidScale);
}
let m = params.m;
let n = params.n;
let k = params.k;
let len_a = m
.checked_mul(k)
.ok_or(QuantizationStdError::ShapeMismatch)?;
let len_b = k
.checked_mul(n)
.ok_or(QuantizationStdError::ShapeMismatch)?;
let len_o = m
.checked_mul(n)
.ok_or(QuantizationStdError::ShapeMismatch)?;
if a.len() < len_a || b.len() < len_b || out.len() < len_o {
return Err(QuantizationStdError::ShapeMismatch);
}
let scale = scale_a * scale_b;
for row in 0..m {
for col in 0..n {
let mut acc = 0i32;
for p in 0..k {
acc += (a[row * k + p] as i32) * (b[p * n + col] as i32);
}
out[row * n + col] = acc as f32 * scale;
}
}
Ok(())
}
pub fn matmul_i8_f32_with_scales(
a: &[i8],
b: &[i8],
out: &mut [f32],
params: MatMulParams,
) -> Result<(), QuantizationStdError> {
matmul_i8_f32(a, b, out, params)
}
impl core::fmt::Display for QuantizationStdError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "QuantizationStdError::{:?}", self)
}
}
impl std::error::Error for QuantizationStdError {}