use crate::crypto::constant_time_eq;
use crate::visualization::errors::VisualizeError;
#[derive(Copy, Clone, Debug)]
pub struct LayerMeta {
pub input_size: usize,
pub output_size: usize,
pub weight_offset: usize,
pub bias_offset: usize,
pub activation: u8,
}
pub struct ModelView<'a> {
bytes: &'a [u8],
header_size: usize,
layer_count: usize,
layer_meta_size: usize,
layer_meta_offset: usize,
weights_offset: usize,
weights_len: usize,
biases_offset: usize,
biases_len: usize,
}
impl<'a> ModelView<'a> {
pub fn from_bytes(bytes: &'a [u8]) -> Result<Self, VisualizeError> {
if bytes.len() < 20 {
return Err(VisualizeError::InvalidFormat);
}
if !constant_time_eq(&bytes[0..4], b"RMD1") {
return Err(VisualizeError::InvalidFormat);
}
let layer_count = u32::from_le_bytes(
bytes[8..12]
.try_into()
.map_err(|_| VisualizeError::InvalidFormat)?,
) as usize;
let weights_len = u32::from_le_bytes(
bytes[12..16]
.try_into()
.map_err(|_| VisualizeError::InvalidFormat)?,
) as usize;
let biases_len = u32::from_le_bytes(
bytes[16..20]
.try_into()
.map_err(|_| VisualizeError::InvalidFormat)?,
) as usize;
let header_size = 20usize;
let layer_meta_size = 20usize;
let needed_meta = layer_count
.checked_mul(layer_meta_size)
.ok_or(VisualizeError::InvalidFormat)?;
let layers_meta_end = header_size
.checked_add(needed_meta)
.ok_or(VisualizeError::InvalidFormat)?;
if layers_meta_end > bytes.len() {
return Err(VisualizeError::InvalidFormat);
}
let weights_bytes = weights_len
.checked_mul(4)
.ok_or(VisualizeError::InvalidFormat)?;
let weights_offset = layers_meta_end;
let weights_end = weights_offset
.checked_add(weights_bytes)
.ok_or(VisualizeError::InvalidFormat)?;
if weights_end > bytes.len() {
return Err(VisualizeError::InvalidFormat);
}
let biases_offset = weights_end;
let biases_bytes = biases_len
.checked_mul(4)
.ok_or(VisualizeError::InvalidFormat)?;
let biases_end = biases_offset
.checked_add(biases_bytes)
.ok_or(VisualizeError::InvalidFormat)?;
if biases_end > bytes.len() {
return Err(VisualizeError::InvalidFormat);
}
Ok(ModelView {
bytes,
header_size,
layer_count,
layer_meta_size,
layer_meta_offset: header_size,
weights_offset,
weights_len,
biases_offset,
biases_len,
})
}
pub fn layer_count(&self) -> usize {
self.layer_count
}
pub fn neuron_count(&self) -> usize {
let mut total = 0usize;
for i in 0..self.layer_count {
if let Ok(meta) = self.layer_meta(i) {
total = total.saturating_add(meta.output_size);
}
}
total
}
pub fn layer_meta(&self, idx: usize) -> Result<LayerMeta, VisualizeError> {
if idx >= self.layer_count {
return Err(VisualizeError::OutOfBounds);
}
let off = self.layer_meta_offset + idx * self.layer_meta_size;
let end = off + self.layer_meta_size;
if end > self.bytes.len() {
return Err(VisualizeError::OutOfBounds);
}
let input_size = u32::from_le_bytes(
self.bytes[off..off + 4]
.try_into()
.map_err(|_| VisualizeError::InvalidFormat)?,
) as usize;
let output_size = u32::from_le_bytes(
self.bytes[off + 4..off + 8]
.try_into()
.map_err(|_| VisualizeError::InvalidFormat)?,
) as usize;
let weight_offset = u32::from_le_bytes(
self.bytes[off + 8..off + 12]
.try_into()
.map_err(|_| VisualizeError::InvalidFormat)?,
) as usize;
let bias_offset = u32::from_le_bytes(
self.bytes[off + 12..off + 16]
.try_into()
.map_err(|_| VisualizeError::InvalidFormat)?,
) as usize;
let activation = self.bytes[off + 16];
Ok(LayerMeta {
input_size,
output_size,
weight_offset,
bias_offset,
activation,
})
}
pub fn weights_len(&self) -> usize {
self.weights_len
}
pub fn biases_len(&self) -> usize {
self.biases_len
}
pub fn weight_at(&self, idx: usize) -> Result<f32, VisualizeError> {
if idx >= self.weights_len {
return Err(VisualizeError::OutOfBounds);
}
let bstart = self.weights_offset + idx * 4;
let bend = bstart + 4;
if bend > self.bytes.len() {
return Err(VisualizeError::OutOfBounds);
}
Ok(f32::from_le_bytes(
self.bytes[bstart..bend]
.try_into()
.map_err(|_| VisualizeError::InvalidFormat)?,
))
}
pub fn bias_at(&self, idx: usize) -> Result<f32, VisualizeError> {
if idx >= self.biases_len {
return Err(VisualizeError::OutOfBounds);
}
let bstart = self.biases_offset + idx * 4;
let bend = bstart + 4;
if bend > self.bytes.len() {
return Err(VisualizeError::OutOfBounds);
}
Ok(f32::from_le_bytes(
self.bytes[bstart..bend]
.try_into()
.map_err(|_| VisualizeError::InvalidFormat)?,
))
}
pub fn header_size(&self) -> usize {
self.header_size
}
}