native_neural_network_std 0.2.1

Ergonomic std wrapper for the `native_neural_network` crate (no_std) — std-friendly re-exports and utilities.
Documentation
#[derive(Debug)]
pub enum QuantizationStdError {
    Empty,
    ShapeMismatch,
    InvalidScale,
}

impl From<native_neural_network::quantization::QuantError> for QuantizationStdError {
    fn from(e: native_neural_network::quantization::QuantError) -> Self {
        match e {
            native_neural_network::quantization::QuantError::Empty => QuantizationStdError::Empty,
            native_neural_network::quantization::QuantError::ShapeMismatch => {
                QuantizationStdError::ShapeMismatch
            }
            native_neural_network::quantization::QuantError::InvalidScale => {
                QuantizationStdError::InvalidScale
            }
        }
    }
}

pub fn quantize_i8_symmetric(
    input: &[f32],
    output: &mut [i8],
) -> Result<f32, QuantizationStdError> {
    native_neural_network::quantization::quantize_i8_symmetric(input, output).map_err(Into::into)
}

pub fn dequantize_i8_symmetric(
    input: &[i8],
    output: &mut [f32],
    scale: f32,
) -> Result<(), QuantizationStdError> {
    native_neural_network::quantization::dequantize_i8_symmetric(input, output, scale)
        .map_err(Into::into)
}

#[derive(Clone, Copy, Debug)]
pub struct MatMulScales {
    pub scale_a: f32,
    pub scale_b: f32,
}

#[derive(Clone, Copy, Debug)]
pub struct MatMulParams {
    pub m: usize,
    pub n: usize,
    pub k: usize,
    pub scales: MatMulScales,
}

pub fn matmul_i8_f32(
    a: &[i8],
    b: &[i8],
    out: &mut [f32],
    params: MatMulParams,
) -> Result<(), QuantizationStdError> {
    let scale_a = params.scales.scale_a;
    let scale_b = params.scales.scale_b;
    if !scale_a.is_finite() || !scale_b.is_finite() || scale_a <= 0.0 || scale_b <= 0.0 {
        return Err(QuantizationStdError::InvalidScale);
    }
    let m = params.m;
    let n = params.n;
    let k = params.k;
    let len_a = m
        .checked_mul(k)
        .ok_or(QuantizationStdError::ShapeMismatch)?;
    let len_b = k
        .checked_mul(n)
        .ok_or(QuantizationStdError::ShapeMismatch)?;
    let len_o = m
        .checked_mul(n)
        .ok_or(QuantizationStdError::ShapeMismatch)?;
    if a.len() < len_a || b.len() < len_b || out.len() < len_o {
        return Err(QuantizationStdError::ShapeMismatch);
    }
    let scale = scale_a * scale_b;
    for row in 0..m {
        for col in 0..n {
            let mut acc = 0i32;
            for p in 0..k {
                acc += (a[row * k + p] as i32) * (b[p * n + col] as i32);
            }
            out[row * n + col] = acc as f32 * scale;
        }
    }
    Ok(())
}

pub fn matmul_i8_f32_with_scales(
    a: &[i8],
    b: &[i8],
    out: &mut [f32],
    params: MatMulParams,
) -> Result<(), QuantizationStdError> {
    matmul_i8_f32(a, b, out, params)
}
impl core::fmt::Display for QuantizationStdError {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        write!(f, "QuantizationStdError::{:?}", self)
    }
}

impl std::error::Error for QuantizationStdError {}