Skip to main content

rnn/model_format/
model_format.rs

1use core::convert::TryInto;
2
3use crate::activations::ActivationKind;
4use crate::layers::{DenseLayerDesc, LayerSpec};
5
6const MAGIC: &[u8; 4] = b"RMD1";
7const VERSION: u16 = 1;
8const HEADER_SIZE: usize = 4 + 2 + 2 + 4 + 4 + 4;
9const LAYER_META_SIZE: usize = 4 + 4 + 4 + 4 + 1 + 3;
10
11#[derive(Clone, Copy, Debug, PartialEq, Eq)]
12pub enum ModelFormatError {
13    Truncated,
14    BadMagic,
15    BadVersion,
16    BadHeader,
17    CapacityTooSmall,
18    InvalidActivation,
19    InvalidLayer,
20}
21
22#[derive(Clone, Copy, Debug, PartialEq, Eq)]
23pub struct DecodedCounts {
24    pub layers: usize,
25    pub weights: usize,
26    pub biases: usize,
27}
28
29pub fn encoded_size_v1(layer_count: usize, weights_len: usize, biases_len: usize) -> Option<usize> {
30    let layers_bytes = layer_count.checked_mul(LAYER_META_SIZE)?;
31    let weights_bytes = weights_len.checked_mul(core::mem::size_of::<f32>())?;
32    let biases_bytes = biases_len.checked_mul(core::mem::size_of::<f32>())?;
33    HEADER_SIZE.checked_add(layers_bytes)?.checked_add(weights_bytes)?.checked_add(biases_bytes)
34}
35
36pub fn encode_dense_model_v1(layers: &[LayerSpec], weights: &[f32], biases: &[f32], out: &mut [u8]) -> Result<usize, ModelFormatError> {
37    let needed = encoded_size_v1(layers.len(), weights.len(), biases.len()).ok_or(ModelFormatError::BadHeader)?;
38    if out.len() < needed {
39        return Err(ModelFormatError::CapacityTooSmall);
40    }
41
42    out[0..4].copy_from_slice(MAGIC);
43    out[4..6].copy_from_slice(&VERSION.to_le_bytes());
44    out[6..8].copy_from_slice(&0u16.to_le_bytes());
45
46    let layer_count_u32 = u32::try_from(layers.len()).map_err(|_| ModelFormatError::BadHeader)?;
47    let weights_len_u32 = u32::try_from(weights.len()).map_err(|_| ModelFormatError::BadHeader)?;
48    let biases_len_u32 = u32::try_from(biases.len()).map_err(|_| ModelFormatError::BadHeader)?;
49
50    out[8..12].copy_from_slice(&layer_count_u32.to_le_bytes());
51    out[12..16].copy_from_slice(&weights_len_u32.to_le_bytes());
52    out[16..20].copy_from_slice(&biases_len_u32.to_le_bytes());
53
54    let mut cursor = HEADER_SIZE;
55    for layer in layers {
56        let dense = match layer {
57            LayerSpec::Dense(v) => *v,
58        };
59        validate_dense(&dense, weights.len(), biases.len())?;
60
61        out[cursor..cursor+4].copy_from_slice(&u32::try_from(dense.input_size).map_err(|_| ModelFormatError::BadHeader)?.to_le_bytes());
62        cursor += 4;
63        out[cursor..cursor+4].copy_from_slice(&u32::try_from(dense.output_size).map_err(|_| ModelFormatError::BadHeader)?.to_le_bytes());
64        cursor += 4;
65        out[cursor..cursor+4].copy_from_slice(&u32::try_from(dense.weight_offset).map_err(|_| ModelFormatError::BadHeader)?.to_le_bytes());
66        cursor += 4;
67        out[cursor..cursor+4].copy_from_slice(&u32::try_from(dense.bias_offset).map_err(|_| ModelFormatError::BadHeader)?.to_le_bytes());
68        cursor += 4;
69        out[cursor] = dense.activation.to_u8();
70        cursor += 1;
71        out[cursor..cursor+3].copy_from_slice(&[0u8; 3]);
72        cursor += 3;
73    }
74
75    for &w in weights {
76        out[cursor..cursor+4].copy_from_slice(&w.to_le_bytes());
77        cursor += 4;
78    }
79    for &b in biases {
80        out[cursor..cursor+4].copy_from_slice(&b.to_le_bytes());
81        cursor += 4;
82    }
83
84    Ok(cursor)
85}
86
87pub fn decode_dense_model_v1(
88    bytes: &[u8],
89    layers_out: &mut [LayerSpec],
90    weights_out: &mut [f32],
91    biases_out: &mut [f32],
92) -> Result<DecodedCounts, ModelFormatError> {
93    if bytes.len() < HEADER_SIZE {
94        return Err(ModelFormatError::Truncated);
95    }
96    if &bytes[0..4] != MAGIC {
97        return Err(ModelFormatError::BadMagic);
98    }
99    let version = u16::from_le_bytes(bytes[4..6].try_into().map_err(|_| ModelFormatError::BadHeader)?);
100    if version != VERSION {
101        return Err(ModelFormatError::BadVersion);
102    }
103
104    let layer_count = u32::from_le_bytes(bytes[8..12].try_into().map_err(|_| ModelFormatError::BadHeader)?) as usize;
105    let weights_len = u32::from_le_bytes(bytes[12..16].try_into().map_err(|_| ModelFormatError::BadHeader)?) as usize;
106    let biases_len = u32::from_le_bytes(bytes[16..20].try_into().map_err(|_| ModelFormatError::BadHeader)?) as usize;
107
108    if layers_out.len() < layer_count || weights_out.len() < weights_len || biases_out.len() < biases_len {
109        return Err(ModelFormatError::CapacityTooSmall);
110    }
111
112    let expected = encoded_size_v1(layer_count, weights_len, biases_len).ok_or(ModelFormatError::BadHeader)?;
113    if bytes.len() < expected {
114        return Err(ModelFormatError::Truncated);
115    }
116
117    let mut cursor = HEADER_SIZE;
118    for slot in layers_out.iter_mut().take(layer_count) {
119        let input_size = u32::from_le_bytes(bytes[cursor..cursor+4].try_into().map_err(|_| ModelFormatError::BadHeader)?) as usize;
120        cursor += 4;
121        let output_size = u32::from_le_bytes(bytes[cursor..cursor+4].try_into().map_err(|_| ModelFormatError::BadHeader)?) as usize;
122        cursor += 4;
123        let weight_offset = u32::from_le_bytes(bytes[cursor..cursor+4].try_into().map_err(|_| ModelFormatError::BadHeader)?) as usize;
124        cursor += 4;
125        let bias_offset = u32::from_le_bytes(bytes[cursor..cursor+4].try_into().map_err(|_| ModelFormatError::BadHeader)?) as usize;
126        cursor += 4;
127        let activation = ActivationKind::from_u8(bytes[cursor]).ok_or(ModelFormatError::InvalidActivation)?;
128        cursor += 1;
129        cursor += 3;
130
131        let dense = DenseLayerDesc {
132            input_size,
133            output_size,
134            weight_offset,
135            bias_offset,
136            activation,
137        };
138        validate_dense(&dense, weights_len, biases_len)?;
139        *slot = LayerSpec::Dense(dense);
140    }
141
142    for w in weights_out.iter_mut().take(weights_len) {
143        let v = f32::from_le_bytes(bytes[cursor..cursor+4].try_into().map_err(|_| ModelFormatError::Truncated)?);
144        *w = v;
145        cursor += 4;
146    }
147    for b in biases_out.iter_mut().take(biases_len) {
148        let v = f32::from_le_bytes(bytes[cursor..cursor+4].try_into().map_err(|_| ModelFormatError::Truncated)?);
149        *b = v;
150        cursor += 4;
151    }
152
153    Ok(DecodedCounts {
154        layers: layer_count,
155        weights: weights_len,
156        biases: biases_len,
157    })
158}
159
160fn validate_dense(layer: &DenseLayerDesc, weights_len: usize, biases_len: usize) -> Result<(), ModelFormatError> {
161    if layer.input_size == 0 || layer.output_size == 0 {
162        return Err(ModelFormatError::InvalidLayer);
163    }
164    let w_len = layer.weight_len().ok_or(ModelFormatError::InvalidLayer)?;
165    let w_end = layer.weight_offset.checked_add(w_len).ok_or(ModelFormatError::InvalidLayer)?;
166    let b_end = layer.bias_offset.checked_add(layer.output_size).ok_or(ModelFormatError::InvalidLayer)?;
167    if w_end > weights_len || b_end > biases_len {
168        return Err(ModelFormatError::InvalidLayer);
169    }
170    Ok(())
171}