fleximl_models/ml/
mlp.rs

1use crate::utils::tasks::Task;
2use ndarray::{Array1, Array2};
3use rand::rngs::StdRng;
4use rand::{Rng, SeedableRng};
5
6pub struct MLP {
7    pub layers: Vec<Layer>,
8    pub learning_rate: f64,
9    pub task: Task,
10    pub rng: StdRng,
11}
12
13pub struct Layer {
14    pub weights: Array2<f64>,
15    pub biases: Array1<f64>,
16    pub activation: Activation,
17    pub weight_momentum: Array2<f64>,
18    pub bias_momentum: Array1<f64>,
19}
20
21#[derive(Clone, Copy)]
22pub enum Activation {
23    ReLU,
24    Sigmoid,
25    TanH,
26    Linear,
27}
28
29impl MLP {
30    pub fn new(
31        layer_sizes: &[usize],
32        activations: &[Activation],
33        learning_rate: f64,
34        task: Task,
35        seed: u64,
36    ) -> Self {
37        assert!(
38            layer_sizes.len() >= 2,
39            "At least input and output layers are required"
40        );
41        assert_eq!(
42            layer_sizes.len() - 1,
43            activations.len(),
44            "Number of activations must match number of layers - 1"
45        );
46
47        let mut rng = StdRng::seed_from_u64(seed);
48
49        let layers: Vec<Layer> = layer_sizes
50            .windows(2)
51            .zip(activations.iter())
52            .map(|(sizes, &activation)| {
53                let (n_in, n_out) = (sizes[0], sizes[1]);
54                let weights = Array2::from_shape_fn((n_out, n_in), |_| {
55                    rng.gen_range(-1.0..1.0) * (2.0 / n_in as f64).sqrt()
56                });
57                let biases = Array1::zeros(n_out);
58                Layer {
59                    weights,
60                    biases,
61                    activation,
62                    weight_momentum: Array2::zeros((n_out, n_in)),
63                    bias_momentum: Array1::zeros(n_out),
64                }
65            })
66            .collect();
67
68        MLP {
69            layers,
70            learning_rate,
71            task,
72            rng,
73        }
74    }
75
76    pub fn predict(&self, x: &Array1<f64>) -> Array1<f64> {
77        let num_layers = self.layers.len();
78        self.layers
79            .iter()
80            .enumerate()
81            .fold(x.clone(), |input, (i, layer)| {
82                let output = layer.weights.dot(&input) + &layer.biases;
83                Self::activate(&output, layer.activation, i == num_layers - 1, self.task)
84            })
85    }
86
87    pub fn fit(&mut self, x: &Array2<f64>, y: &Array2<f64>, epochs: usize) {
88        let num_layers = self.layers.len();
89        for _ in 0..epochs {
90            for (input, target) in x.outer_iter().zip(y.outer_iter()) {
91                let mut activations = vec![input.to_owned()];
92
93                // Forward pass
94                for (i, layer) in self.layers.iter().enumerate() {
95                    let output =
96                        layer.weights.dot(&activations.last().unwrap().view()) + &layer.biases;
97                    activations.push(Self::activate(
98                        &output,
99                        layer.activation,
100                        i == num_layers - 1,
101                        self.task,
102                    ));
103                }
104
105                // Backward pass
106                let mut delta = activations.last().unwrap() - &target;
107
108                for (i, (layer, activation)) in self
109                    .layers
110                    .iter_mut()
111                    .rev()
112                    .zip(activations.iter().rev().skip(1))
113                    .enumerate()
114                {
115                    let gradient = delta.clone();
116                    if i < num_layers - 1 {
117                        delta = layer.weights.t().dot(&delta)
118                            * Self::activate_derivative(
119                                activation,
120                                layer.activation,
121                                i == 0, // is_output_layer is true for the last layer in backpropagation
122                                self.task,
123                            );
124                    }
125
126                    let weight_update = gradient
127                        .clone()
128                        .into_shape((gradient.len(), 1))
129                        .unwrap()
130                        .dot(&activation.view().into_shape((1, activation.len())).unwrap());
131
132                    layer.weights -= &(self.learning_rate * &weight_update);
133                    layer.biases -= &(self.learning_rate * &gradient);
134                }
135            }
136        }
137    }
138
139    fn activate(
140        x: &Array1<f64>,
141        activation: Activation,
142        is_output_layer: bool,
143        task: Task,
144    ) -> Array1<f64> {
145        if is_output_layer && task == Task::Regression {
146            return x.to_owned();
147        }
148        match activation {
149            Activation::ReLU => x.mapv(|v| v.max(0.0)),
150            Activation::Sigmoid => x.mapv(|v| 1.0 / (1.0 + (-v).exp())),
151            Activation::TanH => x.mapv(|v| v.tanh()),
152            Activation::Linear => x.to_owned(),
153        }
154    }
155
156    fn activate_derivative(
157        x: &Array1<f64>,
158        activation: Activation,
159        is_output_layer: bool,
160        task: Task,
161    ) -> Array1<f64> {
162        if is_output_layer && task == Task::Regression {
163            return Array1::ones(x.len());
164        }
165        match activation {
166            Activation::ReLU => x.mapv(|v| if v > 0.0 { 1.0 } else { 0.0 }),
167            Activation::Sigmoid => {
168                let s = x.mapv(|v| 1.0 / (1.0 + (-v).exp()));
169                s.clone() * (1.0 - s)
170            }
171            Activation::TanH => x.mapv(|v| 1.0 - v.tanh().powi(2)),
172            Activation::Linear => Array1::ones(x.len()),
173        }
174    }
175}