algos/ml/deep/
backpropagation.rs

1use ndarray::{Array1, Array2, Axis};
2use rand::seq::SliceRandom;
3use rand::thread_rng;
4use rand_distr::{Distribution, Normal};
5
6pub trait Layer {
7    /// Forward pass: takes input (batch_size, in_features).
8    fn forward(&mut self, input: &Array2<f64>) -> Array2<f64>;
9    /// Backward pass: given grad w.r.t. layer output, returns grad w.r.t. layer input.
10    fn backward(&mut self, grad_output: &Array2<f64>) -> Array2<f64>;
11    /// Update internal parameters using stored gradients.
12    fn update_params(&mut self, learning_rate: f64);
13}
14
15/// A dense (fully connected) layer with weights + biases.
16pub struct DenseLayer {
17    pub weights: Array2<f64>,
18    pub biases: Array1<f64>,
19
20    input_cache: Option<Array2<f64>>,
21    weight_grads: Option<Array2<f64>>,
22    bias_grads: Option<Array1<f64>>,
23}
24
25impl DenseLayer {
26    pub fn new(in_features: usize, out_features: usize, init_std: f64) -> Self {
27        let mut rng = thread_rng();
28        let dist = Normal::new(0.0, init_std).unwrap();
29
30        let weights = Array2::from_shape_fn((in_features, out_features), |_| dist.sample(&mut rng));
31        let biases = Array1::zeros(out_features);
32
33        Self {
34            weights,
35            biases,
36            input_cache: None,
37            weight_grads: None,
38            bias_grads: None,
39        }
40    }
41}
42
43#[allow(non_snake_case)]
44impl Layer for DenseLayer {
45    fn forward(&mut self, input: &Array2<f64>) -> Array2<f64> {
46        self.input_cache = Some(input.clone());
47        let mut output = input.dot(&self.weights);
48        output += &self.biases;
49        output
50    }
51
52    fn backward(&mut self, grad_output: &Array2<f64>) -> Array2<f64> {
53        let input = self
54            .input_cache
55            .as_ref()
56            .expect("Must call forward before backward.");
57
58        // dW = input^T * grad_output
59        let dW = input.t().dot(grad_output);
60        // dB = sum of grad_output over the batch
61        let dB = grad_output.sum_axis(Axis(0));
62        // dX = grad_output * W^T
63        let dX = grad_output.dot(&self.weights.t());
64
65        self.weight_grads = Some(dW);
66        self.bias_grads = Some(dB);
67
68        dX
69    }
70
71    fn update_params(&mut self, lr: f64) {
72        if let Some(dw) = &self.weight_grads {
73            self.weights = &self.weights - &(dw * lr);
74        }
75        if let Some(db) = &self.bias_grads {
76            self.biases = &self.biases - &(db * lr);
77        }
78        self.input_cache = None;
79        self.weight_grads = None;
80        self.bias_grads = None;
81    }
82}
83
84/// Simple sigmoid activation layer
85pub struct Sigmoid {
86    output_cache: Option<Array2<f64>>,
87}
88
89impl Default for Sigmoid {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95impl Sigmoid {
96    pub fn new() -> Self {
97        Self { output_cache: None }
98    }
99}
100
101impl Layer for Sigmoid {
102    fn forward(&mut self, input: &Array2<f64>) -> Array2<f64> {
103        let output = input.mapv(|x| 1.0 / (1.0 + (-x).exp()));
104        self.output_cache = Some(output.clone());
105        output
106    }
107
108    fn backward(&mut self, grad_output: &Array2<f64>) -> Array2<f64> {
109        let out = self.output_cache.as_ref().unwrap();
110        // dSigmoid = sigmoid(x) * (1 - sigmoid(x))
111
112        out * (1.0 - out) * grad_output
113    }
114
115    fn update_params(&mut self, _lr: f64) {
116        self.output_cache = None;
117    }
118}
119
120/// A small sequential network with multiple layers.
121pub struct SequentialNN {
122    pub layers: Vec<Box<dyn Layer>>,
123    pub learning_rate: f64,
124}
125
126impl SequentialNN {
127    pub fn new(layers: Vec<Box<dyn Layer>>, learning_rate: f64) -> Self {
128        Self {
129            layers,
130            learning_rate,
131        }
132    }
133
134    /// Forward pass through the entire network
135    pub fn forward(&mut self, input: &Array2<f64>) -> Array2<f64> {
136        let mut x = input.clone();
137        for layer in self.layers.iter_mut() {
138            x = layer.forward(&x);
139        }
140        x
141    }
142
143    /// Backward pass
144    pub fn backward(&mut self, grad_output: &Array2<f64>) {
145        let mut grad = grad_output.clone();
146        for layer in self.layers.iter_mut().rev() {
147            grad = layer.backward(&grad);
148        }
149    }
150
151    /// Update all params
152    pub fn update_params(&mut self) {
153        for layer in self.layers.iter_mut() {
154            layer.update_params(self.learning_rate);
155        }
156    }
157
158    /// Mean-squared-error
159    pub fn mse_loss(&mut self, inputs: &Array2<f64>, targets: &Array2<f64>) -> f64 {
160        let preds = self.forward(inputs);
161        let diff = &preds - targets;
162        diff.mapv(|x| x.powi(2)).mean().unwrap_or(0.0)
163    }
164}
165
166/// Train for one epoch using **stochastic gradient descent** (mini-batch style).
167/// - `batch_size = 1` is "pure" SGD.
168/// - Larger batch_size is "mini-batch" SGD.
169pub fn train_sgd(
170    net: &mut SequentialNN,
171    inputs: &Array2<f64>,
172    targets: &Array2<f64>,
173    batch_size: usize,
174) {
175    let n_samples = inputs.len_of(Axis(0));
176    let mut indices: Vec<usize> = (0..n_samples).collect();
177    indices.shuffle(&mut thread_rng());
178
179    // Loop over mini-batches in random order
180    for chunk in indices.chunks(batch_size) {
181        // Gather the current mini-batch
182        let batch_input = Array2::from_shape_fn((chunk.len(), inputs.len_of(Axis(1))), |(i, j)| {
183            inputs[[chunk[i], j]]
184        });
185        let batch_target =
186            Array2::from_shape_fn((chunk.len(), targets.len_of(Axis(1))), |(i, j)| {
187                targets[[chunk[i], j]]
188            });
189
190        // Forward
191        let preds = net.forward(&batch_input);
192        // MSE derivative: d(0.5*MSE)/dpreds = (preds - batch_target)
193        let grad_loss = &preds - &batch_target;
194
195        // Backward
196        net.backward(&grad_loss);
197
198        // Update
199        net.update_params();
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use approx::assert_relative_eq;
207
208    #[test]
209    fn test_dense_layer_forward() {
210        let mut layer = DenseLayer::new(2, 3, 0.1);
211        // Set deterministic weights and biases for testing
212        layer.weights = Array2::from_shape_vec((2, 3), vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6]).unwrap();
213        layer.biases = Array1::from_vec(vec![0.1, 0.2, 0.3]);
214
215        let input = Array2::from_shape_vec((1, 2), vec![1.0, 2.0]).unwrap();
216        let output = layer.forward(&input);
217
218        // Manual calculation: input · weights + biases
219        assert_relative_eq!(output[[0, 0]], 1.0 * 0.1 + 2.0 * 0.4 + 0.1, epsilon = 1e-10);
220        assert_relative_eq!(output[[0, 1]], 1.0 * 0.2 + 2.0 * 0.5 + 0.2, epsilon = 1e-10);
221        assert_relative_eq!(output[[0, 2]], 1.0 * 0.3 + 2.0 * 0.6 + 0.3, epsilon = 1e-10);
222    }
223
224    #[test]
225    fn test_sigmoid_activation() {
226        let mut sigmoid = Sigmoid::new();
227        let input = Array2::from_shape_vec((1, 3), vec![0.0, 1.0, -1.0]).unwrap();
228        let output = sigmoid.forward(&input);
229
230        // Test sigmoid(0) = 0.5
231        assert_relative_eq!(output[[0, 0]], 0.5, epsilon = 1e-10);
232        // Test sigmoid(1) ≈ 0.731...
233        assert_relative_eq!(
234            output[[0, 1]],
235            1.0 / (1.0 + (-1.0f64).exp()),
236            epsilon = 1e-10
237        );
238        // Test sigmoid(-1) ≈ 0.269...
239        assert_relative_eq!(output[[0, 2]], 1.0 / (1.0 + 1.0f64.exp()), epsilon = 1e-10);
240    }
241
242    #[test]
243    fn test_dense_layer_backward() {
244        let mut layer = DenseLayer::new(2, 2, 0.1);
245        layer.weights = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
246        layer.biases = Array1::from_vec(vec![0.1, 0.2]);
247
248        let input = Array2::from_shape_vec((1, 2), vec![1.0, 2.0]).unwrap();
249        layer.forward(&input);
250
251        let grad_output = Array2::from_shape_vec((1, 2), vec![1.0, 1.0]).unwrap();
252        let grad_input = layer.backward(&grad_output);
253
254        // Check gradient shapes
255        assert_eq!(grad_input.shape(), &[1, 2]);
256        assert!(layer.weight_grads.is_some());
257        assert!(layer.bias_grads.is_some());
258    }
259
260    #[test]
261    fn test_sequential_network() {
262        let mut net = SequentialNN::new(
263            vec![
264                Box::new(DenseLayer::new(2, 3, 0.1)),
265                Box::new(Sigmoid::new()),
266                Box::new(DenseLayer::new(3, 1, 0.1)),
267            ],
268            0.1,
269        );
270
271        // Test forward pass
272        let input = Array2::from_shape_vec((1, 2), vec![1.0, 2.0]).unwrap();
273        let output = net.forward(&input);
274        assert_eq!(output.shape(), &[1, 1]);
275
276        // Test loss calculation
277        let target = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
278        let loss = net.mse_loss(&input, &target);
279        assert!(loss >= 0.0);
280    }
281
282    #[test]
283    fn test_sgd_training() {
284        let mut net = SequentialNN::new(
285            vec![
286                Box::new(DenseLayer::new(2, 3, 0.1)),
287                Box::new(Sigmoid::new()),
288                Box::new(DenseLayer::new(3, 1, 0.1)),
289            ],
290            0.1,
291        );
292
293        // XOR problem inputs and outputs
294        let inputs =
295            Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0]).unwrap();
296
297        let targets = Array2::from_shape_vec((4, 1), vec![0.0, 1.0, 1.0, 0.0]).unwrap();
298
299        // Initial loss
300        let initial_loss = net.mse_loss(&inputs, &targets);
301
302        // Train for a few epochs
303        for _ in 0..100 {
304            train_sgd(&mut net, &inputs, &targets, 2);
305        }
306
307        // Final loss should be lower than initial loss
308        let final_loss = net.mse_loss(&inputs, &targets);
309        assert!(final_loss < initial_loss, "Training should reduce loss");
310    }
311}