1use crate::activation::*;
2use crate::checkpoint::Checkpoint;
3use crate::layers::*;
4use crate::loss::*;
5use crate::matrix::*;
6use crate::optimizer::*;
7use crate::save_load::save_model;
8use crate::utils::*;
9
10#[derive(Clone)]
11pub struct Model {
12 pub layers: Vec<Layer>,
13 pub lambda: f64,
14 pub optimizer: Optimizer,
15
16 pub layers_debug: Option<Vec<Layer>>,
19 pub input: Option<Matrix>,
20 pub input_label: Option<Matrix>,
21 pub itermediate_evaluation_results: Option<Vec<Matrix>>,
22 pub softmax_output: Option<Matrix>,
23 pub data_loss: Option<f64>,
24 pub reg_loss: Option<f64>,
25 pub loss: Option<f64>,
26 pub d_score: Option<Matrix>,
27 pub d_zs: Option<Vec<Matrix>>,
28 pub d_ws: Option<Vec<Matrix>>,
29 pub d_bs: Option<Vec<Matrix>>,
30}
31
32impl Model {
36 pub fn init(layers: Vec<Layer>, optimizer: Optimizer, lambda: f64) -> Model {
37 let output = Model {
38 layers,
39 lambda,
40 optimizer,
41 layers_debug: None,
42 input: None,
43 input_label: None,
44 itermediate_evaluation_results: None,
45 softmax_output: None,
46 d_zs: None,
47 d_ws: None,
48 d_bs: None,
49 d_score: None,
50 loss: None,
51 reg_loss: None,
52 data_loss: None,
53 };
54
55 output
56 }
57
58 pub fn evaluate(&mut self, input: &Matrix, debug: bool) -> Matrix {
59 for index in 0..self.layers.len() {
60 if index == 0 {
61 self.layers[0].forward(input, false);
62 } else {
63 let (l, r) = self.layers.split_at_mut(index);
64 r[0].forward(&l[index - 1].output, false);
65 }
66
67 if debug {
68 self.itermediate_evaluation_results
69 .get_or_insert(Vec::new())
70 .push(self.layers[index].output.clone());
71 }
72 }
73
74 let output = softmax(&self.layers[self.layers.len() - 1].output);
75
76 if debug {
77 self.softmax_output = Some(output.clone());
78 }
79
80 output
81 }
82
83 pub fn compute_loss(&mut self, output: &Matrix, labels: &Matrix, debug: bool) -> (f64, f64) {
85 if debug {
86 self.data_loss = Some(cross_entropy(output, labels));
87 self.reg_loss = Some(l2_reg(&self.layers, self.lambda));
88 }
89
90 (
91 cross_entropy(output, labels),
92 l2_reg(&self.layers, self.lambda),
93 )
94 }
95
96 pub fn compute_d_score(score: &Matrix, labels: &Matrix) -> Matrix {
97 let mut output: Matrix = Matrix::init_zero(score.height, score.width);
98 for r in 0..score.height {
99 for c in 0..score.width {
100 if labels.get(0, r) == c as f64 {
102 let v: f64 = score.get(r, c) - 1.0;
104
105 output.set(v, r, c);
106 } else {
107 let v: f64 = score.get(r, c);
109 output.set(v, r, c);
110 }
111 }
112 }
113
114 output
115 }
116
117 pub fn update_params(&mut self, d_score: Matrix, input: Matrix, iteration: i32, debug: bool) {
118 let mut index: usize = self.layers.len() - 1;
119 let mut d_z: Matrix = d_score;
120
121 loop {
122 if index > 0 {
123 let (l, r) = self.layers.split_at_mut(index);
124 d_z = r[0].backprop(
125 &d_z,
126 &l[index - 1].output,
127 l[index - 1].relu,
128 self.lambda,
129 &self.optimizer,
130 iteration,
131 false,
132 debug,
133 &mut self.d_ws,
134 &mut self.d_bs,
135 &mut self.d_zs,
136 );
137 } else if index == 0 {
138 self.layers[index].backprop(
139 &d_z,
140 &input,
141 false,
142 self.lambda,
143 &self.optimizer,
144 iteration,
145 true,
146 debug,
147 &mut self.d_ws,
148 &mut self.d_bs,
149 &mut self.d_zs,
150 );
151 break;
152 }
153
154 index -= 1;
155 }
156 }
157
158 pub fn train(
163 &mut self,
164 data: &Matrix,
165 labels: &Matrix,
166 batch_size: u32,
167 epochs: u32,
168 validation_dataset_size: usize,
169 checkpoint: Option<Checkpoint>,
170 print_frequency: usize,
171 debug: bool,
172 silent_mode: bool, ) -> Option<Vec<Model>> {
174 let mut network_history: Option<Vec<Model>> = None;
175
176 let mut index_table: Vec<u32>;
177 let index_validation: Vec<u32>;
178 let mut validation_data: Matrix =
179 Matrix::init_zero(validation_dataset_size as usize, data.width);
180 let mut validation_label: Matrix = Matrix::init_zero(1, validation_dataset_size as usize);
181
182 if !debug {
186 index_table = generate_vec_rand_unique(data.height as u32);
187
188 index_validation = index_table[0..validation_dataset_size].to_vec();
189 index_table.drain(0..validation_dataset_size);
190
191 for i in 0..validation_dataset_size as usize {
192 let index: usize = index_validation[i] as usize;
193 validation_data.set_row(&data.get_row(index), i);
195 validation_label.set(labels.get(0, index), 0, i);
196 }
197 } else {
198 index_table = (0..data.height as u32).collect();
199 }
200
201 let mut iteration: i32 = 1;
202 let mut best_val_acc: Option<f64> = None;
203 let mut best_val_loss: Option<f64> = None;
204 for epoch in 0..epochs {
205 let index_matrix: Vec<Vec<f64>> = generate_batch_index(&index_table, batch_size);
206
207 for batch_row in 0..index_matrix.len() {
208 let batch_indexes: Vec<f64> = index_matrix[batch_row].clone();
209 let mut batch_data: Matrix = Matrix::init_zero(batch_indexes.len(), data.width);
210 let mut batch_label: Matrix = Matrix::init_zero(1, batch_indexes.len());
211
212 for i in 0..batch_indexes.len() as usize {
213 let index: usize = batch_indexes[i] as usize;
214 batch_data.set_row(&data.get_row(index), i);
215 batch_label.set(labels.get(0, index), 0, i);
216 }
217
218 let score: Matrix = self.evaluate(&batch_data, debug);
219 let d_score: Matrix = Model::compute_d_score(&score, &batch_label);
220
221 if debug {
222 let (loss, l2_reg_penalty): (f64, f64) =
223 self.compute_loss(&score, &batch_label, debug);
224 self.d_score = Some(d_score.clone());
225 self.input = Some(batch_data.clone());
226 self.input_label = Some(batch_label.clone());
227 self.loss = Some(loss + l2_reg_penalty);
228 self.layers_debug = Some(self.layers.clone());
229 }
230
231 self.update_params(d_score, batch_data, iteration, debug);
232
233 if debug {
234 network_history.get_or_insert(Vec::new()).push(self.clone());
235 self.itermediate_evaluation_results = None;
236 self.d_zs = None;
237 self.d_bs = None;
238 self.d_ws = None;
239 }
240
241 match &checkpoint {
242 Some(checkpoint) => match checkpoint {
243 Checkpoint::ValAcc { save_path } => {
244 let score_validation: Matrix = self.evaluate(&validation_data, false);
245 let acc_validation: f64 =
246 self.accuracy(&score_validation, &validation_label);
247 match best_val_acc {
248 Some(prev) => {
249 if acc_validation > prev {
250 save_model(self, save_path.to_string()).unwrap();
251 best_val_acc = Some(acc_validation);
252 }
253 }
254 None => {
255 best_val_acc = Some(acc_validation);
256 }
257 }
258 }
259 Checkpoint::ValLoss { save_path } => {
260 let score_validation: Matrix = self.evaluate(&validation_data, false);
261 let (loss_validation, _): (f64, f64) =
262 self.compute_loss(&score_validation, &validation_label, debug);
263 match best_val_loss {
264 Some(prev) => {
265 if loss_validation < prev {
266 save_model(self, save_path.to_string()).unwrap();
267 best_val_loss = Some(loss_validation);
268 }
269 }
270 None => {
271 best_val_loss = Some(loss_validation);
272 }
273 }
274 }
275 },
276 None => (),
277 }
278
279 if ((batch_row + 1) % print_frequency == 0 || batch_row + 1 == index_matrix.len())
280 && !debug
281 && !silent_mode
282 {
283 let score_validation: Matrix = self.evaluate(&validation_data, false);
284 let (loss_validation, _): (f64, f64) =
285 self.compute_loss(&score_validation, &validation_label, debug);
286 let (loss_training, l2_reg_penalty_training): (f64, f64) =
287 self.compute_loss(&score, &batch_label, debug);
288 let acc_training: f64 = self.accuracy(&score, &batch_label);
289 let acc_validation: f64 = self.accuracy(&score_validation, &validation_label);
290
291 println!(
292 "Epoch : {}, Batch : {}, Loss : {}, L2 reg penalty {} , Acc {}, Val_loss : {}, Val_acc : {}",
293 epoch + 1,
294 batch_row + 1,
295 loss_training,
296 l2_reg_penalty_training,
297 acc_training,
298 loss_validation,
299 acc_validation
300 );
301 }
302
303 iteration += 1;
304 }
305 }
306
307 if !silent_mode {
308 match &checkpoint {
309 Some(checkpoint) => {
310 match checkpoint {
311 Checkpoint::ValAcc { save_path } => println!("The best model has been saved at the path : {} it's validation accuracy is : {}", save_path, best_val_acc.unwrap_or(0.0)),
312 Checkpoint::ValLoss { save_path } => println!("The best model has been saved at the path : {} it's validation loss is : {}", save_path, best_val_loss.unwrap_or(0.0))
313 }
314 },
315 None => ()
316 }
317 }
318
319 network_history
320 }
321
322 pub fn accuracy(&mut self, score: &Matrix, labels: &Matrix) -> f64 {
323 let answer = Self::evaluation_output(&score);
324
325 let mut sum = 0;
326 for index in 0..answer.width {
327 if answer.get(0, index) == labels.get(0, index) {
328 sum += 1;
329 }
330 }
331
332 sum as f64 / answer.width as f64
333 }
334
335 pub fn evaluation_output(score: &Matrix) -> Matrix {
336 let mut output: Matrix = Matrix::init_zero(1, score.height);
337 for r in 0..score.height {
338 let one_input: Vec<f64> = score.get_row(r);
339 let index_max: usize = one_input
340 .iter()
341 .enumerate()
342 .max_by(|(_, a), (_, b)| a.total_cmp(b))
343 .map(|(index, _)| index)
344 .unwrap();
345
346 output.set(index_max as f64, 0, r);
347 }
348
349 output
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::{Matrix, Model};
356
357 fn get_test_matrix() -> Matrix {
358 let matrix = Matrix::init(2, 3, vec![0.1, 1.3, 0.5, 12.0, 1.01, -1000.0]);
359
360 matrix
361 }
362
363 #[test]
364 fn evaluation_output_test() {
365 let expected_output = Matrix::init(1, 2, vec![1.0, 0.0]);
366 let output = Model::evaluation_output(&get_test_matrix());
367
368 assert!(expected_output.is_equal(&output, 10));
369 }
370}