brique/
model.rs

1use crate::activation::*;
2use crate::checkpoint::Checkpoint;
3use crate::layers::*;
4use crate::loss::*;
5use crate::matrix::*;
6use crate::optimizer::*;
7use crate::save_load::save_model;
8use crate::utils::*;
9
10#[derive(Clone)]
11pub struct Model {
12    pub layers: Vec<Layer>,
13    pub lambda: f64,
14    pub optimizer: Optimizer,
15
16    // these elements are stored in the struct for debugging purposes
17    // only if debug arg is true
18    pub layers_debug: Option<Vec<Layer>>,
19    pub input: Option<Matrix>,
20    pub input_label: Option<Matrix>,
21    pub itermediate_evaluation_results: Option<Vec<Matrix>>,
22    pub softmax_output: Option<Matrix>,
23    pub data_loss: Option<f64>,
24    pub reg_loss: Option<f64>,
25    pub loss: Option<f64>,
26    pub d_score: Option<Matrix>,
27    pub d_zs: Option<Vec<Matrix>>,
28    pub d_ws: Option<Vec<Matrix>>,
29    pub d_bs: Option<Vec<Matrix>>,
30}
31
32// all the variables begining with d (like d_score) are the derivative
33// of the loss function compared to said variable, so d_score is d Loss/ d Score
34// doing so for ease of read
35impl Model {
36    pub fn init(layers: Vec<Layer>, optimizer: Optimizer, lambda: f64) -> Model {
37        let output = Model {
38            layers,
39            lambda,
40            optimizer,
41            layers_debug: None,
42            input: None,
43            input_label: None,
44            itermediate_evaluation_results: None,
45            softmax_output: None,
46            d_zs: None,
47            d_ws: None,
48            d_bs: None,
49            d_score: None,
50            loss: None,
51            reg_loss: None,
52            data_loss: None,
53        };
54
55        output
56    }
57
58    pub fn evaluate(&mut self, input: &Matrix, debug: bool) -> Matrix {
59        for index in 0..self.layers.len() {
60            if index == 0 {
61                self.layers[0].forward(input, false);
62            } else {
63                let (l, r) = self.layers.split_at_mut(index);
64                r[0].forward(&l[index - 1].output, false);
65            }
66
67            if debug {
68                self.itermediate_evaluation_results
69                    .get_or_insert(Vec::new())
70                    .push(self.layers[index].output.clone());
71            }
72        }
73
74        let output = softmax(&self.layers[self.layers.len() - 1].output);
75
76        if debug {
77            self.softmax_output = Some(output.clone());
78        }
79
80        output
81    }
82
83    // implementing cross-entropy and L2 regulariztion
84    pub fn compute_loss(&mut self, output: &Matrix, labels: &Matrix, debug: bool) -> (f64, f64) {
85        if debug {
86            self.data_loss = Some(cross_entropy(output, labels));
87            self.reg_loss = Some(l2_reg(&self.layers, self.lambda));
88        }
89
90        (
91            cross_entropy(output, labels),
92            l2_reg(&self.layers, self.lambda),
93        )
94    }
95
96    pub fn compute_d_score(score: &Matrix, labels: &Matrix) -> Matrix {
97        let mut output: Matrix = Matrix::init_zero(score.height, score.width);
98        for r in 0..score.height {
99            for c in 0..score.width {
100                //TODO make a choice, to divide or not to divide
101                if labels.get(0, r) == c as f64 {
102                    //output.data[r][c] = (score.data[r][c] - 1.0) / score.height as f64;
103                    let v: f64 = score.get(r, c) - 1.0;
104
105                    output.set(v, r, c);
106                } else {
107                    //output.data[r][c] = score.data[r][c] / score.height as f64;
108                    let v: f64 = score.get(r, c);
109                    output.set(v, r, c);
110                }
111            }
112        }
113
114        output
115    }
116
117    pub fn update_params(&mut self, d_score: Matrix, input: Matrix, iteration: i32, debug: bool) {
118        let mut index: usize = self.layers.len() - 1;
119        let mut d_z: Matrix = d_score;
120
121        loop {
122            if index > 0 {
123                let (l, r) = self.layers.split_at_mut(index);
124                d_z = r[0].backprop(
125                    &d_z,
126                    &l[index - 1].output,
127                    l[index - 1].relu,
128                    self.lambda,
129                    &self.optimizer,
130                    iteration,
131                    false,
132                    debug,
133                    &mut self.d_ws,
134                    &mut self.d_bs,
135                    &mut self.d_zs,
136                );
137            } else if index == 0 {
138                self.layers[index].backprop(
139                    &d_z,
140                    &input,
141                    false,
142                    self.lambda,
143                    &self.optimizer,
144                    iteration,
145                    true,
146                    debug,
147                    &mut self.d_ws,
148                    &mut self.d_bs,
149                    &mut self.d_zs,
150                );
151                break;
152            }
153
154            index -= 1;
155        }
156    }
157
158    // the steps :
159    // before every epoch :
160    //  - shuffle dataset (use the algo of rand crate)
161    //  - generate batch from shuffled dataset
162    pub fn train(
163        &mut self,
164        data: &Matrix,
165        labels: &Matrix,
166        batch_size: u32,
167        epochs: u32,
168        validation_dataset_size: usize,
169        checkpoint: Option<Checkpoint>,
170        print_frequency: usize,
171        debug: bool,
172        silent_mode: bool, // if true will not print anything
173    ) -> Option<Vec<Model>> {
174        let mut network_history: Option<Vec<Model>> = None;
175
176        let mut index_table: Vec<u32>;
177        let index_validation: Vec<u32>;
178        let mut validation_data: Matrix =
179            Matrix::init_zero(validation_dataset_size as usize, data.width);
180        let mut validation_label: Matrix = Matrix::init_zero(1, validation_dataset_size as usize);
181
182        // first step is to randomize the input data
183        // and to create the validation dataset
184        // if debugging mode is on, no validation and no randomization
185        if !debug {
186            index_table = generate_vec_rand_unique(data.height as u32);
187
188            index_validation = index_table[0..validation_dataset_size].to_vec();
189            index_table.drain(0..validation_dataset_size);
190
191            for i in 0..validation_dataset_size as usize {
192                let index: usize = index_validation[i] as usize;
193                // TODO write test for validation dataset creation
194                validation_data.set_row(&data.get_row(index), i);
195                validation_label.set(labels.get(0, index), 0, i);
196            }
197        } else {
198            index_table = (0..data.height as u32).collect();
199        }
200
201        let mut iteration: i32 = 1;
202        let mut best_val_acc: Option<f64> = None;
203        let mut best_val_loss: Option<f64> = None;
204        for epoch in 0..epochs {
205            let index_matrix: Vec<Vec<f64>> = generate_batch_index(&index_table, batch_size);
206
207            for batch_row in 0..index_matrix.len() {
208                let batch_indexes: Vec<f64> = index_matrix[batch_row].clone();
209                let mut batch_data: Matrix = Matrix::init_zero(batch_indexes.len(), data.width);
210                let mut batch_label: Matrix = Matrix::init_zero(1, batch_indexes.len());
211
212                for i in 0..batch_indexes.len() as usize {
213                    let index: usize = batch_indexes[i] as usize;
214                    batch_data.set_row(&data.get_row(index), i);
215                    batch_label.set(labels.get(0, index), 0, i);
216                }
217
218                let score: Matrix = self.evaluate(&batch_data, debug);
219                let d_score: Matrix = Model::compute_d_score(&score, &batch_label);
220
221                if debug {
222                    let (loss, l2_reg_penalty): (f64, f64) =
223                        self.compute_loss(&score, &batch_label, debug);
224                    self.d_score = Some(d_score.clone());
225                    self.input = Some(batch_data.clone());
226                    self.input_label = Some(batch_label.clone());
227                    self.loss = Some(loss + l2_reg_penalty);
228                    self.layers_debug = Some(self.layers.clone());
229                }
230
231                self.update_params(d_score, batch_data, iteration, debug);
232
233                if debug {
234                    network_history.get_or_insert(Vec::new()).push(self.clone());
235                    self.itermediate_evaluation_results = None;
236                    self.d_zs = None;
237                    self.d_bs = None;
238                    self.d_ws = None;
239                }
240
241                match &checkpoint {
242                    Some(checkpoint) => match checkpoint {
243                        Checkpoint::ValAcc { save_path } => {
244                            let score_validation: Matrix = self.evaluate(&validation_data, false);
245                            let acc_validation: f64 =
246                                self.accuracy(&score_validation, &validation_label);
247                            match best_val_acc {
248                                Some(prev) => {
249                                    if acc_validation > prev {
250                                        save_model(self, save_path.to_string()).unwrap();
251                                        best_val_acc = Some(acc_validation);
252                                    }
253                                }
254                                None => {
255                                    best_val_acc = Some(acc_validation);
256                                }
257                            }
258                        }
259                        Checkpoint::ValLoss { save_path } => {
260                            let score_validation: Matrix = self.evaluate(&validation_data, false);
261                            let (loss_validation, _): (f64, f64) =
262                                self.compute_loss(&score_validation, &validation_label, debug);
263                            match best_val_loss {
264                                Some(prev) => {
265                                    if loss_validation < prev {
266                                        save_model(self, save_path.to_string()).unwrap();
267                                        best_val_loss = Some(loss_validation);
268                                    }
269                                }
270                                None => {
271                                    best_val_loss = Some(loss_validation);
272                                }
273                            }
274                        }
275                    },
276                    None => (),
277                }
278
279                if ((batch_row + 1) % print_frequency == 0 || batch_row + 1 == index_matrix.len())
280                    && !debug
281                    && !silent_mode
282                {
283                    let score_validation: Matrix = self.evaluate(&validation_data, false);
284                    let (loss_validation, _): (f64, f64) =
285                        self.compute_loss(&score_validation, &validation_label, debug);
286                    let (loss_training, l2_reg_penalty_training): (f64, f64) =
287                        self.compute_loss(&score, &batch_label, debug);
288                    let acc_training: f64 = self.accuracy(&score, &batch_label);
289                    let acc_validation: f64 = self.accuracy(&score_validation, &validation_label);
290
291                    println!(
292                        "Epoch : {}, Batch : {}, Loss : {}, L2 reg penalty {} , Acc {}, Val_loss : {}, Val_acc : {}",
293                        epoch + 1,
294                        batch_row + 1,
295                        loss_training,
296                        l2_reg_penalty_training,
297                        acc_training,
298                        loss_validation,
299                        acc_validation
300                    );
301                }
302
303                iteration += 1;
304            }
305        }
306
307        if !silent_mode {
308            match &checkpoint {
309                Some(checkpoint) => {
310                    match checkpoint {
311                        Checkpoint::ValAcc { save_path } => println!("The best model has been saved at the path : {} it's validation accuracy is : {}", save_path, best_val_acc.unwrap_or(0.0)),
312                        Checkpoint::ValLoss { save_path } => println!("The best model has been saved at the path : {} it's validation loss is : {}", save_path, best_val_loss.unwrap_or(0.0))
313                    }
314                },
315                None => ()
316            }
317        }
318
319        network_history
320    }
321
322    pub fn accuracy(&mut self, score: &Matrix, labels: &Matrix) -> f64 {
323        let answer = Self::evaluation_output(&score);
324
325        let mut sum = 0;
326        for index in 0..answer.width {
327            if answer.get(0, index) == labels.get(0, index) {
328                sum += 1;
329            }
330        }
331
332        sum as f64 / answer.width as f64
333    }
334
335    pub fn evaluation_output(score: &Matrix) -> Matrix {
336        let mut output: Matrix = Matrix::init_zero(1, score.height);
337        for r in 0..score.height {
338            let one_input: Vec<f64> = score.get_row(r);
339            let index_max: usize = one_input
340                .iter()
341                .enumerate()
342                .max_by(|(_, a), (_, b)| a.total_cmp(b))
343                .map(|(index, _)| index)
344                .unwrap();
345
346            output.set(index_max as f64, 0, r);
347        }
348
349        output
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::{Matrix, Model};
356
357    fn get_test_matrix() -> Matrix {
358        let matrix = Matrix::init(2, 3, vec![0.1, 1.3, 0.5, 12.0, 1.01, -1000.0]);
359
360        matrix
361    }
362
363    #[test]
364    fn evaluation_output_test() {
365        let expected_output = Matrix::init(1, 2, vec![1.0, 0.0]);
366        let output = Model::evaluation_output(&get_test_matrix());
367
368        assert!(expected_output.is_equal(&output, 10));
369    }
370}