1use crate::{layers::Layer, matrix::Matrix, model::Model, optimizer::Optimizer};
2use core::panic;
3use std::{collections::HashMap, fmt, fs};
4
5const FILE_EXTENSION: &str = ".brq";
6const VERSION: u8 = 2;
7const HEADER_SIZE: u64 = 15;
8const START_OF_OBJECT_MAGIC_NUMBER: [u8; 3] = [67, 65, 84];
10const START_OF_FILE_MAGIC_NUMBER: [u8; 6] = [67, 79, 79, 75, 73, 69];
12
13struct LookupStructBinaryId {
14 lookup_table: HashMap<String, u8>,
15}
16
17impl LookupStructBinaryId {
18 pub fn init() -> LookupStructBinaryId {
19 let mut lookup_table = LookupStructBinaryId {
20 lookup_table: HashMap::new(),
21 };
22
23 lookup_table.lookup_table.insert("Matrix".to_string(), 0);
24 lookup_table.lookup_table.insert("Layer".to_string(), 1);
25 lookup_table.lookup_table.insert("Model".to_string(), 2);
26
27 lookup_table
28 }
29
30 pub fn lookup(self, struct_name: &str) -> u8 {
31 let value: Option<&u8> = self.lookup_table.get(struct_name);
32
33 if value.is_none() {
34 panic!("Key in struct id lookup table not found");
35 }
36
37 *value.unwrap()
38 }
39}
40
41#[derive(Debug)]
42pub enum ModelManagementError {
43 CouldNotSaveModel(String),
44 CouldNotReadFile(String),
45 CouldNotDecodeBinary(String),
46}
47
48impl fmt::Display for ModelManagementError {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 match self {
51 ModelManagementError::CouldNotSaveModel(msg) => {
52 write!(f, "Could not save the model, details : {}", msg)
53 }
54 ModelManagementError::CouldNotReadFile(msg) => {
55 write!(f, "Could not read file, details : {}", msg)
56 }
57 ModelManagementError::CouldNotDecodeBinary(msg) => {
58 write!(f, "Could not decode binary, details : {}", msg)
59 }
60 }
61 }
62}
63
64pub fn save_model(model: &Model, file_path: String) -> Result<(), ModelManagementError> {
65 let mut byte_stream: Vec<u8> = vec![];
66 byte_stream.append(&mut model_to_binary(model));
67 byte_stream.splice(0..0, add_header(byte_stream.len() as u64));
68
69 let res_write = fs::write(file_path + FILE_EXTENSION, byte_stream);
70
71 if res_write.is_ok() {
72 Ok(())
73 } else {
74 Err(ModelManagementError::CouldNotSaveModel(
75 res_write.unwrap_err().to_string(),
76 ))
77 }
78}
79
80pub fn load_model(file_path: String) -> Result<Model, ModelManagementError> {
81 let byte_stream: Vec<u8> = match fs::read(file_path + FILE_EXTENSION) {
82 Ok(output) => output,
83 Err(e) => return Err(ModelManagementError::CouldNotReadFile(e.to_string())),
84 };
85
86 match check_header(&byte_stream) {
87 Ok(()) => (),
88 Err(e) => return Err(e),
89 };
90
91 binary_to_model(&byte_stream, HEADER_SIZE as usize)
92}
93
94pub fn load_model_from_byte_stream(byte_stream: &Vec<u8>) -> Result<Model, ModelManagementError> {
95 match check_header(&byte_stream) {
96 Ok(()) => (),
97 Err(e) => return Err(e),
98 };
99
100 binary_to_model(&byte_stream, HEADER_SIZE as usize)
101}
102
103pub fn add_header(data_size: u64) -> Vec<u8> {
108 let mut header: Vec<u8> = vec![];
109 header.append(&mut START_OF_FILE_MAGIC_NUMBER.to_vec());
110 header.push(VERSION);
111 header.append(&mut (data_size + HEADER_SIZE).to_be_bytes().to_vec());
112
113 header
114}
115
116pub fn check_header(byte_stream: &Vec<u8>) -> Result<(), ModelManagementError> {
117 let mut offset: usize = 0;
118 if offset + 6 > byte_stream.len() {
119 return Err(ModelManagementError::CouldNotDecodeBinary(
120 "while attempting to decode the header : Unexpected EOF".to_string(),
121 ));
122 }
123 if byte_stream[offset..offset + 6] != START_OF_FILE_MAGIC_NUMBER {
124 return Err(ModelManagementError::CouldNotDecodeBinary("while attempting to decode the header : Binary start of the file code not found, file may be corrupted".to_string()));
125 }
126
127 offset += 6;
128
129 if offset > byte_stream.len() {
130 return Err(ModelManagementError::CouldNotDecodeBinary(
131 "while attempting to decode the header : Unexpected EOF".to_string(),
132 ));
133 }
134
135 if byte_stream[offset] != VERSION {
136 return Err(ModelManagementError::CouldNotDecodeBinary(
137 "while attempting to decode the header : wrong file version".to_string(),
138 ));
139 }
140 offset += 1;
141
142 if offset + 8 > byte_stream.len() {
143 return Err(ModelManagementError::CouldNotDecodeBinary(
144 "while attempting to decode the header : Unexpected EOF".to_string(),
145 ));
146 }
147 let length: u64 = u64::from_be_bytes(byte_stream[offset..offset + 8].try_into().unwrap());
148
149 if length != byte_stream.len() as u64 {
150 return Err(ModelManagementError::CouldNotDecodeBinary("while attempting to decode the header : file length different than expected, file may be corrupted".to_string()));
151 }
152
153 Ok(())
154}
155
156pub fn f64_array_to_binary(input: &Vec<f64>) -> Vec<u8> {
157 let mut binary: Vec<u8> = vec![];
158
159 input
160 .iter()
161 .for_each(|v| binary.append(&mut v.to_be_bytes().to_vec()));
162
163 binary
164}
165
166pub fn binary_to_f64_array(input: Vec<u8>) -> Vec<f64> {
167 let sized_slice_chunks = input.chunks(8);
168
169 let mut output: Vec<f64> = vec![];
170 sized_slice_chunks.for_each(|v| {
171 let float = v.try_into().expect("Error while parsing binary chunks");
172 output.push(f64::from_be_bytes(float));
173 });
174
175 output
176}
177
178pub fn matrix_to_binary(input: &Matrix) -> Vec<u8> {
184 let mut output: Vec<u8> = vec![];
185 let mut height_binary = (input.height as u64).to_be_bytes().to_vec();
188 let mut width_binary = (input.width as u64).to_be_bytes().to_vec();
189
190 let id_lookup_table = LookupStructBinaryId::init();
191
192 output.append(&mut START_OF_OBJECT_MAGIC_NUMBER.to_vec());
193 output.push(id_lookup_table.lookup("Matrix"));
194 output.push(input.transposed as u8);
195 output.append(&mut height_binary);
196 output.append(&mut width_binary);
197 output.append(&mut f64_array_to_binary(&input.data));
198
199 output
200}
201
202pub fn binary_to_matrix(
203 byte_stream: &Vec<u8>,
204 input_offset: usize,
205) -> Result<(Matrix, usize), ModelManagementError> {
206 let mut offset = input_offset;
207
208 if byte_stream[offset..offset + 3] != START_OF_OBJECT_MAGIC_NUMBER {
209 return Err(ModelManagementError::CouldNotDecodeBinary("while attempting to decode a matrix : Binary start of object code not found, file may be corrupted".to_string()));
210 }
211
212 offset += 3;
213 let id_lookup_table = LookupStructBinaryId::init();
214
215 if byte_stream[offset] != id_lookup_table.lookup("Matrix") {
216 return Err(ModelManagementError::CouldNotDecodeBinary("while attempting to decode a matrix : Binary id code does not match the lookup table for the Matrix entry, file may be corrupted".to_string()));
217 }
218 offset += 1;
219
220 let transposed: bool = byte_stream[offset] != 0;
221 offset += 1;
222
223 let height: usize =
224 u64::from_be_bytes(byte_stream[offset..offset + 8].try_into().unwrap()) as usize;
225 offset += 8;
226
227 let width: usize =
228 u64::from_be_bytes(byte_stream[offset..offset + 8].try_into().unwrap()) as usize;
229 offset += 8;
230
231 let data_size: usize = height * width * 8;
232 if offset + data_size > byte_stream.len() {
233 return Err(ModelManagementError::CouldNotDecodeBinary(
234 "Save binary reading - while attempting to decode a matrix : Unexpected EOF"
235 .to_string(),
236 ));
237 }
238
239 if offset + data_size + 3 <= byte_stream.len() {
240 if byte_stream[offset + data_size..offset + data_size + 3] != START_OF_OBJECT_MAGIC_NUMBER {
241 return Err(ModelManagementError::CouldNotDecodeBinary("while attempting to decode a matrix : Binary start of object code not found, file may be corrupted".to_string()));
242 }
243 } else {
244 if offset + data_size != byte_stream.len() {
245 return Err(ModelManagementError::CouldNotDecodeBinary(
246 "while attempting to decode a matrix : Unexpected EOF".to_string(),
247 ));
248 }
249 }
250
251 let data: Vec<f64> = binary_to_f64_array(byte_stream[offset..offset + data_size].to_vec());
252 offset += data_size;
253
254 let output_matrix = Matrix {
255 height,
256 width,
257 transposed,
258 data,
259 };
260
261 Ok((output_matrix, offset))
262}
263
264pub fn layer_to_binary(input_layer: &Layer) -> Vec<u8> {
268 let mut output: Vec<u8> = vec![];
269
270 let id_lookup_table = LookupStructBinaryId::init();
271
272 output.append(&mut START_OF_OBJECT_MAGIC_NUMBER.to_vec());
273 output.push(id_lookup_table.lookup("Layer"));
274 output.push(input_layer.relu as u8);
275 output.append(&mut matrix_to_binary(&input_layer.weights_t));
276 output.append(&mut matrix_to_binary(&input_layer.biases));
277
278 output
279}
280
281pub fn binary_to_layer(
282 byte_stream: &Vec<u8>,
283 input_offset: usize,
284) -> Result<(Layer, usize), ModelManagementError> {
285 let mut offset = input_offset;
286
287 if byte_stream[offset..offset + 3] != START_OF_OBJECT_MAGIC_NUMBER {
288 return Err(ModelManagementError::CouldNotDecodeBinary("while attempting to decode a layer : Binary start of object code not found, file may be corrupted".to_string()));
289 }
290 offset += 3;
291 let id_lookup_table = LookupStructBinaryId::init();
292
293 if byte_stream[offset] != id_lookup_table.lookup("Layer") {
294 return Err(ModelManagementError::CouldNotDecodeBinary("while attempting to decode a layer : Binary id code does not match the lookup table for the Matrix entry, file may be corrupted".to_string()));
295 }
296 offset += 1;
297
298 let activation: bool = byte_stream[offset] != 0;
299 offset += 1;
300
301 let (weights_t, offset) = match binary_to_matrix(byte_stream, offset) {
302 Ok((matrix, offset)) => (matrix, offset),
303 Err(e) => return Err(e),
304 };
305
306 let (biases, offset) = match binary_to_matrix(byte_stream, offset) {
307 Ok((matrix, offset)) => (matrix, offset),
308 Err(e) => return Err(e),
309 };
310
311 let output_layer = Layer::init_with_data(weights_t, biases, activation);
312
313 Ok((output_layer, offset))
314}
315
316pub fn model_to_binary(input_model: &Model) -> Vec<u8> {
321 let mut output: Vec<u8> = vec![];
322
323 let id_lookup_table = LookupStructBinaryId::init();
324
325 output.append(&mut START_OF_OBJECT_MAGIC_NUMBER.to_vec());
326 output.push(id_lookup_table.lookup("Model"));
327 output.append(&mut input_model.lambda.to_be_bytes().to_vec());
328 output.append(&mut (input_model.layers.len() as u64).to_be_bytes().to_vec());
329
330 input_model
331 .layers
332 .iter()
333 .for_each(|layer| output.append(&mut layer_to_binary(&layer)));
334
335 output
336}
337
338pub fn binary_to_model(
339 byte_stream: &Vec<u8>,
340 input_offset: usize,
341) -> Result<Model, ModelManagementError> {
342 let mut offset: usize = input_offset;
343
344 if byte_stream[offset..offset + 3] != START_OF_OBJECT_MAGIC_NUMBER {
345 return Err(ModelManagementError::CouldNotDecodeBinary("while attempting to decode the model : Binary start of object code not found, file may be corrupted".to_string()));
346 }
347
348 offset += 3;
349 let id_lookup_table = LookupStructBinaryId::init();
350
351 if byte_stream[offset] != id_lookup_table.lookup("Model") {
352 return Err(ModelManagementError::CouldNotDecodeBinary("while attempting to decode the model : Binary id code does not match the lookup table for the Matrix entry, file may be corrupted".to_string()));
353 }
354 offset += 1;
355
356 let lambda: f64 = f64::from_be_bytes(byte_stream[offset..offset + 8].try_into().unwrap());
357 offset += 8;
358
359 let number_of_layers: usize =
360 u64::from_be_bytes(byte_stream[offset..offset + 8].try_into().unwrap()) as usize;
361 offset += 8;
362
363 let mut layers: Vec<Layer> = vec![];
364 for _ in 0..number_of_layers {
365 let (layer, new_offset) = match binary_to_layer(byte_stream, offset) {
366 Ok((layer, offset)) => (layer, offset),
367 Err(e) => return Err(e),
368 };
369
370 offset = new_offset;
371 layers.push(layer);
372 }
373
374 Ok(Model::init(
375 layers,
376 Optimizer::SGD {
377 learning_step: 0.01,
378 },
379 lambda,
380 ))
381}
382
383#[cfg(test)]
385mod tests {
386 use core::panic;
387 use std::fs;
388
389 use crate::{layers::Layer, model::Model, optimizer::Optimizer, save_load::FILE_EXTENSION};
390
391 use super::{load_model, save_model};
392
393 #[test]
394 fn succesful_model_save_and_load() {
395 let layer1 = Layer::init(10, 100, true);
396 let layer2 = Layer::init(100, 200, true);
397 let layer3 = Layer::init(200, 200, true);
398 let layer4 = Layer::init(200, 3, false);
399
400 let lambda: f64 = 0.012;
401
402 let file_path: String = "test_model_save".to_string();
403 let model = Model::init(
404 vec![layer1, layer2, layer3, layer4],
405 Optimizer::SGD {
406 learning_step: 0.01,
407 },
408 lambda,
409 );
410 save_model(&model, file_path.clone()).unwrap();
411
412 let loaded_model = match load_model(file_path.clone()) {
413 Ok(model) => model,
414 Err(e) => panic!("{}", e),
415 };
416
417 match fs::remove_file(file_path + FILE_EXTENSION) {
418 Ok(()) => (),
419 Err(e) => panic!("{}", e),
420 };
421
422 assert_eq!(
423 model.lambda, loaded_model.lambda,
424 "Models lambdas are not the same"
425 );
426 assert_eq!(
427 model.layers.len(),
428 loaded_model.layers.len(),
429 "Models do not have the same number of layers"
430 );
431 for i in 0..model.layers.len() {
432 assert!(
433 model.layers[i]
434 .weights_t
435 .is_equal(&loaded_model.layers[i].weights_t, 10),
436 "Layer {} weights are different in the two models",
437 i
438 );
439 assert!(
440 model.layers[i]
441 .biases
442 .is_equal(&loaded_model.layers[i].biases, 10),
443 "Layer {} biases are different in the two models",
444 i
445 );
446 }
447 }
448}