rnn/rnn_format/
rnn_format.rs1use core::convert::TryInto;
2use core::mem::{align_of, size_of};
3use crate::scratch::Scratch;
4
5#[repr(C)]
6pub struct BlobMeta {
7 pub name_offset: usize,
8 pub name_len: usize,
9 pub dtype: u8,
10 pub ndim: u8,
11 pub shape_offset: usize,
12 pub offset: u64,
13 pub length: u64,
14}
15
16pub struct RnnHandle<'bytes, 'scratch> {
17 pub bytes: &'bytes [u8],
18 pub blobs: &'scratch [BlobMeta],
19 pub scratch: &'scratch [u8],
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum Error { Truncated, BadMagic, BadHeader, BadBounds, ScratchFull }
24
25const MAX_HEADER: usize = 65536;
26
27pub fn parse_rnn_from_bytes<'bytes, 'scratch>(
28 bytes: &'bytes [u8],
29 scratch: &'scratch mut Scratch<'_>,
30) -> Result<RnnHandle<'bytes, 'scratch>, Error> {
31 if bytes.len() < 12 { return Err(Error::Truncated); }
32 if &bytes[0..4] != b"RNN\x00" { return Err(Error::BadMagic); }
33 let _version = u16::from_le_bytes(bytes[4..6].try_into().map_err(|_| Error::BadHeader)?);
34 let _flags = u16::from_le_bytes(bytes[6..8].try_into().map_err(|_| Error::BadHeader)?);
35 let header_size = u32::from_le_bytes(bytes[8..12].try_into().map_err(|_| Error::BadHeader)?) as usize;
36 if header_size > bytes.len() || header_size > MAX_HEADER { return Err(Error::BadHeader); }
37 let mut cursor = 12usize;
38 let scratch_base = scratch.base_ptr() as usize;
39
40 let mut meta_count: usize = 0;
41 let mut meta_first_rel: Option<usize> = None;
42
43 while cursor + 5 <= header_size {
44 let t = bytes[cursor]; cursor += 1;
45 let l = u32::from_le_bytes(bytes[cursor..cursor+4].try_into().map_err(|_| Error::BadHeader)?) as usize; cursor += 4;
46 if cursor + l > header_size { return Err(Error::BadHeader); }
47 if t == 0x03 {
48 let mut p = cursor;
49 while p < cursor + l {
50 if p + 2 > cursor + l { return Err(Error::BadHeader); }
51 let name_len = u16::from_le_bytes(bytes[p..p+2].try_into().map_err(|_| Error::BadHeader)?) as usize; p += 2;
52 if p + name_len + 1 + 1 > cursor + l { return Err(Error::BadHeader); }
53 let name_bytes = &bytes[p..p+name_len]; p += name_len;
54 let dtype = bytes[p]; p += 1;
55 let ndim = bytes[p]; p += 1;
56 let shape_bytes = (ndim as usize).saturating_mul(4);
57 if p + shape_bytes + 8 + 8 + 32 > cursor + l { return Err(Error::BadHeader); }
58 let dims_src = &bytes[p..p+shape_bytes]; p += shape_bytes;
59 let offset = u64::from_le_bytes(bytes[p..p+8].try_into().map_err(|_| Error::BadHeader)?); p += 8;
60 let length = u64::from_le_bytes(bytes[p..p+8].try_into().map_err(|_| Error::BadHeader)?); p += 8;
61 let _sha = &bytes[p..p+32]; p += 32;
62 let offset_usize = usize::try_from(offset).map_err(|_| Error::BadBounds)?;
63 let length_usize = usize::try_from(length).map_err(|_| Error::BadBounds)?;
64 let end = offset_usize.checked_add(length_usize).ok_or(Error::BadBounds)?;
65 if end > bytes.len() { return Err(Error::BadBounds); }
66
67 let name_rel = {
68 let name_store = scratch.alloc_align(name_len, align_of::<u8>()).ok_or(Error::ScratchFull)?;
69 name_store.copy_from_slice(name_bytes);
70 name_store.as_ptr() as usize - scratch_base
71 };
72
73 let dims_rel = {
74 let dims_store = scratch.alloc_align(shape_bytes, align_of::<u32>()).ok_or(Error::ScratchFull)?;
75 dims_store.copy_from_slice(dims_src);
76 dims_store.as_ptr() as usize - scratch_base
77 };
78
79 {
80 let meta_store = scratch.alloc_align(size_of::<BlobMeta>(), align_of::<BlobMeta>()).ok_or(Error::ScratchFull)?;
81 if meta_first_rel.is_none() {
82 meta_first_rel = Some(meta_store.as_ptr() as usize - scratch_base);
83 }
84 let meta_ptr = meta_store.as_mut_ptr() as *mut BlobMeta;
85 unsafe {
86 core::ptr::write(meta_ptr, BlobMeta{
87 name_offset: name_rel,
88 name_len,
89 dtype,
90 ndim,
91 shape_offset: dims_rel,
92 offset,
93 length,
94 });
95 }
96 }
97 meta_count = meta_count.saturating_add(1);
98 }
99 }
100 cursor = cursor.checked_add(l).ok_or(Error::BadHeader)?;
101 }
102
103 let blobs_slice: &[BlobMeta] = if let Some(first_rel) = meta_first_rel {
104 let base = scratch.as_slice().as_ptr();
105 let ptr = unsafe { base.add(first_rel) } as *const BlobMeta;
106 unsafe { core::slice::from_raw_parts(ptr, meta_count) }
107 } else { &[] };
108
109 Ok(RnnHandle { bytes, blobs: blobs_slice, scratch: scratch.as_slice() })
110}
111
112impl<'bytes, 'scratch> RnnHandle<'bytes, 'scratch> {
113 pub fn blob_name(&self, i: usize) -> Option<&'scratch str> {
114 let m = self.blobs.get(i)?;
115 let end = m.name_offset.checked_add(m.name_len)?;
116 if end > self.scratch.len() { return None; }
117 let sl = &self.scratch[m.name_offset..end];
118 core::str::from_utf8(sl).ok()
119 }
120
121 pub fn blob_dims(&self, i: usize) -> Option<&'scratch [u32]> {
122 let m = self.blobs.get(i)?;
123 let len = m.ndim as usize;
124 let byte_len = len.checked_mul(size_of::<u32>())?;
125 let end = m.shape_offset.checked_add(byte_len)?;
126 if end > self.scratch.len() { return None; }
127 let bytes = &self.scratch[m.shape_offset..end];
128 if (bytes.as_ptr() as usize) % align_of::<u32>() != 0 { return None; }
129 let ptr = bytes.as_ptr() as *const u32;
130 Some(unsafe { core::slice::from_raw_parts(ptr, len) })
131 }
132}