Skip to main content

rnn/rnn_format/
rnn_format.rs

1use core::convert::TryInto;
2use core::mem::{align_of, size_of};
3use crate::scratch::Scratch;
4
5#[repr(C)]
6pub struct BlobMeta {
7    pub name_offset: usize,
8    pub name_len: usize,
9    pub dtype: u8,
10    pub ndim: u8,
11    pub shape_offset: usize,
12    pub offset: u64,
13    pub length: u64,
14}
15
16pub struct RnnHandle<'bytes, 'scratch> {
17    pub bytes: &'bytes [u8],
18    pub blobs: &'scratch [BlobMeta],
19    pub scratch: &'scratch [u8],
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum Error { Truncated, BadMagic, BadHeader, BadBounds, ScratchFull }
24
25const MAX_HEADER: usize = 65536;
26
27pub fn parse_rnn_from_bytes<'bytes, 'scratch>(
28    bytes: &'bytes [u8],
29    scratch: &'scratch mut Scratch<'_>,
30) -> Result<RnnHandle<'bytes, 'scratch>, Error> {
31    if bytes.len() < 12 { return Err(Error::Truncated); }
32    if &bytes[0..4] != b"RNN\x00" { return Err(Error::BadMagic); }
33    let _version = u16::from_le_bytes(bytes[4..6].try_into().map_err(|_| Error::BadHeader)?);
34    let _flags = u16::from_le_bytes(bytes[6..8].try_into().map_err(|_| Error::BadHeader)?);
35    let header_size = u32::from_le_bytes(bytes[8..12].try_into().map_err(|_| Error::BadHeader)?) as usize;
36    if header_size > bytes.len() || header_size > MAX_HEADER { return Err(Error::BadHeader); }
37    let mut cursor = 12usize;
38    let scratch_base = scratch.base_ptr() as usize;
39
40    let mut meta_count: usize = 0;
41    let mut meta_first_rel: Option<usize> = None;
42
43    while cursor + 5 <= header_size {
44        let t = bytes[cursor]; cursor += 1;
45        let l = u32::from_le_bytes(bytes[cursor..cursor+4].try_into().map_err(|_| Error::BadHeader)?) as usize; cursor += 4;
46        if cursor + l > header_size { return Err(Error::BadHeader); }
47        if t == 0x03 {
48            let mut p = cursor;
49            while p < cursor + l {
50                if p + 2 > cursor + l { return Err(Error::BadHeader); }
51                let name_len = u16::from_le_bytes(bytes[p..p+2].try_into().map_err(|_| Error::BadHeader)?) as usize; p += 2;
52                if p + name_len + 1 + 1 > cursor + l { return Err(Error::BadHeader); }
53                let name_bytes = &bytes[p..p+name_len]; p += name_len;
54                let dtype = bytes[p]; p += 1;
55                let ndim = bytes[p]; p += 1;
56                let shape_bytes = (ndim as usize).saturating_mul(4);
57                if p + shape_bytes + 8 + 8 + 32 > cursor + l { return Err(Error::BadHeader); }
58                let dims_src = &bytes[p..p+shape_bytes]; p += shape_bytes;
59                let offset = u64::from_le_bytes(bytes[p..p+8].try_into().map_err(|_| Error::BadHeader)?); p += 8;
60                let length = u64::from_le_bytes(bytes[p..p+8].try_into().map_err(|_| Error::BadHeader)?); p += 8;
61                let _sha = &bytes[p..p+32]; p += 32;
62                let offset_usize = usize::try_from(offset).map_err(|_| Error::BadBounds)?;
63                let length_usize = usize::try_from(length).map_err(|_| Error::BadBounds)?;
64                let end = offset_usize.checked_add(length_usize).ok_or(Error::BadBounds)?;
65                if end > bytes.len() { return Err(Error::BadBounds); }
66
67                let name_rel = {
68                    let name_store = scratch.alloc_align(name_len, align_of::<u8>()).ok_or(Error::ScratchFull)?;
69                    name_store.copy_from_slice(name_bytes);
70                    name_store.as_ptr() as usize - scratch_base
71                };
72
73                let dims_rel = {
74                    let dims_store = scratch.alloc_align(shape_bytes, align_of::<u32>()).ok_or(Error::ScratchFull)?;
75                    dims_store.copy_from_slice(dims_src);
76                    dims_store.as_ptr() as usize - scratch_base
77                };
78
79                {
80                    let meta_store = scratch.alloc_align(size_of::<BlobMeta>(), align_of::<BlobMeta>()).ok_or(Error::ScratchFull)?;
81                    if meta_first_rel.is_none() {
82                        meta_first_rel = Some(meta_store.as_ptr() as usize - scratch_base);
83                    }
84                    let meta_ptr = meta_store.as_mut_ptr() as *mut BlobMeta;
85                    unsafe {
86                        core::ptr::write(meta_ptr, BlobMeta{
87                            name_offset: name_rel,
88                            name_len,
89                            dtype,
90                            ndim,
91                            shape_offset: dims_rel,
92                            offset,
93                            length,
94                        });
95                    }
96                }
97                meta_count = meta_count.saturating_add(1);
98            }
99        }
100        cursor = cursor.checked_add(l).ok_or(Error::BadHeader)?;
101    }
102
103    let blobs_slice: &[BlobMeta] = if let Some(first_rel) = meta_first_rel {
104        let base = scratch.as_slice().as_ptr();
105        let ptr = unsafe { base.add(first_rel) } as *const BlobMeta;
106        unsafe { core::slice::from_raw_parts(ptr, meta_count) }
107    } else { &[] };
108
109    Ok(RnnHandle { bytes, blobs: blobs_slice, scratch: scratch.as_slice() })
110}
111
112impl<'bytes, 'scratch> RnnHandle<'bytes, 'scratch> {
113    pub fn blob_name(&self, i: usize) -> Option<&'scratch str> {
114        let m = self.blobs.get(i)?;
115        let end = m.name_offset.checked_add(m.name_len)?;
116        if end > self.scratch.len() { return None; }
117        let sl = &self.scratch[m.name_offset..end];
118        core::str::from_utf8(sl).ok()
119    }
120
121    pub fn blob_dims(&self, i: usize) -> Option<&'scratch [u32]> {
122        let m = self.blobs.get(i)?;
123        let len = m.ndim as usize;
124        let byte_len = len.checked_mul(size_of::<u32>())?;
125        let end = m.shape_offset.checked_add(byte_len)?;
126        if end > self.scratch.len() { return None; }
127        let bytes = &self.scratch[m.shape_offset..end];
128        if (bytes.as_ptr() as usize) % align_of::<u32>() != 0 { return None; }
129        let ptr = bytes.as_ptr() as *const u32;
130        Some(unsafe { core::slice::from_raw_parts(ptr, len) })
131    }
132}