ferrite_rs/multivariate_regression/gradient/
mod.rs1use ndarray::Array2;
2use crate::matrix_operations::mat_mul::matrix_mul;
3use crate::multivariate_regression::gradient::gradient_type::GradientType;
4use crate::multivariate_regression::regularization::regularization::{Regularization, RegularizationType};
5use ndarray::indices;
6
7pub mod gradient_type;
8
9
10pub struct Gradient {
11 pub(crate) gradient : GradientType,
12 pub(crate) regularization: Regularization
13}
14
15impl Gradient {
16 pub fn mean_absolute_error(regularization: Regularization) -> Self{
17 Self{
18 gradient : GradientType::MeanAbsoluteError,
19 regularization
20 }
21 }
22
23 pub fn mean_squared_error(regularization: Regularization) -> Self{
24 Self{
25 gradient : GradientType::MeanAbsoluteError,
26 regularization,
27 }
28 }
29
30 pub fn huber_loss(regularization: Regularization) -> Self{
31 Self{
32 gradient : GradientType::HuberError,
33 regularization
34 }
35 }
36
37 pub fn calculate_gradient(&self,delta : f64,input : &Array2<f64>,y_pred : &Array2<f64>,y_true : &Array2<f64>,weight : &Array2<f64>) -> Array2<f64> {
38
39 let total_elements : f64 = y_true.len() as f64;
40 let error_matrix = y_true - y_pred;
41 let mut raw_gradient_matrix = Array2::<f64>::zeros((weight.nrows(), weight.ncols()));
42
43 match &self.gradient {
44 GradientType::MeanAbsoluteError => {
45 let sign_error_matrix = error_matrix.mapv(|x| x.signum());
46 raw_gradient_matrix = (-1./total_elements) * matrix_mul(&input.t().to_owned(), &sign_error_matrix);
47 },
48 GradientType::MeanSquaredError => {
49 raw_gradient_matrix = (-1./total_elements) * matrix_mul(&input.t().to_owned(), &error_matrix);
50
51 },
52 GradientType::HuberError => {
53 raw_gradient_matrix = (1./total_elements) * error_matrix.mapv(|x| {
54 if x.abs()<=delta { x }
55 else {delta*x.signum() }
56 });
57 println!("{:?}",raw_gradient_matrix);
58
59 }
60 }
61 match &self.regularization.regularization_type {
62 RegularizationType::LassoL1 => {
63 let sign_weight_matrix = weight.mapv(|x| x.signum());
64 (self.regularization.lambda1 * sign_weight_matrix) + raw_gradient_matrix
65 },
66 RegularizationType::RidgeL2 => {
67 (self.regularization.lambda2 * 2. * weight) + raw_gradient_matrix
68 },
69 RegularizationType::ElasticNet => {
70 let sign_weight_matrix = weight.mapv(|x| x.signum());
71 (self.regularization.lambda1 * sign_weight_matrix) + (self.regularization.lambda2 * 2. * weight) + raw_gradient_matrix
72 }
73 }
74 }
75
76}