brique/
save_load.rs

1use crate::{layers::Layer, matrix::Matrix, model::Model, optimizer::Optimizer};
2use core::panic;
3use std::{collections::HashMap, fmt, fs};
4
5const FILE_EXTENSION: &str = ".brq";
6const VERSION: u8 = 2;
7const HEADER_SIZE: u64 = 15;
8// CAT
9const START_OF_OBJECT_MAGIC_NUMBER: [u8; 3] = [67, 65, 84];
10// COOKIE
11const START_OF_FILE_MAGIC_NUMBER: [u8; 6] = [67, 79, 79, 75, 73, 69];
12
13struct LookupStructBinaryId {
14    lookup_table: HashMap<String, u8>,
15}
16
17impl LookupStructBinaryId {
18    pub fn init() -> LookupStructBinaryId {
19        let mut lookup_table = LookupStructBinaryId {
20            lookup_table: HashMap::new(),
21        };
22
23        lookup_table.lookup_table.insert("Matrix".to_string(), 0);
24        lookup_table.lookup_table.insert("Layer".to_string(), 1);
25        lookup_table.lookup_table.insert("Model".to_string(), 2);
26
27        lookup_table
28    }
29
30    pub fn lookup(self, struct_name: &str) -> u8 {
31        let value: Option<&u8> = self.lookup_table.get(struct_name);
32
33        if value.is_none() {
34            panic!("Key in struct id lookup table not found");
35        }
36
37        *value.unwrap()
38    }
39}
40
41#[derive(Debug)]
42pub enum ModelManagementError {
43    CouldNotSaveModel(String),
44    CouldNotReadFile(String),
45    CouldNotDecodeBinary(String),
46}
47
48impl fmt::Display for ModelManagementError {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        match self {
51            ModelManagementError::CouldNotSaveModel(msg) => {
52                write!(f, "Could not save the model, details : {}", msg)
53            }
54            ModelManagementError::CouldNotReadFile(msg) => {
55                write!(f, "Could not read file, details : {}", msg)
56            }
57            ModelManagementError::CouldNotDecodeBinary(msg) => {
58                write!(f, "Could not decode binary, details : {}", msg)
59            }
60        }
61    }
62}
63
64pub fn save_model(model: &Model, file_path: String) -> Result<(), ModelManagementError> {
65    let mut byte_stream: Vec<u8> = vec![];
66    byte_stream.append(&mut model_to_binary(model));
67    byte_stream.splice(0..0, add_header(byte_stream.len() as u64));
68
69    let res_write = fs::write(file_path + FILE_EXTENSION, byte_stream);
70
71    if res_write.is_ok() {
72        Ok(())
73    } else {
74        Err(ModelManagementError::CouldNotSaveModel(
75            res_write.unwrap_err().to_string(),
76        ))
77    }
78}
79
80pub fn load_model(file_path: String) -> Result<Model, ModelManagementError> {
81    let byte_stream: Vec<u8> = match fs::read(file_path + FILE_EXTENSION) {
82        Ok(output) => output,
83        Err(e) => return Err(ModelManagementError::CouldNotReadFile(e.to_string())),
84    };
85
86    match check_header(&byte_stream) {
87        Ok(()) => (),
88        Err(e) => return Err(e),
89    };
90
91    binary_to_model(&byte_stream, HEADER_SIZE as usize)
92}
93
94pub fn load_model_from_byte_stream(byte_stream: &Vec<u8>) -> Result<Model, ModelManagementError> {
95    match check_header(&byte_stream) {
96        Ok(()) => (),
97        Err(e) => return Err(e),
98    };
99
100    binary_to_model(&byte_stream, HEADER_SIZE as usize)
101}
102
103// header (size 15 bytes)
104// magic number : 6 bytes
105// version, would match the version of the release of the lib, i.e, 0.2 => 2, 0.3 => 3 .... 1.0 => 10, 1.1 => 11 : 1 byte
106// length of the binary (data and header combined) in bytes : 8 bytes
107pub fn add_header(data_size: u64) -> Vec<u8> {
108    let mut header: Vec<u8> = vec![];
109    header.append(&mut START_OF_FILE_MAGIC_NUMBER.to_vec());
110    header.push(VERSION);
111    header.append(&mut (data_size + HEADER_SIZE).to_be_bytes().to_vec());
112
113    header
114}
115
116pub fn check_header(byte_stream: &Vec<u8>) -> Result<(), ModelManagementError> {
117    let mut offset: usize = 0;
118    if offset + 6 > byte_stream.len() {
119        return Err(ModelManagementError::CouldNotDecodeBinary(
120            "while attempting to decode the header : Unexpected EOF".to_string(),
121        ));
122    }
123    if byte_stream[offset..offset + 6] != START_OF_FILE_MAGIC_NUMBER {
124        return Err(ModelManagementError::CouldNotDecodeBinary("while attempting to decode the header : Binary start of the file code not found, file may be corrupted".to_string()));
125    }
126
127    offset += 6;
128
129    if offset > byte_stream.len() {
130        return Err(ModelManagementError::CouldNotDecodeBinary(
131            "while attempting to decode the header : Unexpected EOF".to_string(),
132        ));
133    }
134
135    if byte_stream[offset] != VERSION {
136        return Err(ModelManagementError::CouldNotDecodeBinary(
137            "while attempting to decode the header : wrong file version".to_string(),
138        ));
139    }
140    offset += 1;
141
142    if offset + 8 > byte_stream.len() {
143        return Err(ModelManagementError::CouldNotDecodeBinary(
144            "while attempting to decode the header : Unexpected EOF".to_string(),
145        ));
146    }
147    let length: u64 = u64::from_be_bytes(byte_stream[offset..offset + 8].try_into().unwrap());
148
149    if length != byte_stream.len() as u64 {
150        return Err(ModelManagementError::CouldNotDecodeBinary("while attempting to decode the header : file length different than expected, file may be corrupted".to_string()));
151    }
152
153    Ok(())
154}
155
156pub fn f64_array_to_binary(input: &Vec<f64>) -> Vec<u8> {
157    let mut binary: Vec<u8> = vec![];
158
159    input
160        .iter()
161        .for_each(|v| binary.append(&mut v.to_be_bytes().to_vec()));
162
163    binary
164}
165
166pub fn binary_to_f64_array(input: Vec<u8>) -> Vec<f64> {
167    let sized_slice_chunks = input.chunks(8);
168
169    let mut output: Vec<f64> = vec![];
170    sized_slice_chunks.for_each(|v| {
171        let float = v.try_into().expect("Error while parsing binary chunks");
172        output.push(f64::from_be_bytes(float));
173    });
174
175    output
176}
177
178// id
179// height usize (u64)
180// width usize (u64)
181// transposed bool
182// data Vec<f64>
183pub fn matrix_to_binary(input: &Matrix) -> Vec<u8> {
184    let mut output: Vec<u8> = vec![];
185    // forcing usize to 64bit, just in case we are running on a 32bit system
186    // not sure its the best way to deal with this
187    let mut height_binary = (input.height as u64).to_be_bytes().to_vec();
188    let mut width_binary = (input.width as u64).to_be_bytes().to_vec();
189
190    let id_lookup_table = LookupStructBinaryId::init();
191
192    output.append(&mut START_OF_OBJECT_MAGIC_NUMBER.to_vec());
193    output.push(id_lookup_table.lookup("Matrix"));
194    output.push(input.transposed as u8);
195    output.append(&mut height_binary);
196    output.append(&mut width_binary);
197    output.append(&mut f64_array_to_binary(&input.data));
198
199    output
200}
201
202pub fn binary_to_matrix(
203    byte_stream: &Vec<u8>,
204    input_offset: usize,
205) -> Result<(Matrix, usize), ModelManagementError> {
206    let mut offset = input_offset;
207
208    if byte_stream[offset..offset + 3] != START_OF_OBJECT_MAGIC_NUMBER {
209        return Err(ModelManagementError::CouldNotDecodeBinary("while attempting to decode a matrix : Binary start of object code not found, file may be corrupted".to_string()));
210    }
211
212    offset += 3;
213    let id_lookup_table = LookupStructBinaryId::init();
214
215    if byte_stream[offset] != id_lookup_table.lookup("Matrix") {
216        return Err(ModelManagementError::CouldNotDecodeBinary("while attempting to decode a matrix : Binary id code does not match the lookup table for the Matrix entry, file may be corrupted".to_string()));
217    }
218    offset += 1;
219
220    let transposed: bool = byte_stream[offset] != 0;
221    offset += 1;
222
223    let height: usize =
224        u64::from_be_bytes(byte_stream[offset..offset + 8].try_into().unwrap()) as usize;
225    offset += 8;
226
227    let width: usize =
228        u64::from_be_bytes(byte_stream[offset..offset + 8].try_into().unwrap()) as usize;
229    offset += 8;
230
231    let data_size: usize = height * width * 8;
232    if offset + data_size > byte_stream.len() {
233        return Err(ModelManagementError::CouldNotDecodeBinary(
234            "Save binary reading - while attempting to decode a matrix : Unexpected EOF"
235                .to_string(),
236        ));
237    }
238
239    if offset + data_size + 3 <= byte_stream.len() {
240        if byte_stream[offset + data_size..offset + data_size + 3] != START_OF_OBJECT_MAGIC_NUMBER {
241            return Err(ModelManagementError::CouldNotDecodeBinary("while attempting to decode a matrix : Binary start of object code not found, file may be corrupted".to_string()));
242        }
243    } else {
244        if offset + data_size != byte_stream.len() {
245            return Err(ModelManagementError::CouldNotDecodeBinary(
246                "while attempting to decode a matrix : Unexpected EOF".to_string(),
247            ));
248        }
249    }
250
251    let data: Vec<f64> = binary_to_f64_array(byte_stream[offset..offset + data_size].to_vec());
252    offset += data_size;
253
254    let output_matrix = Matrix {
255        height,
256        width,
257        transposed,
258        data,
259    };
260
261    Ok((output_matrix, offset))
262}
263
264// weights : matrix
265// biases : matrix
266// activation : bool
267pub fn layer_to_binary(input_layer: &Layer) -> Vec<u8> {
268    let mut output: Vec<u8> = vec![];
269
270    let id_lookup_table = LookupStructBinaryId::init();
271
272    output.append(&mut START_OF_OBJECT_MAGIC_NUMBER.to_vec());
273    output.push(id_lookup_table.lookup("Layer"));
274    output.push(input_layer.relu as u8);
275    output.append(&mut matrix_to_binary(&input_layer.weights_t));
276    output.append(&mut matrix_to_binary(&input_layer.biases));
277
278    output
279}
280
281pub fn binary_to_layer(
282    byte_stream: &Vec<u8>,
283    input_offset: usize,
284) -> Result<(Layer, usize), ModelManagementError> {
285    let mut offset = input_offset;
286
287    if byte_stream[offset..offset + 3] != START_OF_OBJECT_MAGIC_NUMBER {
288        return Err(ModelManagementError::CouldNotDecodeBinary("while attempting to decode a layer : Binary start of object code not found, file may be corrupted".to_string()));
289    }
290    offset += 3;
291    let id_lookup_table = LookupStructBinaryId::init();
292
293    if byte_stream[offset] != id_lookup_table.lookup("Layer") {
294        return Err(ModelManagementError::CouldNotDecodeBinary("while attempting to decode a layer : Binary id code does not match the lookup table for the Matrix entry, file may be corrupted".to_string()));
295    }
296    offset += 1;
297
298    let activation: bool = byte_stream[offset] != 0;
299    offset += 1;
300
301    let (weights_t, offset) = match binary_to_matrix(byte_stream, offset) {
302        Ok((matrix, offset)) => (matrix, offset),
303        Err(e) => return Err(e),
304    };
305
306    let (biases, offset) = match binary_to_matrix(byte_stream, offset) {
307        Ok((matrix, offset)) => (matrix, offset),
308        Err(e) => return Err(e),
309    };
310
311    let output_layer = Layer::init_with_data(weights_t, biases, activation);
312
313    Ok((output_layer, offset))
314}
315
316// learning step f64
317// lambda f64
318// number of layers
319// layres Vec<Layer>
320pub fn model_to_binary(input_model: &Model) -> Vec<u8> {
321    let mut output: Vec<u8> = vec![];
322
323    let id_lookup_table = LookupStructBinaryId::init();
324
325    output.append(&mut START_OF_OBJECT_MAGIC_NUMBER.to_vec());
326    output.push(id_lookup_table.lookup("Model"));
327    output.append(&mut input_model.lambda.to_be_bytes().to_vec());
328    output.append(&mut (input_model.layers.len() as u64).to_be_bytes().to_vec());
329
330    input_model
331        .layers
332        .iter()
333        .for_each(|layer| output.append(&mut layer_to_binary(&layer)));
334
335    output
336}
337
338pub fn binary_to_model(
339    byte_stream: &Vec<u8>,
340    input_offset: usize,
341) -> Result<Model, ModelManagementError> {
342    let mut offset: usize = input_offset;
343
344    if byte_stream[offset..offset + 3] != START_OF_OBJECT_MAGIC_NUMBER {
345        return Err(ModelManagementError::CouldNotDecodeBinary("while attempting to decode the model : Binary start of object code not found, file may be corrupted".to_string()));
346    }
347
348    offset += 3;
349    let id_lookup_table = LookupStructBinaryId::init();
350
351    if byte_stream[offset] != id_lookup_table.lookup("Model") {
352        return Err(ModelManagementError::CouldNotDecodeBinary("while attempting to decode the model : Binary id code does not match the lookup table for the Matrix entry, file may be corrupted".to_string()));
353    }
354    offset += 1;
355
356    let lambda: f64 = f64::from_be_bytes(byte_stream[offset..offset + 8].try_into().unwrap());
357    offset += 8;
358
359    let number_of_layers: usize =
360        u64::from_be_bytes(byte_stream[offset..offset + 8].try_into().unwrap()) as usize;
361    offset += 8;
362
363    let mut layers: Vec<Layer> = vec![];
364    for _ in 0..number_of_layers {
365        let (layer, new_offset) = match binary_to_layer(byte_stream, offset) {
366            Ok((layer, offset)) => (layer, offset),
367            Err(e) => return Err(e),
368        };
369
370        offset = new_offset;
371        layers.push(layer);
372    }
373
374    Ok(Model::init(
375        layers,
376        Optimizer::SGD {
377            learning_step: 0.01,
378        },
379        lambda,
380    ))
381}
382
383//unit test
384#[cfg(test)]
385mod tests {
386    use core::panic;
387    use std::fs;
388
389    use crate::{layers::Layer, model::Model, optimizer::Optimizer, save_load::FILE_EXTENSION};
390
391    use super::{load_model, save_model};
392
393    #[test]
394    fn succesful_model_save_and_load() {
395        let layer1 = Layer::init(10, 100, true);
396        let layer2 = Layer::init(100, 200, true);
397        let layer3 = Layer::init(200, 200, true);
398        let layer4 = Layer::init(200, 3, false);
399
400        let lambda: f64 = 0.012;
401
402        let file_path: String = "test_model_save".to_string();
403        let model = Model::init(
404            vec![layer1, layer2, layer3, layer4],
405            Optimizer::SGD {
406                learning_step: 0.01,
407            },
408            lambda,
409        );
410        save_model(&model, file_path.clone()).unwrap();
411
412        let loaded_model = match load_model(file_path.clone()) {
413            Ok(model) => model,
414            Err(e) => panic!("{}", e),
415        };
416
417        match fs::remove_file(file_path + FILE_EXTENSION) {
418            Ok(()) => (),
419            Err(e) => panic!("{}", e),
420        };
421
422        assert_eq!(
423            model.lambda, loaded_model.lambda,
424            "Models lambdas are not the same"
425        );
426        assert_eq!(
427            model.layers.len(),
428            loaded_model.layers.len(),
429            "Models do not have the same number of layers"
430        );
431        for i in 0..model.layers.len() {
432            assert!(
433                model.layers[i]
434                    .weights_t
435                    .is_equal(&loaded_model.layers[i].weights_t, 10),
436                "Layer {} weights are different in the two models",
437                i
438            );
439            assert!(
440                model.layers[i]
441                    .biases
442                    .is_equal(&loaded_model.layers[i].biases, 10),
443                "Layer {} biases are different in the two models",
444                i
445            );
446        }
447    }
448}