ferrite_rs/multivariate_regression/gradient/
mod.rs

1use ndarray::Array2;
2use crate::matrix_operations::mat_mul::matrix_mul;
3use crate::multivariate_regression::gradient::gradient_type::GradientType;
4use crate::multivariate_regression::regularization::regularization::{Regularization, RegularizationType};
5use ndarray::indices;
6
7pub mod gradient_type;
8
9
10pub struct Gradient {
11    pub(crate) gradient : GradientType,
12    pub(crate) regularization: Regularization
13}
14
15impl Gradient {
16     pub fn mean_absolute_error(regularization: Regularization) -> Self{
17            Self{
18                gradient : GradientType::MeanAbsoluteError,
19                regularization
20            }
21        }
22
23    pub fn mean_squared_error(regularization: Regularization) -> Self{
24        Self{
25            gradient : GradientType::MeanAbsoluteError,
26            regularization,
27        }
28    }
29
30    pub fn huber_loss(regularization: Regularization) -> Self{
31        Self{
32            gradient : GradientType::HuberError,
33            regularization
34        }
35    }
36
37    pub fn calculate_gradient(&self,delta : f64,input : &Array2<f64>,y_pred : &Array2<f64>,y_true : &Array2<f64>,weight : &Array2<f64>) -> Array2<f64> {
38
39        let total_elements : f64 = y_true.len() as f64;
40        let error_matrix = y_true - y_pred;
41        let mut raw_gradient_matrix = Array2::<f64>::zeros((weight.nrows(), weight.ncols()));
42
43        match &self.gradient {
44            GradientType::MeanAbsoluteError => {
45                let sign_error_matrix = error_matrix.mapv(|x| x.signum());
46                raw_gradient_matrix = (-1./total_elements) * matrix_mul(&input.t().to_owned(), &sign_error_matrix);
47            },
48            GradientType::MeanSquaredError => {
49                raw_gradient_matrix = (-1./total_elements) * matrix_mul(&input.t().to_owned(), &error_matrix);
50
51            },
52            GradientType::HuberError => {
53                raw_gradient_matrix = (1./total_elements) * error_matrix.mapv(|x| {
54                    if x.abs()<=delta { x }
55                    else {delta*x.signum() }
56                });
57                println!("{:?}",raw_gradient_matrix);
58
59            }
60        }
61        match &self.regularization.regularization_type {
62             RegularizationType::LassoL1 => {
63                let sign_weight_matrix = weight.mapv(|x| x.signum());
64                (self.regularization.lambda1 * sign_weight_matrix) + raw_gradient_matrix
65            },
66            RegularizationType::RidgeL2 => {
67                (self.regularization.lambda2 * 2. * weight) + raw_gradient_matrix
68            },
69            RegularizationType::ElasticNet => {
70                let sign_weight_matrix = weight.mapv(|x| x.signum());
71                (self.regularization.lambda1 * sign_weight_matrix) + (self.regularization.lambda2 * 2. * weight) + raw_gradient_matrix
72            }
73        }
74    }
75
76}