native_neural_network 0.3.1

Lib no_std Rust for native neural network (.rnn)
Documentation
use super::bounds::{all_blob_bounds_valid, all_blob_digests_valid};
use super::lookup::find_blob_index;
use super::parser::{parse_rnn_from_bytes, RnnHandle};
use crate::crypto::{
    constant_time_eq, extract_rnn_crypto_timestamp, extract_rnn_issue_counter,
    extract_rnn_key_version, extract_rnn_public_has_benchmark, extract_rnn_public_has_model_name,
    extract_rnn_public_has_model_precision, extract_rnn_public_header_summary, is_encrypted_rnn,
    RNN_MIN_ACCEPTED_KEY_VERSION,
};
use crate::model_format::{
    header_tlv_payload, BLOB_BIASES, BLOB_LAYER_META, BLOB_WEIGHTS, LAYER_META_SIZE, RNN0_MAGIC,
    RNN0_VERSION, TLV_HEADER_BENCHMARK, TLV_HEADER_NETWORK_SUMMARY,
};
use crate::scratch::Scratch;

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
    Truncated,
    BadMagic,
    BadVersion,
    BadHeader,
    BadBounds,
    CapacityTooSmall,
    InvalidPayload,
    InvalidContainer,
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct HeaderSummary {
    pub dtype: u8,
    pub layer_count: u32,
    pub total_neurons: u32,
    pub weights_len: u32,
    pub biases_len: u32,
    pub blob_count: u32,
    pub flags: u32,
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct ScanReport {
    pub encrypted: bool,
    pub header: Option<HeaderSummary>,
    pub has_benchmark: Option<bool>,
    pub has_model_name: Option<bool>,
    pub has_model_precision: Option<bool>,
    pub key_version: Option<u32>,
    pub issue_counter: Option<u64>,
    pub crypto_timestamp: Option<u64>,
}

const MODEL_NAME_BLOB_DATA: &str = "model.name";
const MODEL_PRECISION_BLOB_DATA: &str = "model.precision";

fn encoded_payload_size(
    dtype: u8,
    layer_count: usize,
    weights_len: usize,
    biases_len: usize,
) -> Option<usize> {
    let elem_size = match dtype {
        0 => 4usize,
        1 => 8usize,
        _ => return None,
    };
    let layers_bytes = layer_count.checked_mul(LAYER_META_SIZE)?;
    let weights_bytes = weights_len.checked_mul(elem_size)?;
    let biases_bytes = biases_len.checked_mul(elem_size)?;
    20usize
        .checked_add(layers_bytes)?
        .checked_add(weights_bytes)?
        .checked_add(biases_bytes)
}

pub(crate) fn extract_header_summary(bytes: &[u8]) -> Result<HeaderSummary, Error> {
    let payload =
        header_tlv_payload(bytes, TLV_HEADER_NETWORK_SUMMARY).map_err(|_| Error::InvalidPayload)?;
    if payload.len() >= 16 && constant_time_eq(&payload[0..4], b"VIZ\x00") {
        return decode_compact_viz_summary(payload);
    }
    decode_header_summary(payload)
}

pub fn validate(bytes: &[u8], device_id: Option<&[u8]>) -> Result<ScanReport, Error> {
    validate_scan_backend(bytes, device_id)
}

pub(crate) fn validate_scan_backend(
    bytes: &[u8],
    device_id: Option<&[u8]>,
) -> Result<ScanReport, Error> {
    validate_inner(bytes, device_id)
}

fn validate_inner(bytes: &[u8], device_id: Option<&[u8]>) -> Result<ScanReport, Error> {
    if bytes.len() < 12 {
        return Err(Error::Truncated);
    }
    if !constant_time_eq(&bytes[0..4], RNN0_MAGIC) {
        return Err(Error::BadMagic);
    }
    let version = u16::from_le_bytes([bytes[4], bytes[5]]);
    if version != RNN0_VERSION {
        return Err(Error::BadVersion);
    }

    let mut scratch_buf = [0u8; 16384];
    let mut scratch = Scratch::new(&mut scratch_buf);
    let handle = match parse_rnn_from_bytes(bytes, &mut scratch) {
        Ok(handle) => handle,
        Err(_) => {
            if is_encrypted_rnn(bytes) {
                return validate_encrypted_container(bytes, device_id);
            }
            return Err(Error::InvalidContainer);
        }
    };

    if !all_blob_bounds_valid(&handle) {
        return Err(Error::BadBounds);
    }
    if !all_blob_digests_valid(&handle) {
        return Err(Error::BadBounds);
    }

    let layer_meta = find_layer_meta_blob(&handle).ok_or(Error::InvalidPayload)?;
    let weights = find_weights_blob(&handle).ok_or(Error::InvalidPayload)?;
    let biases = find_biases_blob(&handle).ok_or(Error::InvalidPayload)?;

    let layer_meta_idx = find_blob_index(&handle, BLOB_LAYER_META).ok_or(Error::InvalidPayload)?;
    let weights_idx = find_blob_index(&handle, BLOB_WEIGHTS).ok_or(Error::InvalidPayload)?;
    let biases_idx = find_blob_index(&handle, BLOB_BIASES).ok_or(Error::InvalidPayload)?;

    let layer_meta_desc = handle
        .blobs
        .get(layer_meta_idx)
        .ok_or(Error::InvalidPayload)?;
    let weights_desc = handle.blobs.get(weights_idx).ok_or(Error::InvalidPayload)?;
    let biases_desc = handle.blobs.get(biases_idx).ok_or(Error::InvalidPayload)?;

    if layer_meta.is_empty() || weights.is_empty() || biases.is_empty() {
        return Err(Error::InvalidPayload);
    }
    if !layer_meta.len().is_multiple_of(LAYER_META_SIZE) {
        return Err(Error::InvalidPayload);
    }
    let benchmark = crate::model_format::header_tlv_payload(bytes, TLV_HEADER_BENCHMARK)
        .map_err(|_| Error::InvalidPayload)?;
    if benchmark.is_empty() {
        return Err(Error::InvalidPayload);
    }

    if layer_meta_desc.dtype != weights_desc.dtype || weights_desc.dtype != biases_desc.dtype {
        return Err(Error::InvalidPayload);
    }

    let layer_count = layer_meta.len() / LAYER_META_SIZE;
    match weights_desc.dtype {
        0 => {
            if !weights.len().is_multiple_of(4) || !biases.len().is_multiple_of(4) {
                return Err(Error::InvalidPayload);
            }
            let weights_len = weights.len() / 4;
            let biases_len = biases.len() / 4;
            encoded_payload_size(weights_desc.dtype, layer_count, weights_len, biases_len)
                .ok_or(Error::InvalidPayload)?;
        }
        1 => {
            if !weights.len().is_multiple_of(8) || !biases.len().is_multiple_of(8) {
                return Err(Error::InvalidPayload);
            }
            let weights_len = weights.len() / 8;
            let biases_len = biases.len() / 8;
            encoded_payload_size(weights_desc.dtype, layer_count, weights_len, biases_len)
                .ok_or(Error::InvalidPayload)?;
        }
        _ => return Err(Error::InvalidPayload),
    }

    let has_benchmark =
        crate::model_format::header_tlv_payload(bytes, TLV_HEADER_BENCHMARK).is_ok();
    let has_model_name = find_blob_index(&handle, MODEL_NAME_BLOB_DATA).is_some();
    let has_model_precision = find_blob_index(&handle, MODEL_PRECISION_BLOB_DATA).is_some();
    let header = extract_header_summary(bytes).ok();

    Ok(ScanReport {
        encrypted: false,
        header,
        has_benchmark: Some(has_benchmark),
        has_model_name: Some(has_model_name),
        has_model_precision: Some(has_model_precision),
        key_version: None,
        issue_counter: None,
        crypto_timestamp: None,
    })
}

fn validate_encrypted_container(
    bytes: &[u8],
    key_material: Option<&[u8]>,
) -> Result<ScanReport, Error> {
    if let Some(material) = key_material {
        if material.is_empty() {
            return Err(Error::InvalidContainer);
        }
    }
    if !is_encrypted_rnn(bytes) {
        return Err(Error::BadMagic);
    }
    let key_version = extract_rnn_key_version(bytes).ok_or(Error::BadHeader)?;
    let issue_counter = extract_rnn_issue_counter(bytes).ok_or(Error::BadHeader)?;
    let crypto_timestamp = extract_rnn_crypto_timestamp(bytes).ok_or(Error::BadHeader)?;
    if key_version < RNN_MIN_ACCEPTED_KEY_VERSION || issue_counter == 0 || crypto_timestamp == 0 {
        return Err(Error::InvalidContainer);
    }
    let header = extract_rnn_public_header_summary(bytes).map(
        |(dtype, layer_count, total_neurons, weights_len, biases_len, blob_count, flags)| {
            HeaderSummary {
                dtype,
                layer_count,
                total_neurons,
                weights_len,
                biases_len,
                blob_count,
                flags,
            }
        },
    );
    let has_benchmark = extract_rnn_public_has_benchmark(bytes);
    Ok(ScanReport {
        encrypted: true,
        header,
        has_benchmark,
        has_model_name: extract_rnn_public_has_model_name(bytes),
        has_model_precision: extract_rnn_public_has_model_precision(bytes),
        key_version: Some(key_version),
        issue_counter: Some(issue_counter),
        crypto_timestamp: Some(crypto_timestamp),
    })
}

pub(crate) fn blob_bytes<'bytes, 'scratch>(
    handle: &RnnHandle<'bytes, 'scratch>,
    index: usize,
) -> Option<&'bytes [u8]> {
    let meta = handle.blobs.get(index)?;
    let start = usize::try_from(meta.offset).ok()?;
    let len = usize::try_from(meta.length).ok()?;
    let end = start.checked_add(len)?;
    handle.bytes.get(start..end)
}

fn find_blob_bytes<'bytes, 'scratch>(
    handle: &RnnHandle<'bytes, 'scratch>,
    name: &str,
) -> Option<&'bytes [u8]> {
    let idx = find_blob_index(handle, name)?;
    blob_bytes(handle, idx)
}

fn find_layer_meta_blob<'bytes, 'scratch>(
    handle: &RnnHandle<'bytes, 'scratch>,
) -> Option<&'bytes [u8]> {
    find_blob_bytes(handle, BLOB_LAYER_META)
}

fn find_weights_blob<'bytes, 'scratch>(
    handle: &RnnHandle<'bytes, 'scratch>,
) -> Option<&'bytes [u8]> {
    find_blob_bytes(handle, BLOB_WEIGHTS)
}

fn find_biases_blob<'bytes, 'scratch>(
    handle: &RnnHandle<'bytes, 'scratch>,
) -> Option<&'bytes [u8]> {
    find_blob_bytes(handle, BLOB_BIASES)
}

fn decode_header_summary(payload: &[u8]) -> Result<HeaderSummary, Error> {
    if payload.len() < 2 {
        return Err(Error::BadHeader);
    }
    let summary_version = payload[0];
    let codec = payload[1];
    if summary_version != 1 || codec != 1 {
        return Err(Error::BadHeader);
    }

    let compressed = &payload[2..];
    if !compressed.len().is_multiple_of(2) {
        return Err(Error::BadHeader);
    }

    let mut plain = [0u8; 30];
    let mut plain_len = 0usize;
    let mut idx = 0usize;
    while idx < compressed.len() {
        let run = compressed[idx] as usize;
        let value = compressed[idx + 1];
        if run == 0 {
            return Err(Error::BadHeader);
        }
        plain_len = plain_len.checked_add(run).ok_or(Error::BadHeader)?;
        if plain_len <= 30 {
            let start = plain_len - run;
            let end = plain_len;
            plain[start..end].fill(value);
        } else if plain_len - run < 30 {
            let start = plain_len - run;
            let end = 30usize;
            plain[start..end].fill(value);
        }
        idx += 2;
    }

    if plain_len < 30 {
        return Err(Error::BadHeader);
    }
    if !constant_time_eq(&plain[0..4], b"S5D0") {
        return Err(Error::BadHeader);
    }
    if plain[4] != 1 {
        return Err(Error::BadHeader);
    }

    let dtype = plain[5];
    let layer_count = u32::from_le_bytes([plain[6], plain[7], plain[8], plain[9]]);
    let total_neurons = u32::from_le_bytes([plain[10], plain[11], plain[12], plain[13]]);
    let weights_len = u32::from_le_bytes([plain[14], plain[15], plain[16], plain[17]]);
    let biases_len = u32::from_le_bytes([plain[18], plain[19], plain[20], plain[21]]);
    let blob_count = u32::from_le_bytes([plain[22], plain[23], plain[24], plain[25]]);
    let flags = u32::from_le_bytes([plain[26], plain[27], plain[28], plain[29]]);

    Ok(HeaderSummary {
        dtype,
        layer_count,
        total_neurons,
        weights_len,
        biases_len,
        blob_count,
        flags,
    })
}

fn decode_compact_viz_summary(payload: &[u8]) -> Result<HeaderSummary, Error> {
    if payload.len() < 16 {
        return Err(Error::BadHeader);
    }
    if !constant_time_eq(&payload[0..4], b"VIZ\x00") {
        return Err(Error::BadHeader);
    }
    if payload[3] != 0 {
        return Err(Error::BadHeader);
    }

    let dtype = payload[4];
    let layer_count = u16::from_le_bytes([payload[5], payload[6]]) as u32;
    let total_neurons = u32::from_le_bytes([payload[7], payload[8], payload[9], payload[10]]);
    let weights_len = u16::from_le_bytes([payload[11], payload[12]]) as u32;
    let biases_len = u16::from_le_bytes([payload[13], payload[14]]) as u32;
    let blob_count = payload[15] as u32;

    Ok(HeaderSummary {
        dtype,
        layer_count,
        total_neurons,
        weights_len,
        biases_len,
        blob_count,
        flags: 0,
    })
}