rnn/model_format/
model_format.rs1use core::convert::TryInto;
2
3use crate::activations::ActivationKind;
4use crate::layers::{DenseLayerDesc, LayerSpec};
5
6const MAGIC: &[u8; 4] = b"RMD1";
7const VERSION: u16 = 1;
8const HEADER_SIZE: usize = 4 + 2 + 2 + 4 + 4 + 4;
9const LAYER_META_SIZE: usize = 4 + 4 + 4 + 4 + 1 + 3;
10
11#[derive(Clone, Copy, Debug, PartialEq, Eq)]
12pub enum ModelFormatError {
13 Truncated,
14 BadMagic,
15 BadVersion,
16 BadHeader,
17 CapacityTooSmall,
18 InvalidActivation,
19 InvalidLayer,
20}
21
22#[derive(Clone, Copy, Debug, PartialEq, Eq)]
23pub struct DecodedCounts {
24 pub layers: usize,
25 pub weights: usize,
26 pub biases: usize,
27}
28
29pub fn encoded_size_v1(layer_count: usize, weights_len: usize, biases_len: usize) -> Option<usize> {
30 let layers_bytes = layer_count.checked_mul(LAYER_META_SIZE)?;
31 let weights_bytes = weights_len.checked_mul(core::mem::size_of::<f32>())?;
32 let biases_bytes = biases_len.checked_mul(core::mem::size_of::<f32>())?;
33 HEADER_SIZE.checked_add(layers_bytes)?.checked_add(weights_bytes)?.checked_add(biases_bytes)
34}
35
36pub fn encode_dense_model_v1(layers: &[LayerSpec], weights: &[f32], biases: &[f32], out: &mut [u8]) -> Result<usize, ModelFormatError> {
37 let needed = encoded_size_v1(layers.len(), weights.len(), biases.len()).ok_or(ModelFormatError::BadHeader)?;
38 if out.len() < needed {
39 return Err(ModelFormatError::CapacityTooSmall);
40 }
41
42 out[0..4].copy_from_slice(MAGIC);
43 out[4..6].copy_from_slice(&VERSION.to_le_bytes());
44 out[6..8].copy_from_slice(&0u16.to_le_bytes());
45
46 let layer_count_u32 = u32::try_from(layers.len()).map_err(|_| ModelFormatError::BadHeader)?;
47 let weights_len_u32 = u32::try_from(weights.len()).map_err(|_| ModelFormatError::BadHeader)?;
48 let biases_len_u32 = u32::try_from(biases.len()).map_err(|_| ModelFormatError::BadHeader)?;
49
50 out[8..12].copy_from_slice(&layer_count_u32.to_le_bytes());
51 out[12..16].copy_from_slice(&weights_len_u32.to_le_bytes());
52 out[16..20].copy_from_slice(&biases_len_u32.to_le_bytes());
53
54 let mut cursor = HEADER_SIZE;
55 for layer in layers {
56 let dense = match layer {
57 LayerSpec::Dense(v) => *v,
58 };
59 validate_dense(&dense, weights.len(), biases.len())?;
60
61 out[cursor..cursor+4].copy_from_slice(&u32::try_from(dense.input_size).map_err(|_| ModelFormatError::BadHeader)?.to_le_bytes());
62 cursor += 4;
63 out[cursor..cursor+4].copy_from_slice(&u32::try_from(dense.output_size).map_err(|_| ModelFormatError::BadHeader)?.to_le_bytes());
64 cursor += 4;
65 out[cursor..cursor+4].copy_from_slice(&u32::try_from(dense.weight_offset).map_err(|_| ModelFormatError::BadHeader)?.to_le_bytes());
66 cursor += 4;
67 out[cursor..cursor+4].copy_from_slice(&u32::try_from(dense.bias_offset).map_err(|_| ModelFormatError::BadHeader)?.to_le_bytes());
68 cursor += 4;
69 out[cursor] = dense.activation.to_u8();
70 cursor += 1;
71 out[cursor..cursor+3].copy_from_slice(&[0u8; 3]);
72 cursor += 3;
73 }
74
75 for &w in weights {
76 out[cursor..cursor+4].copy_from_slice(&w.to_le_bytes());
77 cursor += 4;
78 }
79 for &b in biases {
80 out[cursor..cursor+4].copy_from_slice(&b.to_le_bytes());
81 cursor += 4;
82 }
83
84 Ok(cursor)
85}
86
87pub fn decode_dense_model_v1(
88 bytes: &[u8],
89 layers_out: &mut [LayerSpec],
90 weights_out: &mut [f32],
91 biases_out: &mut [f32],
92) -> Result<DecodedCounts, ModelFormatError> {
93 if bytes.len() < HEADER_SIZE {
94 return Err(ModelFormatError::Truncated);
95 }
96 if &bytes[0..4] != MAGIC {
97 return Err(ModelFormatError::BadMagic);
98 }
99 let version = u16::from_le_bytes(bytes[4..6].try_into().map_err(|_| ModelFormatError::BadHeader)?);
100 if version != VERSION {
101 return Err(ModelFormatError::BadVersion);
102 }
103
104 let layer_count = u32::from_le_bytes(bytes[8..12].try_into().map_err(|_| ModelFormatError::BadHeader)?) as usize;
105 let weights_len = u32::from_le_bytes(bytes[12..16].try_into().map_err(|_| ModelFormatError::BadHeader)?) as usize;
106 let biases_len = u32::from_le_bytes(bytes[16..20].try_into().map_err(|_| ModelFormatError::BadHeader)?) as usize;
107
108 if layers_out.len() < layer_count || weights_out.len() < weights_len || biases_out.len() < biases_len {
109 return Err(ModelFormatError::CapacityTooSmall);
110 }
111
112 let expected = encoded_size_v1(layer_count, weights_len, biases_len).ok_or(ModelFormatError::BadHeader)?;
113 if bytes.len() < expected {
114 return Err(ModelFormatError::Truncated);
115 }
116
117 let mut cursor = HEADER_SIZE;
118 for slot in layers_out.iter_mut().take(layer_count) {
119 let input_size = u32::from_le_bytes(bytes[cursor..cursor+4].try_into().map_err(|_| ModelFormatError::BadHeader)?) as usize;
120 cursor += 4;
121 let output_size = u32::from_le_bytes(bytes[cursor..cursor+4].try_into().map_err(|_| ModelFormatError::BadHeader)?) as usize;
122 cursor += 4;
123 let weight_offset = u32::from_le_bytes(bytes[cursor..cursor+4].try_into().map_err(|_| ModelFormatError::BadHeader)?) as usize;
124 cursor += 4;
125 let bias_offset = u32::from_le_bytes(bytes[cursor..cursor+4].try_into().map_err(|_| ModelFormatError::BadHeader)?) as usize;
126 cursor += 4;
127 let activation = ActivationKind::from_u8(bytes[cursor]).ok_or(ModelFormatError::InvalidActivation)?;
128 cursor += 1;
129 cursor += 3;
130
131 let dense = DenseLayerDesc {
132 input_size,
133 output_size,
134 weight_offset,
135 bias_offset,
136 activation,
137 };
138 validate_dense(&dense, weights_len, biases_len)?;
139 *slot = LayerSpec::Dense(dense);
140 }
141
142 for w in weights_out.iter_mut().take(weights_len) {
143 let v = f32::from_le_bytes(bytes[cursor..cursor+4].try_into().map_err(|_| ModelFormatError::Truncated)?);
144 *w = v;
145 cursor += 4;
146 }
147 for b in biases_out.iter_mut().take(biases_len) {
148 let v = f32::from_le_bytes(bytes[cursor..cursor+4].try_into().map_err(|_| ModelFormatError::Truncated)?);
149 *b = v;
150 cursor += 4;
151 }
152
153 Ok(DecodedCounts {
154 layers: layer_count,
155 weights: weights_len,
156 biases: biases_len,
157 })
158}
159
160fn validate_dense(layer: &DenseLayerDesc, weights_len: usize, biases_len: usize) -> Result<(), ModelFormatError> {
161 if layer.input_size == 0 || layer.output_size == 0 {
162 return Err(ModelFormatError::InvalidLayer);
163 }
164 let w_len = layer.weight_len().ok_or(ModelFormatError::InvalidLayer)?;
165 let w_end = layer.weight_offset.checked_add(w_len).ok_or(ModelFormatError::InvalidLayer)?;
166 let b_end = layer.bias_offset.checked_add(layer.output_size).ok_or(ModelFormatError::InvalidLayer)?;
167 if w_end > weights_len || b_end > biases_len {
168 return Err(ModelFormatError::InvalidLayer);
169 }
170 Ok(())
171}