1use ndarray::{Array1, Array2, Axis};
2use rand::seq::SliceRandom;
3use rand::thread_rng;
4use rand_distr::{Distribution, Normal};
5
6pub trait Layer {
7 fn forward(&mut self, input: &Array2<f64>) -> Array2<f64>;
9 fn backward(&mut self, grad_output: &Array2<f64>) -> Array2<f64>;
11 fn update_params(&mut self, learning_rate: f64);
13}
14
15pub struct DenseLayer {
17 pub weights: Array2<f64>,
18 pub biases: Array1<f64>,
19
20 input_cache: Option<Array2<f64>>,
21 weight_grads: Option<Array2<f64>>,
22 bias_grads: Option<Array1<f64>>,
23}
24
25impl DenseLayer {
26 pub fn new(in_features: usize, out_features: usize, init_std: f64) -> Self {
27 let mut rng = thread_rng();
28 let dist = Normal::new(0.0, init_std).unwrap();
29
30 let weights = Array2::from_shape_fn((in_features, out_features), |_| dist.sample(&mut rng));
31 let biases = Array1::zeros(out_features);
32
33 Self {
34 weights,
35 biases,
36 input_cache: None,
37 weight_grads: None,
38 bias_grads: None,
39 }
40 }
41}
42
43#[allow(non_snake_case)]
44impl Layer for DenseLayer {
45 fn forward(&mut self, input: &Array2<f64>) -> Array2<f64> {
46 self.input_cache = Some(input.clone());
47 let mut output = input.dot(&self.weights);
48 output += &self.biases;
49 output
50 }
51
52 fn backward(&mut self, grad_output: &Array2<f64>) -> Array2<f64> {
53 let input = self
54 .input_cache
55 .as_ref()
56 .expect("Must call forward before backward.");
57
58 let dW = input.t().dot(grad_output);
60 let dB = grad_output.sum_axis(Axis(0));
62 let dX = grad_output.dot(&self.weights.t());
64
65 self.weight_grads = Some(dW);
66 self.bias_grads = Some(dB);
67
68 dX
69 }
70
71 fn update_params(&mut self, lr: f64) {
72 if let Some(dw) = &self.weight_grads {
73 self.weights = &self.weights - &(dw * lr);
74 }
75 if let Some(db) = &self.bias_grads {
76 self.biases = &self.biases - &(db * lr);
77 }
78 self.input_cache = None;
79 self.weight_grads = None;
80 self.bias_grads = None;
81 }
82}
83
84pub struct Sigmoid {
86 output_cache: Option<Array2<f64>>,
87}
88
89impl Default for Sigmoid {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95impl Sigmoid {
96 pub fn new() -> Self {
97 Self { output_cache: None }
98 }
99}
100
101impl Layer for Sigmoid {
102 fn forward(&mut self, input: &Array2<f64>) -> Array2<f64> {
103 let output = input.mapv(|x| 1.0 / (1.0 + (-x).exp()));
104 self.output_cache = Some(output.clone());
105 output
106 }
107
108 fn backward(&mut self, grad_output: &Array2<f64>) -> Array2<f64> {
109 let out = self.output_cache.as_ref().unwrap();
110 out * (1.0 - out) * grad_output
113 }
114
115 fn update_params(&mut self, _lr: f64) {
116 self.output_cache = None;
117 }
118}
119
120pub struct SequentialNN {
122 pub layers: Vec<Box<dyn Layer>>,
123 pub learning_rate: f64,
124}
125
126impl SequentialNN {
127 pub fn new(layers: Vec<Box<dyn Layer>>, learning_rate: f64) -> Self {
128 Self {
129 layers,
130 learning_rate,
131 }
132 }
133
134 pub fn forward(&mut self, input: &Array2<f64>) -> Array2<f64> {
136 let mut x = input.clone();
137 for layer in self.layers.iter_mut() {
138 x = layer.forward(&x);
139 }
140 x
141 }
142
143 pub fn backward(&mut self, grad_output: &Array2<f64>) {
145 let mut grad = grad_output.clone();
146 for layer in self.layers.iter_mut().rev() {
147 grad = layer.backward(&grad);
148 }
149 }
150
151 pub fn update_params(&mut self) {
153 for layer in self.layers.iter_mut() {
154 layer.update_params(self.learning_rate);
155 }
156 }
157
158 pub fn mse_loss(&mut self, inputs: &Array2<f64>, targets: &Array2<f64>) -> f64 {
160 let preds = self.forward(inputs);
161 let diff = &preds - targets;
162 diff.mapv(|x| x.powi(2)).mean().unwrap_or(0.0)
163 }
164}
165
166pub fn train_sgd(
170 net: &mut SequentialNN,
171 inputs: &Array2<f64>,
172 targets: &Array2<f64>,
173 batch_size: usize,
174) {
175 let n_samples = inputs.len_of(Axis(0));
176 let mut indices: Vec<usize> = (0..n_samples).collect();
177 indices.shuffle(&mut thread_rng());
178
179 for chunk in indices.chunks(batch_size) {
181 let batch_input = Array2::from_shape_fn((chunk.len(), inputs.len_of(Axis(1))), |(i, j)| {
183 inputs[[chunk[i], j]]
184 });
185 let batch_target =
186 Array2::from_shape_fn((chunk.len(), targets.len_of(Axis(1))), |(i, j)| {
187 targets[[chunk[i], j]]
188 });
189
190 let preds = net.forward(&batch_input);
192 let grad_loss = &preds - &batch_target;
194
195 net.backward(&grad_loss);
197
198 net.update_params();
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206 use approx::assert_relative_eq;
207
208 #[test]
209 fn test_dense_layer_forward() {
210 let mut layer = DenseLayer::new(2, 3, 0.1);
211 layer.weights = Array2::from_shape_vec((2, 3), vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6]).unwrap();
213 layer.biases = Array1::from_vec(vec![0.1, 0.2, 0.3]);
214
215 let input = Array2::from_shape_vec((1, 2), vec![1.0, 2.0]).unwrap();
216 let output = layer.forward(&input);
217
218 assert_relative_eq!(output[[0, 0]], 1.0 * 0.1 + 2.0 * 0.4 + 0.1, epsilon = 1e-10);
220 assert_relative_eq!(output[[0, 1]], 1.0 * 0.2 + 2.0 * 0.5 + 0.2, epsilon = 1e-10);
221 assert_relative_eq!(output[[0, 2]], 1.0 * 0.3 + 2.0 * 0.6 + 0.3, epsilon = 1e-10);
222 }
223
224 #[test]
225 fn test_sigmoid_activation() {
226 let mut sigmoid = Sigmoid::new();
227 let input = Array2::from_shape_vec((1, 3), vec![0.0, 1.0, -1.0]).unwrap();
228 let output = sigmoid.forward(&input);
229
230 assert_relative_eq!(output[[0, 0]], 0.5, epsilon = 1e-10);
232 assert_relative_eq!(
234 output[[0, 1]],
235 1.0 / (1.0 + (-1.0f64).exp()),
236 epsilon = 1e-10
237 );
238 assert_relative_eq!(output[[0, 2]], 1.0 / (1.0 + 1.0f64.exp()), epsilon = 1e-10);
240 }
241
242 #[test]
243 fn test_dense_layer_backward() {
244 let mut layer = DenseLayer::new(2, 2, 0.1);
245 layer.weights = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
246 layer.biases = Array1::from_vec(vec![0.1, 0.2]);
247
248 let input = Array2::from_shape_vec((1, 2), vec![1.0, 2.0]).unwrap();
249 layer.forward(&input);
250
251 let grad_output = Array2::from_shape_vec((1, 2), vec![1.0, 1.0]).unwrap();
252 let grad_input = layer.backward(&grad_output);
253
254 assert_eq!(grad_input.shape(), &[1, 2]);
256 assert!(layer.weight_grads.is_some());
257 assert!(layer.bias_grads.is_some());
258 }
259
260 #[test]
261 fn test_sequential_network() {
262 let mut net = SequentialNN::new(
263 vec![
264 Box::new(DenseLayer::new(2, 3, 0.1)),
265 Box::new(Sigmoid::new()),
266 Box::new(DenseLayer::new(3, 1, 0.1)),
267 ],
268 0.1,
269 );
270
271 let input = Array2::from_shape_vec((1, 2), vec![1.0, 2.0]).unwrap();
273 let output = net.forward(&input);
274 assert_eq!(output.shape(), &[1, 1]);
275
276 let target = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
278 let loss = net.mse_loss(&input, &target);
279 assert!(loss >= 0.0);
280 }
281
282 #[test]
283 fn test_sgd_training() {
284 let mut net = SequentialNN::new(
285 vec![
286 Box::new(DenseLayer::new(2, 3, 0.1)),
287 Box::new(Sigmoid::new()),
288 Box::new(DenseLayer::new(3, 1, 0.1)),
289 ],
290 0.1,
291 );
292
293 let inputs =
295 Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0]).unwrap();
296
297 let targets = Array2::from_shape_vec((4, 1), vec![0.0, 1.0, 1.0, 0.0]).unwrap();
298
299 let initial_loss = net.mse_loss(&inputs, &targets);
301
302 for _ in 0..100 {
304 train_sgd(&mut net, &inputs, &targets, 2);
305 }
306
307 let final_loss = net.mse_loss(&inputs, &targets);
309 assert!(final_loss < initial_loss, "Training should reduce loss");
310 }
311}