Skip to main content

rnn/model_format/
header.rs

1use core::convert::TryInto;
2
3#[derive(Clone, Copy, Debug, PartialEq, Eq)]
4pub struct HeaderV1 {
5    pub version: u16,
6    pub layer_count: usize,
7    pub weights_len: usize,
8    pub biases_len: usize,
9}
10
11pub fn parse_header_v1(bytes: &[u8]) -> Option<HeaderV1> {
12    if bytes.len() < 20 {
13        return None;
14    }
15    if &bytes[0..4] != b"RMD1" {
16        return None;
17    }
18    let version = u16::from_le_bytes(bytes[4..6].try_into().ok()?);
19    let layer_count = u32::from_le_bytes(bytes[8..12].try_into().ok()?) as usize;
20    let weights_len = u32::from_le_bytes(bytes[12..16].try_into().ok()?) as usize;
21    let biases_len = u32::from_le_bytes(bytes[16..20].try_into().ok()?) as usize;
22    Some(HeaderV1 {
23        version,
24        layer_count,
25        weights_len,
26        biases_len,
27    })
28}
29
30pub fn is_v1_header_consistent(h: &HeaderV1) -> bool {
31    if h.version != 1 {
32        return false;
33    }
34    h.layer_count > 0
35}