native_neural_network 0.1.6

Lib no_std Rust for native neural network (.rnn)
Documentation
use crate::layers::LayerSpec;

use super::RnnApiError;

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct RnnDenseCapacity {
    pub layer_specs: usize,
    pub weights: usize,
    pub biases: usize,
    pub infer_scratch: usize,
}

pub fn rnn_dense_required_buffers(topology: &[usize]) -> Result<RnnDenseCapacity, RnnApiError> {
    if topology.len() < 2 {
        return Err(RnnApiError::InvalidTopology);
    }

    let mut weights = 0usize;
    let mut biases = 0usize;
    let mut max_width = 0usize;

    for i in 0..topology.len() {
        let w = topology[i];
        if w == 0 {
            return Err(RnnApiError::InvalidTopology);
        }
        if w > max_width {
            max_width = w;
        }
        if i + 1 < topology.len() {
            let next = topology[i + 1];
            weights = weights
                .checked_add(w.checked_mul(next).ok_or(RnnApiError::CapacityTooSmall)?)
                .ok_or(RnnApiError::CapacityTooSmall)?;
            biases = biases.checked_add(next).ok_or(RnnApiError::CapacityTooSmall)?;
        }
    }

    let infer_scratch = max_width
        .checked_mul(2)
        .ok_or(RnnApiError::CapacityTooSmall)?;

    Ok(RnnDenseCapacity {
        layer_specs: topology.len() - 1,
        weights,
        biases,
        infer_scratch,
    })
}

pub fn rnn_dense_required_infer_scratch_from_specs(layers: &[LayerSpec]) -> Result<usize, RnnApiError> {
    if layers.is_empty() {
        return Err(RnnApiError::InvalidTopology);
    }

    let mut max_width = 0usize;
    for layer in layers {
        let in_w = layer.input_size();
        let out_w = layer.output_size();
        if in_w == 0 || out_w == 0 {
            return Err(RnnApiError::InvalidTopology);
        }
        if in_w > max_width {
            max_width = in_w;
        }
        if out_w > max_width {
            max_width = out_w;
        }
    }

    max_width
        .checked_mul(2)
        .ok_or(RnnApiError::CapacityTooSmall)
}