native_neural_network 0.3.1

Lib no_std Rust for native neural network (.rnn)
Documentation
use super::qtypes::QuantError;

#[derive(Clone, Copy, Debug, PartialEq)]
pub struct MatMulConfig {
    pub m: usize,
    pub n: usize,
    pub k: usize,
    pub scale_a: f32,
    pub scale_b: f32,
}

pub fn matmul_i8_f32(
    a: &[i8],
    b: &[i8],
    out: &mut [f32],
    cfg: MatMulConfig,
) -> Result<(), QuantError> {
    let MatMulConfig {
        m,
        n,
        k,
        scale_a,
        scale_b,
    } = cfg;
    if !scale_a.is_finite() || !scale_b.is_finite() || scale_a <= 0.0 || scale_b <= 0.0 {
        return Err(QuantError::InvalidScale);
    }

    let len_a = m.checked_mul(k).ok_or(QuantError::ShapeMismatch)?;
    let len_b = k.checked_mul(n).ok_or(QuantError::ShapeMismatch)?;
    let len_o = m.checked_mul(n).ok_or(QuantError::ShapeMismatch)?;
    if a.len() < len_a || b.len() < len_b || out.len() < len_o {
        return Err(QuantError::ShapeMismatch);
    }

    let scale = scale_a * scale_b;
    for row in 0..m {
        for col in 0..n {
            let mut acc = 0i32;
            for p in 0..k {
                acc += (a[row * k + p] as i32) * (b[p * n + col] as i32);
            }
            out[row * n + col] = acc as f32 * scale;
        }
    }

    Ok(())
}