fleximl_models/ml/
linear_model.rs

1use crate::utils::tasks::Task;
2use ndarray::{Array1, Array2, Axis};
3use rand::rngs::StdRng;
4use rand::{Rng, SeedableRng};
5
6pub struct LinearModel {
7    pub weights: Array2<f64>,
8    pub bias: Array1<f64>,
9    pub learning_rate: f64,
10    pub task: Task,
11    pub rng: StdRng,
12}
13
14impl LinearModel {
15    pub fn new(
16        num_features: usize,
17        num_classes: usize,
18        learning_rate: f64,
19        task: Task,
20        seed: u64,
21    ) -> Self {
22        let mut rng = StdRng::seed_from_u64(seed);
23        let num_outputs = match task {
24            Task::BinaryClassification => 1,
25            Task::Regression => 1,
26            Task::MultiClassClassification => num_classes,
27        };
28
29        LinearModel {
30            weights: Array2::from_shape_fn((num_features, num_outputs), |_| {
31                rng.gen_range(-0.5..0.5)
32            }),
33            bias: Array1::zeros(num_outputs),
34            learning_rate,
35            task,
36            rng,
37        }
38    }
39
40    pub fn predict(&self, x: &Array1<f64>) -> Array1<f64> {
41        let linear_output = self.weights.t().dot(x) + &self.bias;
42        match self.task {
43            Task::BinaryClassification => Array1::from(vec![self.sigmoid(linear_output[0])]),
44            Task::Regression => linear_output,
45            Task::MultiClassClassification => {
46                let exp_output = linear_output.mapv(|x| x.exp());
47                let sum = exp_output.sum();
48                exp_output / sum
49            }
50        }
51    }
52
53    pub fn fit(&mut self, x: &Array2<f64>, y: &Array2<f64>, epochs: usize) {
54        for _ in 0..epochs {
55            let predictions = x.dot(&self.weights) + &self.bias;
56            let errors = match self.task {
57                Task::BinaryClassification => predictions.mapv(|p| self.sigmoid(p)) - y,
58                Task::Regression => predictions - y,
59                Task::MultiClassClassification => {
60                    let softmax = predictions.mapv(|p| p.exp())
61                        / predictions
62                            .mapv(|p| p.exp())
63                            .sum_axis(Axis(1))
64                            .insert_axis(Axis(1));
65                    softmax - y
66                }
67            };
68
69            let gradient = x.t().dot(&errors);
70            self.weights -= &(self.learning_rate * gradient / x.nrows() as f64);
71            self.bias -= &(self.learning_rate * errors.sum_axis(Axis(0)) / x.nrows() as f64);
72        }
73    }
74
75    fn sigmoid(&self, x: f64) -> f64 {
76        1.0 / (1.0 + (-x).exp())
77    }
78}