native_neural_network 0.1.6

Lib no_std Rust for native neural network (.rnn)
Documentation
use core::convert::TryInto;
use core::mem::{align_of, size_of};
use crate::scratch::Scratch;

#[repr(C)]
pub struct BlobMeta {
    pub name_offset: usize,
    pub name_len: usize,
    pub dtype: u8,
    pub ndim: u8,
    pub shape_offset: usize,
    pub offset: u64,
    pub length: u64,
}

pub struct RnnHandle<'bytes, 'scratch> {
    pub bytes: &'bytes [u8],
    pub blobs: &'scratch [BlobMeta],
    pub scratch: &'scratch [u8],
    pub version: u16,
    pub flags: u16,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Error { Truncated, BadMagic, BadHeader, BadBounds, ScratchFull }

const MAX_HEADER: usize = 65536;

pub fn parse_rnn_from_bytes<'bytes, 'scratch>(
    bytes: &'bytes [u8],
    scratch: &'scratch mut Scratch<'_>,
) -> Result<RnnHandle<'bytes, 'scratch>, Error> {
    if bytes.len() < 12 { return Err(Error::Truncated); }
    if &bytes[0..4] != b"RNN\x00" { return Err(Error::BadMagic); }
    let version = u16::from_le_bytes(bytes[4..6].try_into().map_err(|_| Error::BadHeader)?);
    let flags = u16::from_le_bytes(bytes[6..8].try_into().map_err(|_| Error::BadHeader)?);
    let header_size = u32::from_le_bytes(bytes[8..12].try_into().map_err(|_| Error::BadHeader)?) as usize;
    if header_size > bytes.len() || header_size > MAX_HEADER { return Err(Error::BadHeader); }
    let mut cursor = 12usize;
    let scratch_base = scratch.base_ptr() as usize;

    let mut meta_count: usize = 0;
    let mut meta_first_rel: Option<usize> = None;

    while cursor + 5 <= header_size {
        let t = bytes[cursor]; cursor += 1;
        let l = u32::from_le_bytes(bytes[cursor..cursor+4].try_into().map_err(|_| Error::BadHeader)?) as usize; cursor += 4;
        if cursor + l > header_size { return Err(Error::BadHeader); }
        if t == 0x03 {
            let mut p = cursor;
            while p < cursor + l {
                if p + 2 > cursor + l { return Err(Error::BadHeader); }
                let name_len = u16::from_le_bytes(bytes[p..p+2].try_into().map_err(|_| Error::BadHeader)?) as usize; p += 2;
                if p + name_len + 1 + 1 > cursor + l { return Err(Error::BadHeader); }
                let name_bytes = &bytes[p..p+name_len]; p += name_len;
                let dtype = bytes[p]; p += 1;
                let ndim = bytes[p]; p += 1;
                let shape_bytes = (ndim as usize).saturating_mul(4);
                if p + shape_bytes + 8 + 8 + 32 > cursor + l { return Err(Error::BadHeader); }
                let dims_src = &bytes[p..p+shape_bytes]; p += shape_bytes;
                let offset = u64::from_le_bytes(bytes[p..p+8].try_into().map_err(|_| Error::BadHeader)?); p += 8;
                let length = u64::from_le_bytes(bytes[p..p+8].try_into().map_err(|_| Error::BadHeader)?); p += 8;
                p += 32;
                let offset_usize = usize::try_from(offset).map_err(|_| Error::BadBounds)?;
                let length_usize = usize::try_from(length).map_err(|_| Error::BadBounds)?;
                let end = offset_usize.checked_add(length_usize).ok_or(Error::BadBounds)?;
                if end > bytes.len() { return Err(Error::BadBounds); }

                let name_rel = {
                    let name_store = scratch.alloc_align(name_len, align_of::<u8>()).ok_or(Error::ScratchFull)?;
                    name_store.copy_from_slice(name_bytes);
                    name_store.as_ptr() as usize - scratch_base
                };

                let dims_rel = {
                    let dims_store = scratch.alloc_align(shape_bytes, align_of::<u32>()).ok_or(Error::ScratchFull)?;
                    dims_store.copy_from_slice(dims_src);
                    dims_store.as_ptr() as usize - scratch_base
                };

                {
                    let meta_store = scratch.alloc_align(size_of::<BlobMeta>(), align_of::<BlobMeta>()).ok_or(Error::ScratchFull)?;
                    if meta_first_rel.is_none() {
                        meta_first_rel = Some(meta_store.as_ptr() as usize - scratch_base);
                    }
                    let meta_ptr = meta_store.as_mut_ptr() as *mut BlobMeta;
                    unsafe {
                        core::ptr::write(meta_ptr, BlobMeta{
                            name_offset: name_rel,
                            name_len,
                            dtype,
                            ndim,
                            shape_offset: dims_rel,
                            offset,
                            length,
                        });
                    }
                }
                meta_count = meta_count.saturating_add(1);
            }
        }
        cursor = cursor.checked_add(l).ok_or(Error::BadHeader)?;
    }

    let blobs_slice: &[BlobMeta] = if let Some(first_rel) = meta_first_rel {
        let base = scratch.as_slice().as_ptr();
        let ptr = unsafe { base.add(first_rel) } as *const BlobMeta;
        unsafe { core::slice::from_raw_parts(ptr, meta_count) }
    } else { &[] };

    
    Ok(RnnHandle { bytes, blobs: blobs_slice, scratch: scratch.as_slice(), version, flags })
}

impl<'bytes, 'scratch> RnnHandle<'bytes, 'scratch> {
    pub fn blob_name(&self, i: usize) -> Option<&'scratch str> {
        let m = self.blobs.get(i)?;
        let end = m.name_offset.checked_add(m.name_len)?;
        if end > self.scratch.len() { return None; }
        let sl = &self.scratch[m.name_offset..end];
        core::str::from_utf8(sl).ok()
    }

    pub fn blob_dims(&self, i: usize) -> Option<&'scratch [u32]> {
        let m = self.blobs.get(i)?;
        let len = m.ndim as usize;
        let byte_len = len.checked_mul(size_of::<u32>())?;
        let end = m.shape_offset.checked_add(byte_len)?;
        if end > self.scratch.len() { return None; }
        let bytes = &self.scratch[m.shape_offset..end];
        if !(bytes.as_ptr() as usize).is_multiple_of(align_of::<u32>()) { return None; }
        let ptr = bytes.as_ptr() as *const u32;
        Some(unsafe { core::slice::from_raw_parts(ptr, len) })
    }
}