use core::convert::TryInto;
use core::mem::{align_of, size_of};
use crate::scratch::Scratch;
#[repr(C)]
pub struct BlobMeta {
pub name_offset: usize,
pub name_len: usize,
pub dtype: u8,
pub ndim: u8,
pub shape_offset: usize,
pub offset: u64,
pub length: u64,
}
pub struct RnnHandle<'bytes, 'scratch> {
pub bytes: &'bytes [u8],
pub blobs: &'scratch [BlobMeta],
pub scratch: &'scratch [u8],
pub version: u16,
pub flags: u16,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Error { Truncated, BadMagic, BadHeader, BadBounds, ScratchFull }
const MAX_HEADER: usize = 65536;
pub fn parse_rnn_from_bytes<'bytes, 'scratch>(
bytes: &'bytes [u8],
scratch: &'scratch mut Scratch<'_>,
) -> Result<RnnHandle<'bytes, 'scratch>, Error> {
if bytes.len() < 12 { return Err(Error::Truncated); }
if &bytes[0..4] != b"RNN\x00" { return Err(Error::BadMagic); }
let version = u16::from_le_bytes(bytes[4..6].try_into().map_err(|_| Error::BadHeader)?);
let flags = u16::from_le_bytes(bytes[6..8].try_into().map_err(|_| Error::BadHeader)?);
let header_size = u32::from_le_bytes(bytes[8..12].try_into().map_err(|_| Error::BadHeader)?) as usize;
if header_size > bytes.len() || header_size > MAX_HEADER { return Err(Error::BadHeader); }
let mut cursor = 12usize;
let scratch_base = scratch.base_ptr() as usize;
let mut meta_count: usize = 0;
let mut meta_first_rel: Option<usize> = None;
while cursor + 5 <= header_size {
let t = bytes[cursor]; cursor += 1;
let l = u32::from_le_bytes(bytes[cursor..cursor+4].try_into().map_err(|_| Error::BadHeader)?) as usize; cursor += 4;
if cursor + l > header_size { return Err(Error::BadHeader); }
if t == 0x03 {
let mut p = cursor;
while p < cursor + l {
if p + 2 > cursor + l { return Err(Error::BadHeader); }
let name_len = u16::from_le_bytes(bytes[p..p+2].try_into().map_err(|_| Error::BadHeader)?) as usize; p += 2;
if p + name_len + 1 + 1 > cursor + l { return Err(Error::BadHeader); }
let name_bytes = &bytes[p..p+name_len]; p += name_len;
let dtype = bytes[p]; p += 1;
let ndim = bytes[p]; p += 1;
let shape_bytes = (ndim as usize).saturating_mul(4);
if p + shape_bytes + 8 + 8 + 32 > cursor + l { return Err(Error::BadHeader); }
let dims_src = &bytes[p..p+shape_bytes]; p += shape_bytes;
let offset = u64::from_le_bytes(bytes[p..p+8].try_into().map_err(|_| Error::BadHeader)?); p += 8;
let length = u64::from_le_bytes(bytes[p..p+8].try_into().map_err(|_| Error::BadHeader)?); p += 8;
p += 32;
let offset_usize = usize::try_from(offset).map_err(|_| Error::BadBounds)?;
let length_usize = usize::try_from(length).map_err(|_| Error::BadBounds)?;
let end = offset_usize.checked_add(length_usize).ok_or(Error::BadBounds)?;
if end > bytes.len() { return Err(Error::BadBounds); }
let name_rel = {
let name_store = scratch.alloc_align(name_len, align_of::<u8>()).ok_or(Error::ScratchFull)?;
name_store.copy_from_slice(name_bytes);
name_store.as_ptr() as usize - scratch_base
};
let dims_rel = {
let dims_store = scratch.alloc_align(shape_bytes, align_of::<u32>()).ok_or(Error::ScratchFull)?;
dims_store.copy_from_slice(dims_src);
dims_store.as_ptr() as usize - scratch_base
};
{
let meta_store = scratch.alloc_align(size_of::<BlobMeta>(), align_of::<BlobMeta>()).ok_or(Error::ScratchFull)?;
if meta_first_rel.is_none() {
meta_first_rel = Some(meta_store.as_ptr() as usize - scratch_base);
}
let meta_ptr = meta_store.as_mut_ptr() as *mut BlobMeta;
unsafe {
core::ptr::write(meta_ptr, BlobMeta{
name_offset: name_rel,
name_len,
dtype,
ndim,
shape_offset: dims_rel,
offset,
length,
});
}
}
meta_count = meta_count.saturating_add(1);
}
}
cursor = cursor.checked_add(l).ok_or(Error::BadHeader)?;
}
let blobs_slice: &[BlobMeta] = if let Some(first_rel) = meta_first_rel {
let base = scratch.as_slice().as_ptr();
let ptr = unsafe { base.add(first_rel) } as *const BlobMeta;
unsafe { core::slice::from_raw_parts(ptr, meta_count) }
} else { &[] };
Ok(RnnHandle { bytes, blobs: blobs_slice, scratch: scratch.as_slice(), version, flags })
}
impl<'bytes, 'scratch> RnnHandle<'bytes, 'scratch> {
pub fn blob_name(&self, i: usize) -> Option<&'scratch str> {
let m = self.blobs.get(i)?;
let end = m.name_offset.checked_add(m.name_len)?;
if end > self.scratch.len() { return None; }
let sl = &self.scratch[m.name_offset..end];
core::str::from_utf8(sl).ok()
}
pub fn blob_dims(&self, i: usize) -> Option<&'scratch [u32]> {
let m = self.blobs.get(i)?;
let len = m.ndim as usize;
let byte_len = len.checked_mul(size_of::<u32>())?;
let end = m.shape_offset.checked_add(byte_len)?;
if end > self.scratch.len() { return None; }
let bytes = &self.scratch[m.shape_offset..end];
if !(bytes.as_ptr() as usize).is_multiple_of(align_of::<u32>()) { return None; }
let ptr = bytes.as_ptr() as *const u32;
Some(unsafe { core::slice::from_raw_parts(ptr, len) })
}
}