1use crate::utils::tasks::Task;
2use ndarray::{Array1, Array2};
3use rand::rngs::StdRng;
4use rand::{Rng, SeedableRng};
5
6pub struct MLP {
7 pub layers: Vec<Layer>,
8 pub learning_rate: f64,
9 pub task: Task,
10 pub rng: StdRng,
11}
12
13pub struct Layer {
14 pub weights: Array2<f64>,
15 pub biases: Array1<f64>,
16 pub activation: Activation,
17 pub weight_momentum: Array2<f64>,
18 pub bias_momentum: Array1<f64>,
19}
20
21#[derive(Clone, Copy)]
22pub enum Activation {
23 ReLU,
24 Sigmoid,
25 TanH,
26 Linear,
27}
28
29impl MLP {
30 pub fn new(
31 layer_sizes: &[usize],
32 activations: &[Activation],
33 learning_rate: f64,
34 task: Task,
35 seed: u64,
36 ) -> Self {
37 assert!(
38 layer_sizes.len() >= 2,
39 "At least input and output layers are required"
40 );
41 assert_eq!(
42 layer_sizes.len() - 1,
43 activations.len(),
44 "Number of activations must match number of layers - 1"
45 );
46
47 let mut rng = StdRng::seed_from_u64(seed);
48
49 let layers: Vec<Layer> = layer_sizes
50 .windows(2)
51 .zip(activations.iter())
52 .map(|(sizes, &activation)| {
53 let (n_in, n_out) = (sizes[0], sizes[1]);
54 let weights = Array2::from_shape_fn((n_out, n_in), |_| {
55 rng.gen_range(-1.0..1.0) * (2.0 / n_in as f64).sqrt()
56 });
57 let biases = Array1::zeros(n_out);
58 Layer {
59 weights,
60 biases,
61 activation,
62 weight_momentum: Array2::zeros((n_out, n_in)),
63 bias_momentum: Array1::zeros(n_out),
64 }
65 })
66 .collect();
67
68 MLP {
69 layers,
70 learning_rate,
71 task,
72 rng,
73 }
74 }
75
76 pub fn predict(&self, x: &Array1<f64>) -> Array1<f64> {
77 let num_layers = self.layers.len();
78 self.layers
79 .iter()
80 .enumerate()
81 .fold(x.clone(), |input, (i, layer)| {
82 let output = layer.weights.dot(&input) + &layer.biases;
83 Self::activate(&output, layer.activation, i == num_layers - 1, self.task)
84 })
85 }
86
87 pub fn fit(&mut self, x: &Array2<f64>, y: &Array2<f64>, epochs: usize) {
88 let num_layers = self.layers.len();
89 for _ in 0..epochs {
90 for (input, target) in x.outer_iter().zip(y.outer_iter()) {
91 let mut activations = vec![input.to_owned()];
92
93 for (i, layer) in self.layers.iter().enumerate() {
95 let output =
96 layer.weights.dot(&activations.last().unwrap().view()) + &layer.biases;
97 activations.push(Self::activate(
98 &output,
99 layer.activation,
100 i == num_layers - 1,
101 self.task,
102 ));
103 }
104
105 let mut delta = activations.last().unwrap() - ⌖
107
108 for (i, (layer, activation)) in self
109 .layers
110 .iter_mut()
111 .rev()
112 .zip(activations.iter().rev().skip(1))
113 .enumerate()
114 {
115 let gradient = delta.clone();
116 if i < num_layers - 1 {
117 delta = layer.weights.t().dot(&delta)
118 * Self::activate_derivative(
119 activation,
120 layer.activation,
121 i == 0, self.task,
123 );
124 }
125
126 let weight_update = gradient
127 .clone()
128 .into_shape((gradient.len(), 1))
129 .unwrap()
130 .dot(&activation.view().into_shape((1, activation.len())).unwrap());
131
132 layer.weights -= &(self.learning_rate * &weight_update);
133 layer.biases -= &(self.learning_rate * &gradient);
134 }
135 }
136 }
137 }
138
139 fn activate(
140 x: &Array1<f64>,
141 activation: Activation,
142 is_output_layer: bool,
143 task: Task,
144 ) -> Array1<f64> {
145 if is_output_layer && task == Task::Regression {
146 return x.to_owned();
147 }
148 match activation {
149 Activation::ReLU => x.mapv(|v| v.max(0.0)),
150 Activation::Sigmoid => x.mapv(|v| 1.0 / (1.0 + (-v).exp())),
151 Activation::TanH => x.mapv(|v| v.tanh()),
152 Activation::Linear => x.to_owned(),
153 }
154 }
155
156 fn activate_derivative(
157 x: &Array1<f64>,
158 activation: Activation,
159 is_output_layer: bool,
160 task: Task,
161 ) -> Array1<f64> {
162 if is_output_layer && task == Task::Regression {
163 return Array1::ones(x.len());
164 }
165 match activation {
166 Activation::ReLU => x.mapv(|v| if v > 0.0 { 1.0 } else { 0.0 }),
167 Activation::Sigmoid => {
168 let s = x.mapv(|v| 1.0 / (1.0 + (-v).exp()));
169 s.clone() * (1.0 - s)
170 }
171 Activation::TanH => x.mapv(|v| 1.0 - v.tanh().powi(2)),
172 Activation::Linear => Array1::ones(x.len()),
173 }
174 }
175}