use crate::crypto::{constant_time_eq, sha256_bytes};
use crate::model_format::{
BLOB_BIASES, BLOB_LAYER_META, BLOB_NEURON_POSITIONS, BLOB_RUNTIME_INPUT, BLOB_WEIGHTS,
};
use crate::scratch::Scratch;
use core::convert::TryInto;
use core::mem::{align_of, size_of};
#[repr(C)]
pub(crate) struct BlobMeta {
pub(crate) name_offset: usize,
pub(crate) name_len: usize,
pub(crate) dtype: u8,
pub(crate) ndim: u8,
pub(crate) shape_offset: usize,
pub(crate) offset: u64,
pub(crate) length: u64,
pub(crate) digest_sha256: [u8; 32],
}
pub(crate) struct RnnHandle<'bytes, 'scratch> {
pub(crate) bytes: &'bytes [u8],
pub(crate) blobs: &'scratch [BlobMeta],
pub(crate) scratch: &'scratch [u8],
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum Error {
Truncated,
BadMagic,
WrongFormatRmd1,
BadHeader,
BadBounds,
ScratchFull,
}
const MAX_HEADER: usize = 65536;
fn blob_name_from_compact_id(id: u8) -> Option<&'static str> {
match id {
1 => Some(BLOB_NEURON_POSITIONS),
2 => Some(BLOB_LAYER_META),
3 => Some(BLOB_WEIGHTS),
4 => Some(BLOB_BIASES),
5 => Some(BLOB_RUNTIME_INPUT),
_ => None,
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) enum RnnContainerFormat {
Rnn0,
}
fn magic_for_rnn_format(fmt: RnnContainerFormat) -> [u8; 4] {
match fmt {
RnnContainerFormat::Rnn0 => *b"RNN\x00",
}
}
fn is_rmd1_magic(bytes: &[u8]) -> bool {
bytes.len() >= 4 && constant_time_eq(&bytes[0..4], b"RMD1")
}
pub(crate) fn parse_rnn_from_bytes<'bytes, 'scratch>(
bytes: &'bytes [u8],
scratch: &'scratch mut Scratch<'_>,
) -> Result<RnnHandle<'bytes, 'scratch>, Error> {
parse_rnn_from_bytes_with_format(bytes, scratch, RnnContainerFormat::Rnn0)
}
pub(crate) fn parse_rnn_from_bytes_with_format<'bytes, 'scratch>(
bytes: &'bytes [u8],
scratch: &'scratch mut Scratch<'_>,
format: RnnContainerFormat,
) -> Result<RnnHandle<'bytes, 'scratch>, Error> {
if bytes.len() < 12 {
return Err(Error::Truncated);
}
let expected_magic = magic_for_rnn_format(format);
if !constant_time_eq(&bytes[0..4], &expected_magic) {
if is_rmd1_magic(bytes) {
return Err(Error::WrongFormatRmd1);
}
return Err(Error::BadMagic);
}
let version = u16::from_le_bytes(bytes[4..6].try_into().map_err(|_| Error::BadHeader)?);
let flags = u16::from_le_bytes(bytes[6..8].try_into().map_err(|_| Error::BadHeader)?);
core::hint::black_box(version);
core::hint::black_box(flags);
let header_size =
u32::from_le_bytes(bytes[8..12].try_into().map_err(|_| Error::BadHeader)?) as usize;
if header_size > bytes.len() || header_size > MAX_HEADER {
return Err(Error::BadHeader);
}
let scratch_base = scratch.base_ptr() as usize;
if header_size >= 0xA0 && bytes.len() >= 0xA0 && bytes[0x50] == 0xC1 {
let table = &bytes[0x50..0xA0];
if table[1] != 1 {
return Err(Error::BadHeader);
}
let count = table[2] as usize;
let need = 3usize
.checked_add(count.checked_mul(16).ok_or(Error::BadHeader)?)
.ok_or(Error::BadHeader)?;
if need > table.len() {
return Err(Error::BadHeader);
}
let blobs_slice: &[BlobMeta] = if count == 0 {
&[]
} else {
let meta_bytes = count
.checked_mul(size_of::<BlobMeta>())
.ok_or(Error::ScratchFull)?;
let meta_store = scratch
.alloc_align(meta_bytes, align_of::<BlobMeta>())
.ok_or(Error::ScratchFull)?;
let meta_ptr = meta_store.as_mut_ptr() as *mut BlobMeta;
for i in 0..count {
let p = 3 + i * 15;
let blob_id = table[p];
let dtype = table[p + 1];
let ndim = table[p + 2];
if !matches!(dtype, 0..=2) || ndim == 0 || ndim > 2 {
return Err(Error::BadHeader);
}
let d0 = u16::from_le_bytes([table[p + 4], table[p + 5]]) as u32;
let d1 = u16::from_le_bytes([table[p + 6], table[p + 7]]) as u32;
let offset =
u32::from_le_bytes([table[p + 8], table[p + 9], table[p + 10], table[p + 11]])
as u64;
let length = u32::from_le_bytes([
table[p + 12],
table[p + 13],
table[p + 14],
table[p + 15],
]) as u64;
let name = blob_name_from_compact_id(blob_id).ok_or(Error::BadHeader)?;
let name_bytes = name.as_bytes();
let name_len = name_bytes.len();
let name_rel = {
let name_store = scratch
.alloc_align(name_len, align_of::<u8>())
.ok_or(Error::ScratchFull)?;
name_store.copy_from_slice(name_bytes);
name_store.as_ptr() as usize - scratch_base
};
let shape_bytes = (ndim as usize).saturating_mul(4);
let dims_rel = {
let dims_store = scratch
.alloc_align(shape_bytes, align_of::<u32>())
.ok_or(Error::ScratchFull)?;
let mut w = 0usize;
if ndim >= 1 {
dims_store[w..w + 4].copy_from_slice(&d0.to_le_bytes());
w += 4;
}
if ndim >= 2 {
dims_store[w..w + 4].copy_from_slice(&d1.to_le_bytes());
}
dims_store.as_ptr() as usize - scratch_base
};
let offset_usize = usize::try_from(offset).map_err(|_| Error::BadBounds)?;
let length_usize = usize::try_from(length).map_err(|_| Error::BadBounds)?;
let end = offset_usize
.checked_add(length_usize)
.ok_or(Error::BadBounds)?;
if end > bytes.len() {
return Err(Error::BadBounds);
}
let payload = &bytes[offset_usize..end];
let mut digest_sha256 = [0u8; 32];
sha256_bytes(payload, &mut digest_sha256);
unsafe {
core::ptr::write(
meta_ptr.add(i),
BlobMeta {
name_offset: name_rel,
name_len,
dtype,
ndim,
shape_offset: dims_rel,
offset,
length,
digest_sha256,
},
);
}
}
unsafe { core::slice::from_raw_parts(meta_ptr as *const BlobMeta, count) }
};
return Ok(RnnHandle {
bytes,
blobs: blobs_slice,
scratch: scratch.as_slice(),
});
}
let mut cursor = 12usize;
let mut metas_count = 0usize;
while cursor < header_size {
if bytes[cursor] == 0 {
cursor += 1;
continue;
}
if cursor + 5 > header_size {
return Err(Error::BadHeader);
}
let t = bytes[cursor];
cursor += 1;
let l = u32::from_le_bytes(
bytes[cursor..cursor + 4]
.try_into()
.map_err(|_| Error::BadHeader)?,
) as usize;
cursor += 4;
if cursor + l > header_size {
return Err(Error::BadHeader);
}
if t == 0x03 {
if l >= 3 && bytes[cursor] == 0xC1 {
let version_compact = bytes[cursor + 1];
if version_compact != 1 {
return Err(Error::BadHeader);
}
let count = bytes[cursor + 2] as usize;
let need = 3usize
.checked_add(count.checked_mul(15).ok_or(Error::BadHeader)?)
.ok_or(Error::BadHeader)?;
if need > l {
return Err(Error::BadHeader);
}
let mut p = cursor + 3;
for _ in 0..count {
let blob_id = bytes[p];
let dtype = bytes[p + 1];
let ndim = bytes[p + 2];
let d0 = u16::from_le_bytes([bytes[p + 3], bytes[p + 4]]) as u32;
let d1 = u16::from_le_bytes([bytes[p + 5], bytes[p + 6]]) as u32;
let offset = u32::from_le_bytes([
bytes[p + 7],
bytes[p + 8],
bytes[p + 9],
bytes[p + 10],
]) as u64;
let length = u32::from_le_bytes([
bytes[p + 11],
bytes[p + 12],
bytes[p + 13],
bytes[p + 14],
]) as u64;
p += 15;
if blob_name_from_compact_id(blob_id).is_none() {
return Err(Error::BadHeader);
}
if !matches!(dtype, 0..=2) || ndim == 0 || ndim > 2 {
return Err(Error::BadHeader);
}
let _ = (d0, d1);
let offset_usize = usize::try_from(offset).map_err(|_| Error::BadBounds)?;
let length_usize = usize::try_from(length).map_err(|_| Error::BadBounds)?;
let end = offset_usize
.checked_add(length_usize)
.ok_or(Error::BadBounds)?;
if end > bytes.len() {
return Err(Error::BadBounds);
}
metas_count = metas_count.checked_add(1).ok_or(Error::ScratchFull)?;
}
cursor = cursor.checked_add(l).ok_or(Error::BadHeader)?;
continue;
}
let mut p = cursor;
while p < cursor + l {
if p + 2 > cursor + l {
return Err(Error::BadHeader);
}
let name_len =
u16::from_le_bytes(bytes[p..p + 2].try_into().map_err(|_| Error::BadHeader)?)
as usize;
p += 2;
if p + name_len + 1 + 1 > cursor + l {
return Err(Error::BadHeader);
}
p += name_len;
if !matches!(bytes[p], 0..=2) {
return Err(Error::BadHeader);
}
p += 1;
let ndim = bytes[p];
p += 1;
if ndim == 0 {
return Err(Error::BadHeader);
}
let shape_bytes = (ndim as usize).saturating_mul(4);
if p + shape_bytes + 8 + 8 + 32 > cursor + l {
return Err(Error::BadHeader);
}
p += shape_bytes;
let offset =
u64::from_le_bytes(bytes[p..p + 8].try_into().map_err(|_| Error::BadHeader)?);
p += 8;
let length =
u64::from_le_bytes(bytes[p..p + 8].try_into().map_err(|_| Error::BadHeader)?);
p += 8;
p += 32;
let offset_usize = usize::try_from(offset).map_err(|_| Error::BadBounds)?;
let length_usize = usize::try_from(length).map_err(|_| Error::BadBounds)?;
let end = offset_usize
.checked_add(length_usize)
.ok_or(Error::BadBounds)?;
if end > bytes.len() {
return Err(Error::BadBounds);
}
metas_count = metas_count.checked_add(1).ok_or(Error::ScratchFull)?;
}
}
cursor = cursor.checked_add(l).ok_or(Error::BadHeader)?;
}
let blobs_slice: &[BlobMeta] = if metas_count == 0 {
&[]
} else {
let meta_bytes = metas_count
.checked_mul(size_of::<BlobMeta>())
.ok_or(Error::ScratchFull)?;
let meta_store = scratch
.alloc_align(meta_bytes, align_of::<BlobMeta>())
.ok_or(Error::ScratchFull)?;
let meta_ptr = meta_store.as_mut_ptr() as *mut BlobMeta;
cursor = 12usize;
let mut meta_index = 0usize;
while cursor < header_size {
if bytes[cursor] == 0 {
cursor += 1;
continue;
}
if cursor + 5 > header_size {
return Err(Error::BadHeader);
}
let t = bytes[cursor];
cursor += 1;
let l = u32::from_le_bytes(
bytes[cursor..cursor + 4]
.try_into()
.map_err(|_| Error::BadHeader)?,
) as usize;
cursor += 4;
if cursor + l > header_size {
return Err(Error::BadHeader);
}
if t == 0x03 {
if l >= 3 && bytes[cursor] == 0xC1 {
let version_compact = bytes[cursor + 1];
if version_compact != 1 {
return Err(Error::BadHeader);
}
let count = bytes[cursor + 2] as usize;
let need = 3usize
.checked_add(count.checked_mul(15).ok_or(Error::BadHeader)?)
.ok_or(Error::BadHeader)?;
if need > l {
return Err(Error::BadHeader);
}
let mut p = cursor + 3;
for _ in 0..count {
let blob_id = bytes[p];
let dtype = bytes[p + 1];
let ndim = bytes[p + 2];
if blob_name_from_compact_id(blob_id).is_none() {
return Err(Error::BadHeader);
}
if !matches!(dtype, 0..=2) || ndim == 0 || ndim > 2 {
return Err(Error::BadHeader);
}
let d0 = u16::from_le_bytes([bytes[p + 3], bytes[p + 4]]) as u32;
let d1 = u16::from_le_bytes([bytes[p + 5], bytes[p + 6]]) as u32;
let offset = u32::from_le_bytes([
bytes[p + 7],
bytes[p + 8],
bytes[p + 9],
bytes[p + 10],
]) as u64;
let length = u32::from_le_bytes([
bytes[p + 11],
bytes[p + 12],
bytes[p + 13],
bytes[p + 14],
]) as u64;
p += 15;
let offset_usize = usize::try_from(offset).map_err(|_| Error::BadBounds)?;
let length_usize = usize::try_from(length).map_err(|_| Error::BadBounds)?;
let end = offset_usize
.checked_add(length_usize)
.ok_or(Error::BadBounds)?;
if end > bytes.len() {
return Err(Error::BadBounds);
}
let payload = &bytes[offset_usize..end];
let mut digest_sha256 = [0u8; 32];
sha256_bytes(payload, &mut digest_sha256);
let name = blob_name_from_compact_id(blob_id).ok_or(Error::BadHeader)?;
let name_bytes = name.as_bytes();
let name_len = name_bytes.len();
let name_rel = {
let name_store = scratch
.alloc_align(name_len, align_of::<u8>())
.ok_or(Error::ScratchFull)?;
name_store.copy_from_slice(name_bytes);
name_store.as_ptr() as usize - scratch_base
};
let shape_bytes = (ndim as usize).saturating_mul(4);
let dims_rel = {
let dims_store = scratch
.alloc_align(shape_bytes, align_of::<u32>())
.ok_or(Error::ScratchFull)?;
let mut w = 0usize;
if ndim >= 1 {
dims_store[w..w + 4].copy_from_slice(&d0.to_le_bytes());
w += 4;
}
if ndim >= 2 {
dims_store[w..w + 4].copy_from_slice(&d1.to_le_bytes());
}
dims_store.as_ptr() as usize - scratch_base
};
unsafe {
core::ptr::write(
meta_ptr.add(meta_index),
BlobMeta {
name_offset: name_rel,
name_len,
dtype,
ndim,
shape_offset: dims_rel,
offset,
length,
digest_sha256,
},
);
}
meta_index = meta_index.checked_add(1).ok_or(Error::ScratchFull)?;
}
cursor = cursor.checked_add(l).ok_or(Error::BadHeader)?;
continue;
}
let mut p = cursor;
while p < cursor + l {
if p + 2 > cursor + l {
return Err(Error::BadHeader);
}
let name_len = u16::from_le_bytes(
bytes[p..p + 2].try_into().map_err(|_| Error::BadHeader)?,
) as usize;
p += 2;
if p + name_len + 1 + 1 > cursor + l {
return Err(Error::BadHeader);
}
let name_bytes = &bytes[p..p + name_len];
p += name_len;
let dtype = bytes[p];
p += 1;
let ndim = bytes[p];
p += 1;
if ndim == 0 {
return Err(Error::BadHeader);
}
let shape_bytes = (ndim as usize).saturating_mul(4);
if p + shape_bytes + 8 + 8 + 32 > cursor + l {
return Err(Error::BadHeader);
}
let dims_src = &bytes[p..p + shape_bytes];
p += shape_bytes;
let offset = u64::from_le_bytes(
bytes[p..p + 8].try_into().map_err(|_| Error::BadHeader)?,
);
p += 8;
let length = u64::from_le_bytes(
bytes[p..p + 8].try_into().map_err(|_| Error::BadHeader)?,
);
p += 8;
let mut digest_sha256 = [0u8; 32];
digest_sha256.copy_from_slice(&bytes[p..p + 32]);
p += 32;
let offset_usize = usize::try_from(offset).map_err(|_| Error::BadBounds)?;
let length_usize = usize::try_from(length).map_err(|_| Error::BadBounds)?;
let end = offset_usize
.checked_add(length_usize)
.ok_or(Error::BadBounds)?;
if end > bytes.len() {
return Err(Error::BadBounds);
}
let name_rel = {
let name_store = scratch
.alloc_align(name_len, align_of::<u8>())
.ok_or(Error::ScratchFull)?;
name_store.copy_from_slice(name_bytes);
name_store.as_ptr() as usize - scratch_base
};
let dims_rel = {
let dims_store = scratch
.alloc_align(shape_bytes, align_of::<u32>())
.ok_or(Error::ScratchFull)?;
dims_store.copy_from_slice(dims_src);
dims_store.as_ptr() as usize - scratch_base
};
unsafe {
core::ptr::write(
meta_ptr.add(meta_index),
BlobMeta {
name_offset: name_rel,
name_len,
dtype,
ndim,
shape_offset: dims_rel,
offset,
length,
digest_sha256,
},
);
}
meta_index = meta_index.checked_add(1).ok_or(Error::ScratchFull)?;
}
}
cursor = cursor.checked_add(l).ok_or(Error::BadHeader)?;
}
if meta_index != metas_count {
return Err(Error::BadHeader);
}
unsafe { core::slice::from_raw_parts(meta_ptr as *const BlobMeta, metas_count) }
};
Ok(RnnHandle {
bytes,
blobs: blobs_slice,
scratch: scratch.as_slice(),
})
}
impl<'bytes, 'scratch> RnnHandle<'bytes, 'scratch> {
pub(crate) fn blob_name(&self, i: usize) -> Option<&'scratch str> {
let m = self.blobs.get(i)?;
let end = m.name_offset.checked_add(m.name_len)?;
if end > self.scratch.len() {
return None;
}
let sl = &self.scratch[m.name_offset..end];
core::str::from_utf8(sl).ok()
}
}