native_neural_network 0.3.1

Lib no_std Rust for native neural network (.rnn)
Documentation
use super::NormError;

pub fn rms_norm_in_place_f32(x: &mut [f32], gamma: &[f32], eps: f32) -> Result<(), NormError> {
    if x.is_empty() {
        return Err(NormError::Empty);
    }
    if gamma.len() != x.len() {
        return Err(NormError::ShapeMismatch);
    }
    if !eps.is_finite() || eps <= 0.0 {
        return Err(NormError::InvalidEps);
    }

    if crate::engine::try_invoke_gpu_rms_norm_f32(x, gamma, eps) {
        return Ok(());
    }

    let mut ms = 0.0f32;
    for v in x.iter() {
        ms += *v * *v;
    }
    ms /= x.len() as f32;
    let inv_rms = 1.0 / crate::math::sqrtf(ms + eps);
    for i in 0..x.len() {
        x[i] = x[i] * inv_rms * gamma[i];
    }

    Ok(())
}

pub fn rms_norm_in_place(x: &mut [f32], gamma: &[f32], eps: f32) -> Result<(), NormError> {
    rms_norm_in_place_f32(x, gamma, eps)
}

pub fn rms_norm_in_place_f64(x: &mut [f64], gamma: &[f64], eps: f64) -> Result<(), NormError> {
    if x.is_empty() {
        return Err(NormError::Empty);
    }
    if gamma.len() != x.len() {
        return Err(NormError::ShapeMismatch);
    }
    if !eps.is_finite() || eps <= 0.0 {
        return Err(NormError::InvalidEps);
    }

    if crate::engine::try_invoke_gpu_rms_norm_f64(x, gamma, eps) {
        return Ok(());
    }

    let mut ms = 0.0f64;
    for v in x.iter() {
        ms += *v * *v;
    }
    ms /= x.len() as f64;
    let inv_rms = 1.0 / crate::math::sqrtd(ms + eps);
    for i in 0..x.len() {
        x[i] = x[i] * inv_rms * gamma[i];
    }

    Ok(())
}

pub fn rms_norm_f32(
    input: &[f32],
    gamma: &[f32],
    eps: f32,
    out: &mut [f32],
) -> Result<(), NormError> {
    if out.len() != input.len() {
        return Err(NormError::ShapeMismatch);
    }
    out.copy_from_slice(input);
    rms_norm_in_place_f32(out, gamma, eps)
}

pub fn rms_norm_f64(
    input: &[f64],
    gamma: &[f64],
    eps: f64,
    out: &mut [f64],
) -> Result<(), NormError> {
    if out.len() != input.len() {
        return Err(NormError::ShapeMismatch);
    }
    out.copy_from_slice(input);
    rms_norm_in_place_f64(out, gamma, eps)
}