use crate::visualization::errors::VisualizeError;
use crate::rnn_format::RnnHandle;
#[derive(Copy, Clone, Debug)]
pub struct DenseLayerMeta {
pub input_size: usize,
pub output_size: usize,
pub weight_offset: usize,
pub bias_offset: usize,
pub activation: u8,
}
pub struct DenseModelView<'a> {
bytes: &'a [u8],
header_size: usize,
layer_count: usize,
layer_meta_size: usize,
layer_meta_offset: usize,
weights_offset: usize,
weights_len: usize,
biases_offset: usize,
biases_len: usize,
}
impl<'a> DenseModelView<'a> {
pub fn from_bytes(bytes: &'a [u8]) -> Result<Self, VisualizeError> {
if bytes.len() < 20 { return Err(VisualizeError::InvalidFormat); }
if &bytes[0..4] != b"RMD1" { return Err(VisualizeError::InvalidFormat); }
let layer_count = u32::from_le_bytes(bytes[8..12].try_into().map_err(|_| VisualizeError::InvalidFormat)?) as usize;
let weights_len = u32::from_le_bytes(bytes[12..16].try_into().map_err(|_| VisualizeError::InvalidFormat)?) as usize;
let biases_len = u32::from_le_bytes(bytes[16..20].try_into().map_err(|_| VisualizeError::InvalidFormat)?) as usize;
let header_size = 20usize;
let layer_meta_size = 20usize;
let needed_meta = layer_count.checked_mul(layer_meta_size).ok_or(VisualizeError::InvalidFormat)?;
let layers_meta_end = header_size.checked_add(needed_meta).ok_or(VisualizeError::InvalidFormat)?;
if layers_meta_end > bytes.len() { return Err(VisualizeError::InvalidFormat); }
let weights_bytes = weights_len.checked_mul(4).ok_or(VisualizeError::InvalidFormat)?;
let weights_offset = layers_meta_end;
let weights_end = weights_offset.checked_add(weights_bytes).ok_or(VisualizeError::InvalidFormat)?;
if weights_end > bytes.len() { return Err(VisualizeError::InvalidFormat); }
let biases_offset = weights_end;
let biases_bytes = biases_len.checked_mul(4).ok_or(VisualizeError::InvalidFormat)?;
let biases_end = biases_offset.checked_add(biases_bytes).ok_or(VisualizeError::InvalidFormat)?;
if biases_end > bytes.len() { return Err(VisualizeError::InvalidFormat); }
Ok(DenseModelView {
bytes,
header_size,
layer_count,
layer_meta_size,
layer_meta_offset: header_size,
weights_offset,
weights_len,
biases_offset,
biases_len,
})
}
pub fn layer_count(&self) -> usize { self.layer_count }
pub fn neuron_count(&self) -> usize {
let mut total = 0usize;
for i in 0..self.layer_count {
if let Ok(meta) = self.layer_meta(i) {
total = total.saturating_add(meta.output_size);
}
}
total
}
pub fn layer_meta(&self, idx: usize) -> Result<DenseLayerMeta, VisualizeError> {
if idx >= self.layer_count { return Err(VisualizeError::OutOfBounds); }
let off = self.layer_meta_offset + idx * self.layer_meta_size;
let end = off + self.layer_meta_size;
if end > self.bytes.len() { return Err(VisualizeError::OutOfBounds); }
let input_size = u32::from_le_bytes(self.bytes[off..off+4].try_into().map_err(|_| VisualizeError::InvalidFormat)?) as usize;
let output_size = u32::from_le_bytes(self.bytes[off+4..off+8].try_into().map_err(|_| VisualizeError::InvalidFormat)?) as usize;
let weight_offset = u32::from_le_bytes(self.bytes[off+8..off+12].try_into().map_err(|_| VisualizeError::InvalidFormat)?) as usize;
let bias_offset = u32::from_le_bytes(self.bytes[off+12..off+16].try_into().map_err(|_| VisualizeError::InvalidFormat)?) as usize;
let activation = self.bytes[off+16];
Ok(DenseLayerMeta { input_size, output_size, weight_offset, bias_offset, activation })
}
pub fn weights_len(&self) -> usize { self.weights_len }
pub fn biases_len(&self) -> usize { self.biases_len }
pub fn weight_at(&self, idx: usize) -> Result<f32, VisualizeError> {
if idx >= self.weights_len { return Err(VisualizeError::OutOfBounds); }
let bstart = self.weights_offset + idx * 4;
let bend = bstart + 4;
if bend > self.bytes.len() { return Err(VisualizeError::OutOfBounds); }
Ok(f32::from_le_bytes(self.bytes[bstart..bend].try_into().map_err(|_| VisualizeError::InvalidFormat)?))
}
pub fn bias_at(&self, idx: usize) -> Result<f32, VisualizeError> {
if idx >= self.biases_len { return Err(VisualizeError::OutOfBounds); }
let bstart = self.biases_offset + idx * 4;
let bend = bstart + 4;
if bend > self.bytes.len() { return Err(VisualizeError::OutOfBounds); }
Ok(f32::from_le_bytes(self.bytes[bstart..bend].try_into().map_err(|_| VisualizeError::InvalidFormat)?))
}
pub fn header_size(&self) -> usize { self.header_size }
}
pub fn model_view_from_rnn<'bytes, 'scratch>(handle: &RnnHandle<'bytes, 'scratch>, blob_name: &str) -> Result<DenseModelView<'bytes>, VisualizeError> {
use crate::rnn_format::find_blob_index;
let idx = find_blob_index(handle, blob_name).ok_or(VisualizeError::InvalidFormat)?;
let meta = match handle.blobs.get(idx) { Some(m) => m, None => return Err(VisualizeError::InvalidFormat) };
let offset = usize::try_from(meta.offset).map_err(|_| VisualizeError::InvalidFormat)?;
let len = usize::try_from(meta.length).map_err(|_| VisualizeError::InvalidFormat)?;
if offset.checked_add(len).is_none_or(|e| e > handle.bytes.len()) { return Err(VisualizeError::InvalidFormat); }
let slice = &handle.bytes[offset..offset+len];
DenseModelView::from_bytes(slice)
}