native_neural_network 0.3.1

Lib no_std Rust for native neural network (.rnn)
Documentation
use super::SamplingError;

pub fn softmax_temperature_f32(
    logits: &[f32],
    temperature: f32,
    out: &mut [f32],
) -> Result<(), SamplingError> {
    if logits.is_empty() {
        return Err(SamplingError::Empty);
    }
    if out.len() < logits.len() {
        return Err(SamplingError::ShapeMismatch);
    }
    if !temperature.is_finite() || temperature <= 0.0 {
        return Err(SamplingError::InvalidParameter);
    }

    let mut max_v = logits[0] / temperature;
    for &v in logits.iter().skip(1) {
        let t = v / temperature;
        if t > max_v {
            max_v = t;
        }
    }

    let mut sum = 0.0f32;
    for i in 0..logits.len() {
        let p = crate::math::expf((logits[i] / temperature) - max_v);
        out[i] = p;
        sum += p;
    }
    if !sum.is_finite() || sum <= 0.0 {
        return Err(SamplingError::InvalidParameter);
    }
    let inv = 1.0 / sum;
    for p in out.iter_mut().take(logits.len()) {
        *p *= inv;
    }
    Ok(())
}

pub fn softmax_temperature(
    logits: &[f32],
    temperature: f32,
    out: &mut [f32],
) -> Result<(), SamplingError> {
    softmax_temperature_f32(logits, temperature, out)
}

pub fn softmax_temperature_f64(
    logits: &[f64],
    temperature: f64,
    out: &mut [f64],
) -> Result<(), SamplingError> {
    if logits.is_empty() {
        return Err(SamplingError::Empty);
    }
    if out.len() < logits.len() {
        return Err(SamplingError::ShapeMismatch);
    }
    if !temperature.is_finite() || temperature <= 0.0 {
        return Err(SamplingError::InvalidParameter);
    }

    let mut max_v = logits[0] / temperature;
    for &v in logits.iter().skip(1) {
        let t = v / temperature;
        if t > max_v {
            max_v = t;
        }
    }

    let mut sum = 0.0f64;
    for i in 0..logits.len() {
        let p = crate::math::expd((logits[i] / temperature) - max_v);
        out[i] = p;
        sum += p;
    }
    if !sum.is_finite() || sum <= 0.0 {
        return Err(SamplingError::InvalidParameter);
    }
    let inv = 1.0f64 / sum;
    for p in out.iter_mut().take(logits.len()) {
        *p *= inv;
    }
    Ok(())
}