Skip to main content

rust_mlp/
layer.rs

1//! Dense layer implementation.
2//!
3//! A `Layer` is a dense affine transform followed by an element-wise activation:
4//!
5//! - `z = W x + b`
6//! - `y = activation(z)`
7//!
8//! The activation is stored in the layer so an `Mlp` can mix activation functions
9//! across layers.
10//!
11//! Shape mismatches are treated as programmer error and will panic via `assert!`.
12
13use rand::Rng;
14use rand::distributions::{Distribution, Uniform};
15
16use crate::Activation;
17use crate::{Error, Result};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20/// Initialization scheme for layer weights.
21pub enum Init {
22    Zeros,
23    /// Xavier/Glorot uniform init.
24    ///
25    /// This is a good default for `tanh`, `sigmoid`, and `identity` activations.
26    Xavier,
27    /// He/Kaiming uniform init.
28    ///
29    /// This is a good default for `ReLU`-family activations.
30    He,
31}
32
33#[derive(Debug, Clone)]
34/// A dense layer: `y = activation(Wx + b)`.
35///
36/// Weights use row-major layout with shape `(out_dim, in_dim)`.
37pub struct Layer {
38    in_dim: usize,
39    out_dim: usize,
40    activation: Activation,
41    /// Row-major matrix with shape (out_dim, in_dim).
42    weights: Vec<f32>,
43    biases: Vec<f32>,
44}
45
46impl Layer {
47    #[inline]
48    /// Returns this layer's activation.
49    pub fn activation(&self) -> Activation {
50        self.activation
51    }
52}
53
54impl Layer {
55    /// Construct a layer from explicit parameter buffers.
56    ///
57    /// This is primarily intended for model loading/serialization.
58    #[cfg(feature = "serde")]
59    pub(crate) fn from_parts(
60        in_dim: usize,
61        out_dim: usize,
62        activation: Activation,
63        weights: Vec<f32>,
64        biases: Vec<f32>,
65    ) -> Result<Self> {
66        if in_dim == 0 || out_dim == 0 {
67            return Err(Error::InvalidData(format!(
68                "layer dims must be > 0, got in_dim={in_dim} out_dim={out_dim}"
69            )));
70        }
71
72        activation
73            .validate()
74            .map_err(|e| Error::InvalidData(format!("invalid activation: {e}")))?;
75
76        let expected_w = in_dim
77            .checked_mul(out_dim)
78            .ok_or_else(|| Error::InvalidData("layer weight shape overflow".to_owned()))?;
79        if weights.len() != expected_w {
80            return Err(Error::InvalidData(format!(
81                "weights length {} does not match out_dim * in_dim ({} * {})",
82                weights.len(),
83                out_dim,
84                in_dim
85            )));
86        }
87        if biases.len() != out_dim {
88            return Err(Error::InvalidData(format!(
89                "biases length {} does not match out_dim {}",
90                biases.len(),
91                out_dim
92            )));
93        }
94
95        if weights.iter().any(|v| !v.is_finite()) {
96            return Err(Error::InvalidData(
97                "weights must contain only finite values".to_owned(),
98            ));
99        }
100        if biases.iter().any(|v| !v.is_finite()) {
101            return Err(Error::InvalidData(
102                "biases must contain only finite values".to_owned(),
103            ));
104        }
105
106        Ok(Self {
107            in_dim,
108            out_dim,
109            activation,
110            weights,
111            biases,
112        })
113    }
114
115    pub fn new_with_rng<R: Rng + ?Sized>(
116        in_dim: usize,
117        out_dim: usize,
118        init: Init,
119        activation: Activation,
120        rng: &mut R,
121    ) -> Result<Self> {
122        if in_dim == 0 || out_dim == 0 {
123            return Err(Error::InvalidConfig("layer dims must be > 0".to_owned()));
124        }
125
126        activation.validate()?;
127
128        let mut weights = vec![0.0; in_dim * out_dim];
129        match init {
130            Init::Zeros => {}
131            Init::Xavier => {
132                let fan_in = in_dim as f32;
133                let fan_out = out_dim as f32;
134                let limit = (6.0 / (fan_in + fan_out)).sqrt();
135                let dist = Uniform::new(-limit, limit);
136                for w in &mut weights {
137                    *w = dist.sample(rng);
138                }
139            }
140            Init::He => {
141                let fan_in = in_dim as f32;
142                let limit = (6.0 / fan_in).sqrt();
143                let dist = Uniform::new(-limit, limit);
144                for w in &mut weights {
145                    *w = dist.sample(rng);
146                }
147            }
148        }
149
150        let biases = vec![0.0; out_dim];
151
152        Ok(Self {
153            in_dim,
154            out_dim,
155            activation,
156            weights,
157            biases,
158        })
159    }
160
161    #[inline]
162    /// Returns the input dimension.
163    pub fn in_dim(&self) -> usize {
164        self.in_dim
165    }
166
167    #[inline]
168    /// Returns the output dimension.
169    pub fn out_dim(&self) -> usize {
170        self.out_dim
171    }
172
173    #[inline]
174    pub(crate) fn weights(&self) -> &[f32] {
175        &self.weights
176    }
177
178    #[inline]
179    pub(crate) fn biases(&self) -> &[f32] {
180        &self.biases
181    }
182
183    #[inline]
184    #[cfg(test)]
185    pub(crate) fn weights_mut(&mut self) -> &mut [f32] {
186        &mut self.weights
187    }
188
189    #[inline]
190    #[cfg(test)]
191    pub(crate) fn biases_mut(&mut self) -> &mut [f32] {
192        &mut self.biases
193    }
194
195    /// Forward pass for a single sample.
196    ///
197    /// Computes:
198    /// - `z = W * inputs + b`
199    /// - `outputs = activation(z)`
200    ///
201    /// Shape contract:
202    /// - `inputs.len() == self.in_dim`
203    /// - `outputs.len() == self.out_dim`
204    #[inline]
205    pub fn forward(&self, inputs: &[f32], outputs: &mut [f32]) {
206        assert_eq!(
207            inputs.len(),
208            self.in_dim,
209            "inputs len {} does not match layer in_dim {}",
210            inputs.len(),
211            self.in_dim
212        );
213        assert_eq!(
214            outputs.len(),
215            self.out_dim,
216            "outputs len {} does not match layer out_dim {}",
217            outputs.len(),
218            self.out_dim
219        );
220
221        let activation = self.activation;
222
223        for (o, out) in outputs.iter_mut().enumerate() {
224            let mut sum = self.biases[o];
225            let row = o * self.in_dim;
226            for (i, &x) in inputs.iter().enumerate() {
227                sum = self.weights[row + i].mul_add(x, sum);
228            }
229            *out = activation.forward(sum);
230        }
231    }
232
233    /// Backward pass for a single sample.
234    ///
235    /// This uses overwrite semantics:
236    /// - `d_inputs` is overwritten (and internally zeroed before accumulation)
237    /// - `d_weights` is overwritten
238    /// - `d_biases` is overwritten
239    ///
240    /// Inputs:
241    /// - `inputs`: the same inputs passed to `forward`
242    /// - `outputs`: the outputs previously produced by `forward` (post-activation)
243    /// - `d_outputs`: upstream gradient dL/d(outputs)
244    ///
245    /// Shape contract:
246    /// - `inputs.len() == self.in_dim`
247    /// - `outputs.len() == self.out_dim`
248    /// - `d_outputs.len() == self.out_dim`
249    /// - `d_inputs.len() == self.in_dim`
250    /// - `d_weights.len() == self.weights.len()`
251    /// - `d_biases.len() == self.out_dim`
252    #[inline]
253    pub fn backward(
254        &self,
255        inputs: &[f32],
256        outputs: &[f32],
257        d_outputs: &[f32],
258        d_inputs: &mut [f32],
259        d_weights: &mut [f32],
260        d_biases: &mut [f32],
261    ) {
262        assert_eq!(
263            inputs.len(),
264            self.in_dim,
265            "inputs len {} does not match layer in_dim {}",
266            inputs.len(),
267            self.in_dim
268        );
269        assert_eq!(
270            outputs.len(),
271            self.out_dim,
272            "outputs len {} does not match layer out_dim {}",
273            outputs.len(),
274            self.out_dim
275        );
276        assert_eq!(
277            d_outputs.len(),
278            self.out_dim,
279            "d_outputs len {} does not match layer out_dim {}",
280            d_outputs.len(),
281            self.out_dim
282        );
283        assert_eq!(
284            d_inputs.len(),
285            self.in_dim,
286            "d_inputs len {} does not match layer in_dim {}",
287            d_inputs.len(),
288            self.in_dim
289        );
290        assert_eq!(
291            d_weights.len(),
292            self.weights.len(),
293            "d_weights len {} does not match weights len {}",
294            d_weights.len(),
295            self.weights.len()
296        );
297        assert_eq!(
298            d_biases.len(),
299            self.out_dim,
300            "d_biases len {} does not match layer out_dim {}",
301            d_biases.len(),
302            self.out_dim
303        );
304
305        // d_inputs accumulates contributions from all outputs.
306        d_inputs.fill(0.0);
307
308        let activation = self.activation;
309
310        for o in 0..self.out_dim {
311            let d_z = d_outputs[o] * activation.grad_from_output(outputs[o]);
312            d_biases[o] = d_z;
313
314            let row = o * self.in_dim;
315            for i in 0..self.in_dim {
316                let w = self.weights[row + i];
317                d_weights[row + i] = d_z * inputs[i];
318                d_inputs[i] = w.mul_add(d_z, d_inputs[i]);
319            }
320        }
321    }
322
323    /// Backward pass for a single sample (parameter accumulation semantics).
324    ///
325    /// This is identical to `backward` except that parameter gradients are *accumulated*:
326    /// - `d_inputs` is overwritten (and internally zeroed before accumulation)
327    /// - `d_weights` is accumulated into (`+=`)
328    /// - `d_biases` is accumulated into (`+=`)
329    ///
330    /// This is useful for mini-batch training where you sum gradients over multiple samples.
331    ///
332    /// Shape contract:
333    /// - `inputs.len() == self.in_dim`
334    /// - `outputs.len() == self.out_dim`
335    /// - `d_outputs.len() == self.out_dim`
336    /// - `d_inputs.len() == self.in_dim`
337    /// - `d_weights.len() == self.weights.len()`
338    /// - `d_biases.len() == self.out_dim`
339    #[inline]
340    pub fn backward_accumulate(
341        &self,
342        inputs: &[f32],
343        outputs: &[f32],
344        d_outputs: &[f32],
345        d_inputs: &mut [f32],
346        d_weights: &mut [f32],
347        d_biases: &mut [f32],
348    ) {
349        assert_eq!(
350            inputs.len(),
351            self.in_dim,
352            "inputs len {} does not match layer in_dim {}",
353            inputs.len(),
354            self.in_dim
355        );
356        assert_eq!(
357            outputs.len(),
358            self.out_dim,
359            "outputs len {} does not match layer out_dim {}",
360            outputs.len(),
361            self.out_dim
362        );
363        assert_eq!(
364            d_outputs.len(),
365            self.out_dim,
366            "d_outputs len {} does not match layer out_dim {}",
367            d_outputs.len(),
368            self.out_dim
369        );
370        assert_eq!(
371            d_inputs.len(),
372            self.in_dim,
373            "d_inputs len {} does not match layer in_dim {}",
374            d_inputs.len(),
375            self.in_dim
376        );
377        assert_eq!(
378            d_weights.len(),
379            self.weights.len(),
380            "d_weights len {} does not match weights len {}",
381            d_weights.len(),
382            self.weights.len()
383        );
384        assert_eq!(
385            d_biases.len(),
386            self.out_dim,
387            "d_biases len {} does not match layer out_dim {}",
388            d_biases.len(),
389            self.out_dim
390        );
391
392        // d_inputs accumulates contributions from all outputs.
393        d_inputs.fill(0.0);
394
395        let activation = self.activation;
396
397        for o in 0..self.out_dim {
398            let d_z = d_outputs[o] * activation.grad_from_output(outputs[o]);
399            d_biases[o] += d_z;
400
401            let row = o * self.in_dim;
402            for i in 0..self.in_dim {
403                let w = self.weights[row + i];
404                d_weights[row + i] += d_z * inputs[i];
405                d_inputs[i] = w.mul_add(d_z, d_inputs[i]);
406            }
407        }
408    }
409
410    /// Applies an SGD update: `param -= lr * d_param`.
411    ///
412    /// Shape contract:
413    /// - `d_weights.len() == self.weights.len()`
414    /// - `d_biases.len() == self.biases.len()`
415    #[inline]
416    pub fn sgd_step(&mut self, d_weights: &[f32], d_biases: &[f32], lr: f32) {
417        assert_eq!(
418            d_weights.len(),
419            self.weights.len(),
420            "d_weights len {} does not match weights len {}",
421            d_weights.len(),
422            self.weights.len()
423        );
424        assert_eq!(
425            d_biases.len(),
426            self.biases.len(),
427            "d_biases len {} does not match biases len {}",
428            d_biases.len(),
429            self.biases.len()
430        );
431
432        for (w, &dw) in self.weights.iter_mut().zip(d_weights) {
433            *w -= lr * dw;
434        }
435        for (b, &db) in self.biases.iter_mut().zip(d_biases) {
436            *b -= lr * db;
437        }
438    }
439
440    /// Apply decoupled weight decay to weights only.
441    ///
442    /// `w -= lr * weight_decay * w`.
443    pub(crate) fn apply_weight_decay(&mut self, lr: f32, weight_decay: f32) {
444        assert!(
445            lr.is_finite() && lr > 0.0,
446            "learning rate must be finite and > 0"
447        );
448        assert!(
449            weight_decay.is_finite() && weight_decay >= 0.0,
450            "weight_decay must be finite and >= 0"
451        );
452
453        if weight_decay == 0.0 {
454            return;
455        }
456
457        let scale = lr * weight_decay;
458        for w in &mut self.weights {
459            *w -= scale * *w;
460        }
461    }
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467    use rand::SeedableRng;
468    use rand::rngs::StdRng;
469
470    fn loss_for_layer(layer: &Layer, input: &[f32], target: &[f32], out: &mut [f32]) -> f32 {
471        layer.forward(input, out);
472        crate::loss::mse(out, target)
473    }
474
475    fn assert_close(analytic: f32, numeric: f32, abs_tol: f32, rel_tol: f32) {
476        let diff = (analytic - numeric).abs();
477        let scale = analytic.abs().max(numeric.abs()).max(1.0);
478        assert!(
479            diff <= abs_tol || diff / scale <= rel_tol,
480            "analytic={analytic} numeric={numeric} diff={diff}"
481        );
482    }
483
484    #[test]
485    fn seeded_init_is_deterministic() {
486        let mut rng_a = StdRng::seed_from_u64(123);
487        let mut rng_b = StdRng::seed_from_u64(123);
488        let a = Layer::new_with_rng(3, 2, Init::Xavier, Activation::Tanh, &mut rng_a).unwrap();
489        let b = Layer::new_with_rng(3, 2, Init::Xavier, Activation::Tanh, &mut rng_b).unwrap();
490        assert_eq!(a.weights, b.weights);
491        assert_eq!(a.biases, b.biases);
492    }
493
494    #[test]
495    fn backward_matches_numeric_gradients() {
496        let in_dim = 3;
497        let out_dim = 2;
498        let mut rng = StdRng::seed_from_u64(0);
499        let mut layer =
500            Layer::new_with_rng(in_dim, out_dim, Init::Xavier, Activation::Tanh, &mut rng).unwrap();
501
502        let mut input = vec![0.3_f32, -0.7_f32, 0.1_f32];
503        let target = vec![0.2_f32, -0.1_f32];
504
505        let mut outputs = vec![0.0_f32; out_dim];
506        layer.forward(&input, &mut outputs);
507
508        let mut d_outputs = vec![0.0_f32; out_dim];
509        let _loss = crate::loss::mse_backward(&outputs, &target, &mut d_outputs);
510
511        let mut d_inputs = vec![0.0_f32; in_dim];
512        let mut d_weights = vec![0.0_f32; in_dim * out_dim];
513        let mut d_biases = vec![0.0_f32; out_dim];
514
515        layer.backward(
516            &input,
517            &outputs,
518            &d_outputs,
519            &mut d_inputs,
520            &mut d_weights,
521            &mut d_biases,
522        );
523
524        let eps = 1e-3_f32;
525        let abs_tol = 1e-3_f32;
526        let rel_tol = 1e-2_f32;
527
528        // Weights.
529        let mut out_tmp = vec![0.0_f32; out_dim];
530        for (p, &analytic) in d_weights.iter().enumerate() {
531            let orig = layer.weights[p];
532
533            layer.weights[p] = orig + eps;
534            let loss_plus = loss_for_layer(&layer, &input, &target, &mut out_tmp);
535
536            layer.weights[p] = orig - eps;
537            let loss_minus = loss_for_layer(&layer, &input, &target, &mut out_tmp);
538
539            layer.weights[p] = orig;
540
541            let numeric = (loss_plus - loss_minus) / (2.0 * eps);
542            assert_close(analytic, numeric, abs_tol, rel_tol);
543        }
544
545        // Biases.
546        for (p, &analytic) in d_biases.iter().enumerate() {
547            let orig = layer.biases[p];
548
549            layer.biases[p] = orig + eps;
550            let loss_plus = loss_for_layer(&layer, &input, &target, &mut out_tmp);
551
552            layer.biases[p] = orig - eps;
553            let loss_minus = loss_for_layer(&layer, &input, &target, &mut out_tmp);
554
555            layer.biases[p] = orig;
556
557            let numeric = (loss_plus - loss_minus) / (2.0 * eps);
558            assert_close(analytic, numeric, abs_tol, rel_tol);
559        }
560
561        // Inputs.
562        for i in 0..input.len() {
563            let orig = input[i];
564
565            input[i] = orig + eps;
566            let loss_plus = loss_for_layer(&layer, &input, &target, &mut out_tmp);
567
568            input[i] = orig - eps;
569            let loss_minus = loss_for_layer(&layer, &input, &target, &mut out_tmp);
570
571            input[i] = orig;
572
573            let numeric = (loss_plus - loss_minus) / (2.0 * eps);
574            let analytic = d_inputs[i];
575            assert_close(analytic, numeric, abs_tol, rel_tol);
576        }
577    }
578
579    #[test]
580    #[should_panic]
581    fn forward_panics_on_input_shape_mismatch() {
582        let mut rng = StdRng::seed_from_u64(0);
583        let layer = Layer::new_with_rng(3, 2, Init::Xavier, Activation::Tanh, &mut rng).unwrap();
584        let input = vec![0.0_f32; 2];
585        let mut out = vec![0.0_f32; 2];
586        layer.forward(&input, &mut out);
587    }
588
589    #[test]
590    #[should_panic]
591    fn forward_panics_on_output_shape_mismatch() {
592        let mut rng = StdRng::seed_from_u64(0);
593        let layer = Layer::new_with_rng(3, 2, Init::Xavier, Activation::Tanh, &mut rng).unwrap();
594        let input = vec![0.0_f32; 3];
595        let mut out = vec![0.0_f32; 1];
596        layer.forward(&input, &mut out);
597    }
598}