native_neural_network 0.1.6

Lib no_std Rust for native neural network (.rnn)
Documentation
use crate::activations::ActivationKind;
use crate::engine::{forward_dense_plan, ForwardError};
use crate::initializers::expected_parameter_counts;
use crate::layers::{build_dense_specs_from_layers, LayerError, LayerPlan, LayerSpec};
use crate::model_format::{
    decode_dense_model_v1,
    encode_dense_model_v1,
    encoded_size_v1,
    DecodedCounts,
    ModelFormatError,
};

const MAGIC: &[u8; 4] = b"RMD1";
const VERSION: u16 = 1;
const HEADER_SIZE: usize = 4 + 2 + 2 + 4 + 4 + 4;

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum RnnApiError {
    InvalidTopology,
    CapacityTooSmall,
    BadBytes,
    Layer,
    Model,
    Forward,
}

pub fn rnn_required_dense_from_topology(topology: &[usize]) -> Result<DecodedCounts, RnnApiError> {
    if topology.len() < 2 {
        return Err(RnnApiError::InvalidTopology);
    }

    let (weights, biases) = expected_parameter_counts(topology).ok_or(RnnApiError::InvalidTopology)?;
    Ok(DecodedCounts {
        layers: topology.len() - 1,
        weights,
        biases,
    })
}

pub fn rnn_required_dense_from_bytes_v1(bytes: &[u8]) -> Result<DecodedCounts, RnnApiError> {
    if bytes.len() < HEADER_SIZE {
        return Err(RnnApiError::BadBytes);
    }
    if &bytes[0..4] != MAGIC {
        return Err(RnnApiError::BadBytes);
    }

    let version = u16::from_le_bytes([bytes[4], bytes[5]]);
    if version != VERSION {
        return Err(RnnApiError::BadBytes);
    }

    let layer_count = u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]) as usize;
    let weights_len = u32::from_le_bytes([bytes[12], bytes[13], bytes[14], bytes[15]]) as usize;
    let biases_len = u32::from_le_bytes([bytes[16], bytes[17], bytes[18], bytes[19]]) as usize;

    let expected_size = encoded_size_v1(layer_count, weights_len, biases_len).ok_or(RnnApiError::BadBytes)?;
    if bytes.len() < expected_size {
        return Err(RnnApiError::BadBytes);
    }

    Ok(DecodedCounts {
        layers: layer_count,
        weights: weights_len,
        biases: biases_len,
    })
}

pub fn rnn_pack_dense_v1(
    topology: &[usize],
    hidden_activation: ActivationKind,
    output_activation: ActivationKind,
    weights: &[f32],
    biases: &[f32],
    layer_specs_scratch: &mut [LayerSpec],
    out_bytes: &mut [u8],
) -> Result<usize, RnnApiError> {
    let layer_count = build_dense_specs_from_layers(
        topology,
        hidden_activation,
        output_activation,
        weights.len(),
        biases.len(),
        layer_specs_scratch,
    )
    .map_err(map_layer_error)?;

    encode_dense_model_v1(&layer_specs_scratch[..layer_count], weights, biases, out_bytes).map_err(map_model_error)
}

pub fn rnn_run_dense_v1(
    bytes: &[u8],
    input: &[f32],
    output: &mut [f32],
    layer_specs_scratch: &mut [LayerSpec],
    weights_scratch: &mut [f32],
    biases_scratch: &mut [f32],
    infer_scratch: &mut [f32],
) -> Result<(), RnnApiError> {
    let counts = rnn_required_dense_from_bytes_v1(bytes)?;
    if layer_specs_scratch.len() < counts.layers || weights_scratch.len() < counts.weights || biases_scratch.len() < counts.biases {
        return Err(RnnApiError::CapacityTooSmall);
    }

    let decoded = decode_dense_model_v1(
        bytes,
        &mut layer_specs_scratch[..counts.layers],
        &mut weights_scratch[..counts.weights],
        &mut biases_scratch[..counts.biases],
    )
    .map_err(map_model_error)?;

    let plan = LayerPlan {
        layers: &layer_specs_scratch[..decoded.layers],
        weights: &weights_scratch[..decoded.weights],
        biases: &biases_scratch[..decoded.biases],
    };

    forward_dense_plan(&plan, input, output, infer_scratch).map_err(map_forward_error)
}

fn map_layer_error(err: LayerError) -> RnnApiError {
    match err {
        LayerError::EmptyPlan | LayerError::InvalidShape | LayerError::IncompatibleChain => RnnApiError::InvalidTopology,
        LayerError::BufferTooSmall => RnnApiError::CapacityTooSmall,
        _ => RnnApiError::Layer,
    }
}

pub(crate) fn map_model_error(err: ModelFormatError) -> RnnApiError {
    match err {
        ModelFormatError::CapacityTooSmall => RnnApiError::CapacityTooSmall,
        _ => RnnApiError::BadBytes,
    }
}

fn map_forward_error(err: ForwardError) -> RnnApiError {
    match err {
        ForwardError::ScratchTooSmall => RnnApiError::CapacityTooSmall,
        _ => RnnApiError::Forward,
    }
}