ferrite_rs/multivariate_regression/update_weight/
mod.rs1use ndarray::{Array2, Axis, s};
2use crate::matrix_operations::mat_mul::matrix_mul;
3use crate::multivariate_regression::cost_fn::cost_fn::CostFn;
4use crate::multivariate_regression::gradient::Gradient;
5use crate::multivariate_regression::regularization::regularization::Regularization;
6
7pub enum UpdatationMethod {
8 SGD,
9 BGD,
10 MiniBatchGD
11}
12
13#[derive(Clone)]
14pub enum MiniBatchSize {
15 Small = 4,
16 Medium = 16,
17 Large = 64,
18 ExtraLarge = 256
19}
20
21pub fn update_weight(
22 input : &Array2<f64>,
23 output : &Array2<f64>,
24 weight: &mut Array2<f64>,
25 updatation_method: &UpdatationMethod,
26 mini_batch_size: &Option<MiniBatchSize>,
27 regularization: Regularization,
28 grad : &Gradient,
29 cost_fn : &CostFn,
30 delta : f64,
31 lr : f64,
32 log : bool,
33
34) {
35 match updatation_method {
36 UpdatationMethod::SGD => {
37 for i in 0..input.nrows() {
38 let input_row = input.slice(s![i..i+1, ..]).to_owned(); let output_row = output.slice(s![i..i+1, ..]).to_owned();
40 let pred = matrix_mul(&input_row, weight);
41 if log {
42 let cost = cost_fn.calculate_cost(&output_row, &pred, ®ularization, weight);
43 println!(" {}", cost);
44 }
45 let gradient = grad.calculate_gradient(delta, &input_row, &pred, &output_row, weight);
46 *weight -= &(lr * gradient);
47 }
48 },
49 UpdatationMethod::BGD => {
50 let pred = matrix_mul(input,&weight);
51 if log {
52 let cost = cost_fn.calculate_cost(output,&pred,®ularization,&weight);
53 println!(" {}", cost);
54 }
55 let gradient = grad.calculate_gradient(delta,input,&pred,output,&weight);
56 *weight -= &(lr * gradient);
57 },
58 UpdatationMethod::MiniBatchGD => {
59 let batch_size = mini_batch_size.clone().unwrap_or(MiniBatchSize::Medium) as usize;
60 let mut batches = input.nrows()/batch_size as usize + 1;
61 let mut tmp = 0;
62 while batches > 0 {
63 let batch_ip = input.slice(s![tmp..tmp+batch_size, ..]).to_owned();
64 let batch_op = output.slice(s![tmp..tmp+batch_size, ..]).to_owned();
65 let pred = matrix_mul(&batch_ip,&weight);
66 if log {
67 let cost = cost_fn.calculate_cost(output,&pred,®ularization,&weight);
68 print!(" {}", cost);
69 }
70 let gradient = grad.calculate_gradient(delta,&batch_ip,&pred,&batch_op,&weight);
71 *weight -= &(lr * gradient);
72 batches -= 1;
73 tmp += batch_ip.len();
74 }
75
76 }
77 }
78}