ferrite_rs/multivariate_regression/update_weight/
mod.rs

1use ndarray::{Array2, Axis, s};
2use crate::matrix_operations::mat_mul::matrix_mul;
3use crate::multivariate_regression::cost_fn::cost_fn::CostFn;
4use crate::multivariate_regression::gradient::Gradient;
5use crate::multivariate_regression::regularization::regularization::Regularization;
6
7pub enum UpdatationMethod {
8    SGD,
9    BGD,
10    MiniBatchGD
11}
12
13#[derive(Clone)]
14pub enum MiniBatchSize {
15    Small = 4,
16    Medium = 16,
17    Large = 64,
18    ExtraLarge = 256
19}
20
21pub fn update_weight(
22    input : &Array2<f64>,
23    output : &Array2<f64>,
24    weight: &mut Array2<f64>,
25    updatation_method: &UpdatationMethod,
26    mini_batch_size: &Option<MiniBatchSize>,
27    regularization: Regularization,
28    grad : &Gradient,
29    cost_fn : &CostFn,
30    delta : f64,
31    lr : f64,
32    log : bool,
33    
34) {
35    match updatation_method {
36        UpdatationMethod::SGD => {
37            for i in 0..input.nrows() {
38                let input_row = input.slice(s![i..i+1, ..]).to_owned();  // 2D slice of one row
39                let output_row = output.slice(s![i..i+1, ..]).to_owned();
40                let pred = matrix_mul(&input_row, weight); 
41                if log {
42                    let cost = cost_fn.calculate_cost(&output_row, &pred, &regularization, weight);
43                    println!(" {}", cost);
44                }
45                let gradient = grad.calculate_gradient(delta, &input_row, &pred, &output_row, weight);
46                *weight -= &(lr * gradient);
47            }
48        },
49        UpdatationMethod::BGD => {
50            let pred = matrix_mul(input,&weight);
51            if log {
52                let cost = cost_fn.calculate_cost(output,&pred,&regularization,&weight);
53                println!(" {}", cost);
54            }
55            let gradient = grad.calculate_gradient(delta,input,&pred,output,&weight);
56            *weight -= &(lr * gradient);
57        },
58        UpdatationMethod::MiniBatchGD => {
59            let batch_size = mini_batch_size.clone().unwrap_or(MiniBatchSize::Medium) as usize;
60            let mut batches = input.nrows()/batch_size as usize + 1;
61            let mut tmp = 0;
62            while batches > 0 {
63                let batch_ip = input.slice(s![tmp..tmp+batch_size, ..]).to_owned();
64                let batch_op = output.slice(s![tmp..tmp+batch_size, ..]).to_owned();
65                let pred = matrix_mul(&batch_ip,&weight);
66                if log {
67                    let cost = cost_fn.calculate_cost(output,&pred,&regularization,&weight);
68                    print!(" {}", cost);
69                } 
70                let gradient = grad.calculate_gradient(delta,&batch_ip,&pred,&batch_op,&weight);
71                *weight -= &(lr * gradient);
72                batches -= 1;
73                tmp += batch_ip.len();
74            }
75            
76        }
77    }
78}