unda/core/serialize/
ser_layer.rs

1use crate::core::{layer::{layers::{Layer, LayerTypes}, dense::Dense}, data::input::Input};
2
3pub struct SerializedLayer {
4    pub name: char,
5    pub rows: usize,
6    pub cols: usize,
7    pub weights: String,
8    pub bias: String
9}
10
11impl SerializedLayer {
12    pub fn new(layer: &Box<dyn Layer>, _layer_type: &LayerTypes) -> Self {
13        let rows = layer.get_weights().to_param_2d().len();
14        let cols = layer.get_weights().to_param_2d()[0].len();
15        let weights: String = SerializedLayer::flatten_string(&layer.get_weights().to_param_2d());
16        let bias = SerializedLayer::flatten_string(&layer.get_biases().to_param_2d());
17
18        Self { name: 'D', rows, cols, weights, bias }
19    }
20    pub fn from(&self) -> Box<dyn Layer> {
21        match self.name {
22            'D' => {
23                let weights_f32: Vec<f32> = self.weights.split(" ").into_iter().map(|val| val.parse().unwrap()).collect();
24                let bias_f32: Vec<f32> = self.bias.split(" ").into_iter().map(|val| val.parse().unwrap()).collect();
25                let dense_layer: Dense = Dense::new_ser(self.rows, self.cols, weights_f32, bias_f32);
26                return Box::new(dense_layer)
27            },
28            _ => panic!("Not a supported type"),
29        };
30    }
31    fn flatten_string(data: &Vec<Vec<f32>>) -> String {
32        data.to_param()
33            .into_iter()
34            .map(|d| d.to_string() + " ")
35            .collect::<String>().trim_end().to_string()
36    }
37    pub fn to_string(&self) -> String {
38        format!("{}|{}|{}|{}|{}", self.name, self.rows, self.cols, self.weights, self.bias)
39    }
40    pub fn from_string(data: String) -> Self {
41        let mut parse_res = data.split("|");
42        let name: char = parse_res.next().unwrap().chars().next().unwrap();
43        let rows: usize = str::parse(parse_res.next().unwrap()).unwrap();
44        let cols: usize = str::parse(parse_res.next().unwrap()).unwrap();
45        let weights = parse_res.next().unwrap().to_string();
46        let bias = parse_res.next().unwrap().to_string();
47
48        Self { name, rows, cols, weights, bias }
49    }
50}