fleximl_models/ml/
linear_model.rs1use crate::utils::tasks::Task;
2use ndarray::{Array1, Array2, Axis};
3use rand::rngs::StdRng;
4use rand::{Rng, SeedableRng};
5
6pub struct LinearModel {
7 pub weights: Array2<f64>,
8 pub bias: Array1<f64>,
9 pub learning_rate: f64,
10 pub task: Task,
11 pub rng: StdRng,
12}
13
14impl LinearModel {
15 pub fn new(
16 num_features: usize,
17 num_classes: usize,
18 learning_rate: f64,
19 task: Task,
20 seed: u64,
21 ) -> Self {
22 let mut rng = StdRng::seed_from_u64(seed);
23 let num_outputs = match task {
24 Task::BinaryClassification => 1,
25 Task::Regression => 1,
26 Task::MultiClassClassification => num_classes,
27 };
28
29 LinearModel {
30 weights: Array2::from_shape_fn((num_features, num_outputs), |_| {
31 rng.gen_range(-0.5..0.5)
32 }),
33 bias: Array1::zeros(num_outputs),
34 learning_rate,
35 task,
36 rng,
37 }
38 }
39
40 pub fn predict(&self, x: &Array1<f64>) -> Array1<f64> {
41 let linear_output = self.weights.t().dot(x) + &self.bias;
42 match self.task {
43 Task::BinaryClassification => Array1::from(vec![self.sigmoid(linear_output[0])]),
44 Task::Regression => linear_output,
45 Task::MultiClassClassification => {
46 let exp_output = linear_output.mapv(|x| x.exp());
47 let sum = exp_output.sum();
48 exp_output / sum
49 }
50 }
51 }
52
53 pub fn fit(&mut self, x: &Array2<f64>, y: &Array2<f64>, epochs: usize) {
54 for _ in 0..epochs {
55 let predictions = x.dot(&self.weights) + &self.bias;
56 let errors = match self.task {
57 Task::BinaryClassification => predictions.mapv(|p| self.sigmoid(p)) - y,
58 Task::Regression => predictions - y,
59 Task::MultiClassClassification => {
60 let softmax = predictions.mapv(|p| p.exp())
61 / predictions
62 .mapv(|p| p.exp())
63 .sum_axis(Axis(1))
64 .insert_axis(Axis(1));
65 softmax - y
66 }
67 };
68
69 let gradient = x.t().dot(&errors);
70 self.weights -= &(self.learning_rate * gradient / x.nrows() as f64);
71 self.bias -= &(self.learning_rate * errors.sum_axis(Axis(0)) / x.nrows() as f64);
72 }
73 }
74
75 fn sigmoid(&self, x: f64) -> f64 {
76 1.0 / (1.0 + (-x).exp())
77 }
78}