native_neural_network 0.1.6

Lib no_std Rust for native neural network (.rnn)
Documentation
use crate::model_config::{ConfigError, TransformerConfig};
use super::profile::RuntimeProfile;

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct RuntimeEstimate {
    pub parameter_bytes: usize,
    pub activation_bytes: usize,
    pub kv_cache_bytes: usize,
    pub total_bytes: usize,
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum RuntimeError {
    InvalidProfile,
    Overflow,
    InvalidConfig,
}

pub fn estimate_runtime_memory(
    config: &TransformerConfig,
    profile: RuntimeProfile,
) -> Result<RuntimeEstimate, RuntimeError> {
    config.validate().map_err(|_| RuntimeError::InvalidConfig)?;

    if !profile.validate() {
        return Err(RuntimeError::InvalidProfile);
    }

    let params = config
        .approximate_parameter_count()
        .map_err(|e| match e {
            ConfigError::Invalid => RuntimeError::InvalidConfig,
            ConfigError::Overflow => RuntimeError::Overflow,
        })?;

    let parameter_bytes = params.checked_mul(profile.bytes_per_param).ok_or(RuntimeError::Overflow)?;
    let tokens = profile.batch_size.checked_mul(profile.sequence_len).ok_or(RuntimeError::Overflow)?;
    let activation_elems = tokens
        .checked_mul(config.hidden_size)
        .and_then(|x| x.checked_mul(config.num_layers))
        .ok_or(RuntimeError::Overflow)?;
    let activation_bytes = activation_elems
        .checked_mul(profile.bytes_per_activation)
        .ok_or(RuntimeError::Overflow)?;

    let kv_elems = tokens
        .checked_mul(config.hidden_size)
        .and_then(|x| x.checked_mul(config.num_layers))
        .and_then(|x| x.checked_mul(2))
        .ok_or(RuntimeError::Overflow)?;
    let kv_cache_bytes = kv_elems
        .checked_mul(profile.bytes_per_kv)
        .ok_or(RuntimeError::Overflow)?;

    let total_bytes = parameter_bytes
        .checked_add(activation_bytes)
        .and_then(|x| x.checked_add(kv_cache_bytes))
        .ok_or(RuntimeError::Overflow)?;

    Ok(RuntimeEstimate {
        parameter_bytes,
        activation_bytes,
        kv_cache_bytes,
        total_bytes,
    })
}