Skip to main content

pc_rl_core/
layer.rs

1// Author: Julian Bolivar
2// Version: 1.0.0
3// Date: 2026-03-25
4
5//! Dense neural network layer.
6//!
7//! Provides forward propagation, transpose forward (PC top-down pass),
8//! and backward propagation with gradient/weight clipping. Building
9//! block for both [`crate::PcActor`] and [`crate::MlpCritic`].
10
11use rand::Rng;
12use serde::{Deserialize, Serialize};
13
14use crate::activation::Activation;
15use crate::linalg::cpu::CpuLinAlg;
16use crate::linalg::LinAlg;
17use crate::matrix::{GRAD_CLIP, WEIGHT_CLIP};
18
19/// Definition of a layer's shape and activation, used for topology configuration.
20///
21/// # Examples
22///
23/// ```
24/// use pc_rl_core::activation::Activation;
25/// use pc_rl_core::layer::LayerDef;
26///
27/// let def = LayerDef { size: 64, activation: Activation::Tanh };
28/// assert_eq!(def.size, 64);
29/// ```
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct LayerDef {
32    /// Number of neurons in this layer.
33    pub size: usize,
34    /// Activation function applied after the linear transform.
35    pub activation: Activation,
36}
37
38/// A single dense layer with weights, bias, and activation function.
39///
40/// Generic over a [`LinAlg`] backend `L`. Defaults to [`CpuLinAlg`] for
41/// backward compatibility.
42///
43/// Weights have shape `[output_size × input_size]`. Bias has length `output_size`.
44///
45/// # Examples
46///
47/// ```
48/// use pc_rl_core::activation::Activation;
49/// use pc_rl_core::layer::Layer;
50/// use rand::SeedableRng;
51/// use rand::rngs::StdRng;
52///
53/// let mut rng = StdRng::seed_from_u64(42);
54/// let layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
55/// let output: Vec<f64> = layer.forward(&vec![1.0, 0.0, -1.0, 0.5]);
56/// assert_eq!(output.len(), 3);
57/// ```
58#[derive(Debug, Clone, Serialize, Deserialize)]
59#[serde(bound(
60    serialize = "L::Matrix: Serialize, L::Vector: Serialize",
61    deserialize = "L::Matrix: for<'a> Deserialize<'a>, L::Vector: for<'a> Deserialize<'a>"
62))]
63pub struct Layer<L: LinAlg = CpuLinAlg> {
64    /// Weight matrix of shape `[output_size × input_size]`.
65    pub weights: L::Matrix,
66    /// Bias vector of length `output_size`.
67    pub bias: L::Vector,
68    /// Activation function applied element-wise after the linear transform.
69    pub activation: Activation,
70}
71
72impl<L: LinAlg> Layer<L> {
73    /// Creates a new layer with Xavier-initialized weights and zero bias.
74    ///
75    /// # Arguments
76    ///
77    /// * `input_size` - Number of inputs to this layer.
78    /// * `output_size` - Number of neurons (outputs) in this layer.
79    /// * `activation` - Activation function to apply after the linear transform.
80    /// * `rng` - Random number generator for weight initialization.
81    pub fn new(
82        input_size: usize,
83        output_size: usize,
84        activation: Activation,
85        rng: &mut impl Rng,
86    ) -> Self {
87        Self {
88            weights: L::xavier_mat(output_size, input_size, rng),
89            bias: L::zeros_vec(output_size),
90            activation,
91        }
92    }
93
94    /// Computes `activation(W * input + bias)`.
95    ///
96    /// # Panics
97    ///
98    /// Panics if `input.len() != input_size` (number of columns in weights).
99    pub fn forward(&self, input: &L::Vector) -> L::Vector {
100        let linear = L::mat_vec_mul(&self.weights, input);
101        let biased = L::vec_add(&linear, &self.bias);
102        L::apply_activation(&biased, self.activation)
103    }
104
105    /// Computes `activation(W^T * input)` (no bias).
106    ///
107    /// Used for PC top-down predictions. The `activation` parameter is
108    /// separate from `self.activation` because at the output→last-hidden
109    /// boundary, different activations may apply.
110    ///
111    /// # Panics
112    ///
113    /// Panics if `input.len() != output_size` (number of rows in weights).
114    pub fn transpose_forward(&self, input: &L::Vector, activation: Activation) -> L::Vector {
115        let wt = L::mat_transpose(&self.weights);
116        let linear = L::mat_vec_mul(&wt, input);
117        L::apply_activation(&linear, activation)
118    }
119
120    /// Backpropagation with gradient and weight clipping.
121    ///
122    /// Returns the propagated delta for the layer below (length = input_size).
123    ///
124    /// # Arguments
125    ///
126    /// * `input` - Input that was fed to this layer during forward pass.
127    /// * `output` - Output of this layer from the forward pass (post-activation).
128    /// * `delta` - Error signal from the layer above.
129    /// * `lr` - Base learning rate.
130    /// * `surprise_scale` - Multiplier on `lr` based on surprise score.
131    ///
132    /// # Panics
133    ///
134    /// Panics on dimension mismatches.
135    pub fn backward(
136        &mut self,
137        input: &L::Vector,
138        output: &L::Vector,
139        delta: &L::Vector,
140        lr: f64,
141        surprise_scale: f64,
142    ) -> L::Vector {
143        // 1. Activation derivative
144        let deriv = L::apply_derivative(output, self.activation);
145
146        // 2. Local gradient = delta * deriv (element-wise Hadamard product)
147        let mut grad = L::vec_hadamard(delta, &deriv);
148
149        // 3. Clip gradient
150        L::clip_vec(&mut grad, GRAD_CLIP);
151
152        // 4. Effective learning rate
153        let effective_lr = lr * surprise_scale;
154
155        // 5. Weight gradient: dW = outer(grad, input)
156        let dw = L::outer_product(&grad, input);
157
158        // 6. Update weights (scale_add includes WEIGHT_CLIP clamping)
159        L::mat_scale_add(&mut self.weights, &dw, -effective_lr);
160
161        // 7. Update bias with clamping
162        let bias_update = L::vec_scale(&grad, effective_lr);
163        let new_bias = L::vec_sub(&self.bias, &bias_update);
164        self.bias = new_bias;
165        L::clip_vec(&mut self.bias, WEIGHT_CLIP);
166
167        // 8. Propagated delta: W^T * grad
168        let wt = L::mat_transpose(&self.weights);
169        L::mat_vec_mul(&wt, &grad)
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use rand::rngs::StdRng;
177    use rand::SeedableRng;
178
179    fn make_rng() -> StdRng {
180        StdRng::seed_from_u64(42)
181    }
182
183    // ── forward tests ──────────────────────────────────────────────
184
185    #[test]
186    fn test_forward_output_length_equals_output_size() {
187        let mut rng = make_rng();
188        let layer: Layer = Layer::new(4, 3, Activation::Linear, &mut rng);
189        let out = layer.forward(&vec![1.0, 0.0, -1.0, 0.5]);
190        assert_eq!(out.len(), 3);
191    }
192
193    #[test]
194    fn test_forward_linear_known_value() {
195        let mut rng = make_rng();
196        let mut layer: Layer = Layer::new(2, 1, Activation::Linear, &mut rng);
197        // Set known weights and bias
198        layer.weights.set(0, 0, 2.0);
199        layer.weights.set(0, 1, 3.0);
200        layer.bias[0] = 1.0;
201        // output = 2*1 + 3*2 + 1 = 9
202        let out = layer.forward(&vec![1.0, 2.0]);
203        assert!((out[0] - 9.0).abs() < 1e-12);
204    }
205
206    #[test]
207    fn test_forward_tanh_output_bounded() {
208        let mut rng = make_rng();
209        let layer: Layer = Layer::new(4, 5, Activation::Tanh, &mut rng);
210        let out = layer.forward(&vec![10.0, -10.0, 5.0, -5.0]);
211        for &v in &out {
212            assert!(v > -1.0 && v < 1.0, "Tanh output {v} not in (-1,1)");
213        }
214    }
215
216    #[test]
217    fn test_forward_sigmoid_output_bounded() {
218        let mut rng = make_rng();
219        let layer: Layer = Layer::new(4, 5, Activation::Sigmoid, &mut rng);
220        let out = layer.forward(&vec![10.0, -10.0, 5.0, -5.0]);
221        for &v in &out {
222            assert!(v > 0.0 && v < 1.0, "Sigmoid output {v} not in (0,1)");
223        }
224    }
225
226    #[test]
227    fn test_forward_relu_no_negative_outputs() {
228        let mut rng = make_rng();
229        let layer: Layer = Layer::new(4, 5, Activation::Relu, &mut rng);
230        let out = layer.forward(&vec![10.0, -10.0, 5.0, -5.0]);
231        for &v in &out {
232            assert!(v >= 0.0, "ReLU output {v} is negative");
233        }
234    }
235
236    #[test]
237    fn test_forward_all_outputs_finite() {
238        let mut rng = make_rng();
239        let layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
240        let out = layer.forward(&vec![1e6, -1e6, 1e3, -1e3]);
241        for &v in &out {
242            assert!(v.is_finite(), "Output {v} is not finite");
243        }
244    }
245
246    #[test]
247    #[should_panic]
248    fn test_forward_panics_wrong_input_length() {
249        let mut rng = make_rng();
250        let layer: Layer = Layer::new(4, 3, Activation::Linear, &mut rng);
251        let _ = layer.forward(&vec![1.0, 2.0]); // wrong length
252    }
253
254    // ── transpose_forward tests ────────────────────────────────────
255
256    #[test]
257    fn test_transpose_forward_output_length_equals_input_size() {
258        let mut rng = make_rng();
259        let layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
260        // transpose_forward takes output_size input, returns input_size
261        let out = layer.transpose_forward(&vec![0.5, -0.5, 0.0], Activation::Tanh);
262        assert_eq!(out.len(), 4);
263    }
264
265    #[test]
266    fn test_transpose_forward_all_finite() {
267        let mut rng = make_rng();
268        let layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
269        let out = layer.transpose_forward(&vec![1e3, -1e3, 0.0], Activation::Tanh);
270        for &v in &out {
271            assert!(v.is_finite(), "transpose_forward output {v} is not finite");
272        }
273    }
274
275    #[test]
276    fn test_transpose_forward_different_activation_changes_output() {
277        let mut rng = make_rng();
278        let layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
279        let input = vec![0.5, -0.5, 0.3];
280        let out_tanh = layer.transpose_forward(&input, Activation::Tanh);
281        let out_linear = layer.transpose_forward(&input, Activation::Linear);
282        // At least one element should differ
283        let differs = out_tanh
284            .iter()
285            .zip(out_linear.iter())
286            .any(|(a, b)| (a - b).abs() > 1e-12);
287        assert!(
288            differs,
289            "Different activations should produce different outputs"
290        );
291    }
292
293    #[test]
294    #[should_panic]
295    fn test_transpose_forward_panics_wrong_input_length() {
296        let mut rng = make_rng();
297        let layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
298        let _ = layer.transpose_forward(&vec![0.5, -0.5], Activation::Tanh); // wrong length
299    }
300
301    // ── backward tests ─────────────────────────────────────────────
302
303    #[test]
304    fn test_backward_changes_weights() {
305        let mut rng = make_rng();
306        let mut layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
307        let input = vec![1.0, 0.5, -0.5, 0.0];
308        let output = layer.forward(&input);
309        let delta = vec![0.1, -0.2, 0.3];
310        let weights_before = layer.weights.clone();
311        let _ = layer.backward(&input, &output, &delta, 0.01, 1.0);
312        // At least one weight should change
313        let changed = (0..3).any(|r| {
314            (0..4).any(|c| (layer.weights.get(r, c) - weights_before.get(r, c)).abs() > 1e-15)
315        });
316        assert!(changed, "Weights should change after backward");
317    }
318
319    #[test]
320    fn test_backward_changes_bias() {
321        let mut rng = make_rng();
322        let mut layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
323        let input = vec![1.0, 0.5, -0.5, 0.0];
324        let output = layer.forward(&input);
325        let delta = vec![0.1, -0.2, 0.3];
326        let bias_before = layer.bias.clone();
327        let _ = layer.backward(&input, &output, &delta, 0.01, 1.0);
328        let changed = layer
329            .bias
330            .iter()
331            .zip(bias_before.iter())
332            .any(|(a, b)| (a - b).abs() > 1e-15);
333        assert!(changed, "Bias should change after backward");
334    }
335
336    #[test]
337    fn test_backward_returns_delta_of_correct_length() {
338        let mut rng = make_rng();
339        let mut layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
340        let input = vec![1.0, 0.5, -0.5, 0.0];
341        let output = layer.forward(&input);
342        let delta = vec![0.1, -0.2, 0.3];
343        let prop_delta = layer.backward(&input, &output, &delta, 0.01, 1.0);
344        assert_eq!(prop_delta.len(), 4);
345    }
346
347    #[test]
348    fn test_backward_clips_weights_to_weight_clip() {
349        let mut rng = make_rng();
350        let mut layer: Layer = Layer::new(4, 3, Activation::Linear, &mut rng);
351        let input = vec![100.0, 100.0, 100.0, 100.0];
352        let output = layer.forward(&input);
353        let delta = vec![1e6, 1e6, 1e6];
354        let _ = layer.backward(&input, &output, &delta, 1.0, 1.0);
355        for r in 0..3 {
356            for c in 0..4 {
357                let w = layer.weights.get(r, c);
358                assert!(
359                    w.abs() <= WEIGHT_CLIP + 1e-12,
360                    "Weight {w} exceeds WEIGHT_CLIP"
361                );
362            }
363        }
364        for &b in &layer.bias {
365            assert!(
366                b.abs() <= WEIGHT_CLIP + 1e-12,
367                "Bias {b} exceeds WEIGHT_CLIP"
368            );
369        }
370    }
371
372    #[test]
373    fn test_backward_returns_finite_delta() {
374        let mut rng = make_rng();
375        let mut layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
376        let input = vec![1.0, 0.5, -0.5, 0.0];
377        let output = layer.forward(&input);
378        let delta = vec![0.1, -0.2, 0.3];
379        let prop_delta = layer.backward(&input, &output, &delta, 0.01, 1.0);
380        for &v in &prop_delta {
381            assert!(v.is_finite(), "Propagated delta {v} is not finite");
382        }
383    }
384
385    #[test]
386    fn test_backward_zero_lr_does_not_change_weights() {
387        let mut rng = make_rng();
388        let mut layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
389        let input = vec![1.0, 0.5, -0.5, 0.0];
390        let output = layer.forward(&input);
391        let delta = vec![0.1, -0.2, 0.3];
392        let weights_before = layer.weights.clone();
393        let bias_before = layer.bias.clone();
394        let _ = layer.backward(&input, &output, &delta, 0.0, 1.0);
395        for r in 0..3 {
396            for c in 0..4 {
397                assert!(
398                    (layer.weights.get(r, c) - weights_before.get(r, c)).abs() < 1e-15,
399                    "Weights changed with zero lr"
400                );
401            }
402        }
403        for (a, b) in layer.bias.iter().zip(bias_before.iter()) {
404            assert!((a - b).abs() < 1e-15, "Bias changed with zero lr");
405        }
406    }
407
408    // ── serde test ─────────────────────────────────────────────────
409
410    #[test]
411    fn test_serde_roundtrip_preserves_weights_and_activation() {
412        let mut rng = make_rng();
413        let layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
414        let json = serde_json::to_string(&layer).unwrap();
415        let restored: Layer = serde_json::from_str(&json).unwrap();
416        assert_eq!(layer.bias, restored.bias);
417        assert_eq!(layer.activation, restored.activation);
418        for r in 0..3 {
419            for c in 0..4 {
420                assert!(
421                    (layer.weights.get(r, c) - restored.weights.get(r, c)).abs() < 1e-15,
422                    "Weights not preserved in serde roundtrip"
423                );
424            }
425        }
426    }
427}