Skip to main content

pc_rl_core/
layer.rs

1// Author: Julian Bolivar
2// Version: 1.0.0
3// Date: 2026-03-25
4
5//! Dense neural network layer.
6//!
7//! Provides forward propagation, transpose forward (PC top-down pass),
8//! and backward propagation with gradient/weight clipping. Building
9//! block for both [`crate::PcActor`] and [`crate::MlpCritic`].
10
11use rand::Rng;
12use serde::{Deserialize, Serialize};
13
14use crate::activation::Activation;
15use crate::linalg::cpu::CpuLinAlg;
16use crate::linalg::LinAlg;
17use crate::matrix::{GRAD_CLIP, WEIGHT_CLIP};
18
19/// Definition of a layer's shape and activation, used for topology configuration.
20///
21/// # Examples
22///
23/// ```
24/// use pc_rl_core::activation::Activation;
25/// use pc_rl_core::layer::LayerDef;
26///
27/// let def = LayerDef { size: 64, activation: Activation::Tanh };
28/// assert_eq!(def.size, 64);
29/// ```
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct LayerDef {
32    /// Number of neurons in this layer.
33    pub size: usize,
34    /// Activation function applied after the linear transform.
35    pub activation: Activation,
36}
37
38/// A single dense layer with weights, bias, and activation function.
39///
40/// Generic over a [`LinAlg`] backend `L`. Defaults to [`CpuLinAlg`] for
41/// backward compatibility.
42///
43/// Weights have shape `[output_size × input_size]`. Bias has length `output_size`.
44///
45/// # Serde and the `backend` field
46///
47/// The `backend` field is annotated with `#[serde(skip, default)]` so it is
48/// **not** included in the JSON format. On deserialization, `L::default()` is
49/// used to fill the field. However, direct serde deserialization of
50/// `Layer<GpuLinAlg>` is **not** the intended usage path. Use
51/// [`PcActor::from_weights`](crate::pc_actor::PcActor::from_weights) or
52/// [`MlpCritic::from_weights`](crate::mlp_critic::MlpCritic::from_weights)
53/// instead, which inject the correct backend instance into every layer.
54///
55/// # Examples
56///
57/// ```
58/// use pc_rl_core::activation::Activation;
59/// use pc_rl_core::layer::Layer;
60/// use pc_rl_core::linalg::cpu::CpuLinAlg;
61/// use rand::SeedableRng;
62/// use rand::rngs::StdRng;
63///
64/// let backend = CpuLinAlg::new();
65/// let mut rng = StdRng::seed_from_u64(42);
66/// let layer: Layer = Layer::new(4, 3, Activation::Tanh, &backend, &mut rng);
67/// let output: Vec<f64> = layer.forward(&vec![1.0, 0.0, -1.0, 0.5]);
68/// assert_eq!(output.len(), 3);
69/// ```
70#[derive(Debug, Clone, Serialize, Deserialize)]
71#[serde(bound(
72    serialize = "L::Matrix: Serialize, L::Vector: Serialize",
73    deserialize = "L::Matrix: for<'a> Deserialize<'a>, L::Vector: for<'a> Deserialize<'a>, L: Default"
74))]
75pub struct Layer<L: LinAlg = CpuLinAlg> {
76    /// Weight matrix of shape `[output_size × input_size]`.
77    pub weights: L::Matrix,
78    /// Bias vector of length `output_size`.
79    pub bias: L::Vector,
80    /// Activation function applied element-wise after the linear transform.
81    pub activation: Activation,
82    /// Backend used for linear algebra operations.
83    #[serde(skip, default)]
84    pub(crate) backend: L,
85}
86
87impl<L: LinAlg> Layer<L> {
88    /// Creates a new layer with Xavier-initialized weights and zero bias.
89    ///
90    /// # Arguments
91    ///
92    /// * `input_size` - Number of inputs to this layer.
93    /// * `output_size` - Number of neurons (outputs) in this layer.
94    /// * `activation` - Activation function to apply after the linear transform.
95    /// * `rng` - Random number generator for weight initialization.
96    pub fn new(
97        input_size: usize,
98        output_size: usize,
99        activation: Activation,
100        backend: &L,
101        rng: &mut impl Rng,
102    ) -> Self {
103        Self {
104            weights: backend.xavier_mat(output_size, input_size, rng),
105            bias: backend.zeros_vec(output_size),
106            activation,
107            backend: backend.clone(),
108        }
109    }
110
111    /// Computes `activation(W * input + bias)`.
112    ///
113    /// # Panics
114    ///
115    /// Panics if `input.len() != input_size` (number of columns in weights).
116    pub fn forward(&self, input: &L::Vector) -> L::Vector {
117        let linear = self.backend.mat_vec_mul(&self.weights, input);
118        let biased = self.backend.vec_add(&linear, &self.bias);
119        self.backend.apply_activation(&biased, self.activation)
120    }
121
122    /// Computes `activation(W^T * input)` (no bias).
123    ///
124    /// Used for PC top-down predictions. The `activation` parameter is
125    /// separate from `self.activation` because at the output→last-hidden
126    /// boundary, different activations may apply.
127    ///
128    /// # Panics
129    ///
130    /// Panics if `input.len() != output_size` (number of rows in weights).
131    pub fn transpose_forward(&self, input: &L::Vector, activation: Activation) -> L::Vector {
132        let wt = self.backend.mat_transpose(&self.weights);
133        let linear = self.backend.mat_vec_mul(&wt, input);
134        self.backend.apply_activation(&linear, activation)
135    }
136
137    /// Backpropagation with gradient and weight clipping.
138    ///
139    /// Returns the propagated delta for the layer below (length = input_size).
140    ///
141    /// # Arguments
142    ///
143    /// * `input` - Input that was fed to this layer during forward pass.
144    /// * `output` - Output of this layer from the forward pass (post-activation).
145    /// * `delta` - Error signal from the layer above.
146    /// * `lr` - Base learning rate.
147    /// * `surprise_scale` - Multiplier on `lr` based on surprise score.
148    ///
149    /// # Panics
150    ///
151    /// Panics on dimension mismatches.
152    pub fn backward(
153        &mut self,
154        input: &L::Vector,
155        output: &L::Vector,
156        delta: &L::Vector,
157        lr: f64,
158        surprise_scale: f64,
159    ) -> L::Vector {
160        // 1. Activation derivative
161        let deriv = self.backend.apply_derivative(output, self.activation);
162
163        // 2. Local gradient = delta * deriv (element-wise Hadamard product)
164        let mut grad = self.backend.vec_hadamard(delta, &deriv);
165
166        // 3. Clip gradient
167        self.backend.clip_vec(&mut grad, GRAD_CLIP);
168
169        // 4. Effective learning rate
170        let effective_lr = lr * surprise_scale;
171
172        // 5. Weight gradient: dW = outer(grad, input)
173        let dw = self.backend.outer_product(&grad, input);
174
175        // 6. Update weights (scale_add includes WEIGHT_CLIP clamping)
176        self.backend
177            .mat_scale_add(&mut self.weights, &dw, -effective_lr);
178
179        // 7. Update bias with clamping
180        let bias_update = self.backend.vec_scale(&grad, effective_lr);
181        let new_bias = self.backend.vec_sub(&self.bias, &bias_update);
182        self.bias = new_bias;
183        self.backend.clip_vec(&mut self.bias, WEIGHT_CLIP);
184
185        // 8. Propagated delta: W^T * grad
186        let wt = self.backend.mat_transpose(&self.weights);
187        self.backend.mat_vec_mul(&wt, &grad)
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use rand::rngs::StdRng;
195    use rand::SeedableRng;
196
197    fn make_rng() -> StdRng {
198        StdRng::seed_from_u64(42)
199    }
200
201    fn make_backend() -> CpuLinAlg {
202        CpuLinAlg::new()
203    }
204
205    // ── forward tests ──────────────────────────────────────────────
206
207    #[test]
208    fn test_forward_output_length_equals_output_size() {
209        let mut rng = make_rng();
210        let backend = make_backend();
211        let layer: Layer = Layer::new(4, 3, Activation::Linear, &backend, &mut rng);
212        let out = layer.forward(&vec![1.0, 0.0, -1.0, 0.5]);
213        assert_eq!(out.len(), 3);
214    }
215
216    #[test]
217    fn test_forward_linear_known_value() {
218        let mut rng = make_rng();
219        let backend = make_backend();
220        let mut layer: Layer = Layer::new(2, 1, Activation::Linear, &backend, &mut rng);
221        // Set known weights and bias
222        layer.weights.set(0, 0, 2.0);
223        layer.weights.set(0, 1, 3.0);
224        layer.bias[0] = 1.0;
225        // output = 2*1 + 3*2 + 1 = 9
226        let out = layer.forward(&vec![1.0, 2.0]);
227        assert!((out[0] - 9.0).abs() < 1e-12);
228    }
229
230    #[test]
231    fn test_forward_tanh_output_bounded() {
232        let mut rng = make_rng();
233        let backend = make_backend();
234        let layer: Layer = Layer::new(4, 5, Activation::Tanh, &backend, &mut rng);
235        let out = layer.forward(&vec![10.0, -10.0, 5.0, -5.0]);
236        for &v in &out {
237            assert!(v > -1.0 && v < 1.0, "Tanh output {v} not in (-1,1)");
238        }
239    }
240
241    #[test]
242    fn test_forward_sigmoid_output_bounded() {
243        let mut rng = make_rng();
244        let backend = make_backend();
245        let layer: Layer = Layer::new(4, 5, Activation::Sigmoid, &backend, &mut rng);
246        let out = layer.forward(&vec![10.0, -10.0, 5.0, -5.0]);
247        for &v in &out {
248            assert!(v > 0.0 && v < 1.0, "Sigmoid output {v} not in (0,1)");
249        }
250    }
251
252    #[test]
253    fn test_forward_relu_no_negative_outputs() {
254        let mut rng = make_rng();
255        let backend = make_backend();
256        let layer: Layer = Layer::new(4, 5, Activation::Relu, &backend, &mut rng);
257        let out = layer.forward(&vec![10.0, -10.0, 5.0, -5.0]);
258        for &v in &out {
259            assert!(v >= 0.0, "ReLU output {v} is negative");
260        }
261    }
262
263    #[test]
264    fn test_forward_all_outputs_finite() {
265        let mut rng = make_rng();
266        let backend = make_backend();
267        let layer: Layer = Layer::new(4, 3, Activation::Tanh, &backend, &mut rng);
268        let out = layer.forward(&vec![1e6, -1e6, 1e3, -1e3]);
269        for &v in &out {
270            assert!(v.is_finite(), "Output {v} is not finite");
271        }
272    }
273
274    #[test]
275    #[should_panic]
276    fn test_forward_panics_wrong_input_length() {
277        let mut rng = make_rng();
278        let backend = make_backend();
279        let layer: Layer = Layer::new(4, 3, Activation::Linear, &backend, &mut rng);
280        let _ = layer.forward(&vec![1.0, 2.0]); // wrong length
281    }
282
283    // ── transpose_forward tests ────────────────────────────────────
284
285    #[test]
286    fn test_transpose_forward_output_length_equals_input_size() {
287        let mut rng = make_rng();
288        let backend = make_backend();
289        let layer: Layer = Layer::new(4, 3, Activation::Tanh, &backend, &mut rng);
290        // transpose_forward takes output_size input, returns input_size
291        let out = layer.transpose_forward(&vec![0.5, -0.5, 0.0], Activation::Tanh);
292        assert_eq!(out.len(), 4);
293    }
294
295    #[test]
296    fn test_transpose_forward_all_finite() {
297        let mut rng = make_rng();
298        let backend = make_backend();
299        let layer: Layer = Layer::new(4, 3, Activation::Tanh, &backend, &mut rng);
300        let out = layer.transpose_forward(&vec![1e3, -1e3, 0.0], Activation::Tanh);
301        for &v in &out {
302            assert!(v.is_finite(), "transpose_forward output {v} is not finite");
303        }
304    }
305
306    #[test]
307    fn test_transpose_forward_different_activation_changes_output() {
308        let mut rng = make_rng();
309        let backend = make_backend();
310        let layer: Layer = Layer::new(4, 3, Activation::Tanh, &backend, &mut rng);
311        let input = vec![0.5, -0.5, 0.3];
312        let out_tanh = layer.transpose_forward(&input, Activation::Tanh);
313        let out_linear = layer.transpose_forward(&input, Activation::Linear);
314        // At least one element should differ
315        let differs = out_tanh
316            .iter()
317            .zip(out_linear.iter())
318            .any(|(a, b)| (a - b).abs() > 1e-12);
319        assert!(
320            differs,
321            "Different activations should produce different outputs"
322        );
323    }
324
325    #[test]
326    #[should_panic]
327    fn test_transpose_forward_panics_wrong_input_length() {
328        let mut rng = make_rng();
329        let backend = make_backend();
330        let layer: Layer = Layer::new(4, 3, Activation::Tanh, &backend, &mut rng);
331        let _ = layer.transpose_forward(&vec![0.5, -0.5], Activation::Tanh); // wrong length
332    }
333
334    // ── backward tests ─────────────────────────────────────────────
335
336    #[test]
337    fn test_backward_changes_weights() {
338        let mut rng = make_rng();
339        let backend = make_backend();
340        let mut layer: Layer = Layer::new(4, 3, Activation::Tanh, &backend, &mut rng);
341        let input = vec![1.0, 0.5, -0.5, 0.0];
342        let output = layer.forward(&input);
343        let delta = vec![0.1, -0.2, 0.3];
344        let weights_before = layer.weights.clone();
345        let _ = layer.backward(&input, &output, &delta, 0.01, 1.0);
346        // At least one weight should change
347        let changed = (0..3).any(|r| {
348            (0..4).any(|c| (layer.weights.get(r, c) - weights_before.get(r, c)).abs() > 1e-15)
349        });
350        assert!(changed, "Weights should change after backward");
351    }
352
353    #[test]
354    fn test_backward_changes_bias() {
355        let mut rng = make_rng();
356        let backend = make_backend();
357        let mut layer: Layer = Layer::new(4, 3, Activation::Tanh, &backend, &mut rng);
358        let input = vec![1.0, 0.5, -0.5, 0.0];
359        let output = layer.forward(&input);
360        let delta = vec![0.1, -0.2, 0.3];
361        let bias_before = layer.bias.clone();
362        let _ = layer.backward(&input, &output, &delta, 0.01, 1.0);
363        let changed = layer
364            .bias
365            .iter()
366            .zip(bias_before.iter())
367            .any(|(a, b)| (a - b).abs() > 1e-15);
368        assert!(changed, "Bias should change after backward");
369    }
370
371    #[test]
372    fn test_backward_returns_delta_of_correct_length() {
373        let mut rng = make_rng();
374        let backend = make_backend();
375        let mut layer: Layer = Layer::new(4, 3, Activation::Tanh, &backend, &mut rng);
376        let input = vec![1.0, 0.5, -0.5, 0.0];
377        let output = layer.forward(&input);
378        let delta = vec![0.1, -0.2, 0.3];
379        let prop_delta = layer.backward(&input, &output, &delta, 0.01, 1.0);
380        assert_eq!(prop_delta.len(), 4);
381    }
382
383    #[test]
384    fn test_backward_clips_weights_to_weight_clip() {
385        let mut rng = make_rng();
386        let backend = make_backend();
387        let mut layer: Layer = Layer::new(4, 3, Activation::Linear, &backend, &mut rng);
388        let input = vec![100.0, 100.0, 100.0, 100.0];
389        let output = layer.forward(&input);
390        let delta = vec![1e6, 1e6, 1e6];
391        let _ = layer.backward(&input, &output, &delta, 1.0, 1.0);
392        for r in 0..3 {
393            for c in 0..4 {
394                let w = layer.weights.get(r, c);
395                assert!(
396                    w.abs() <= WEIGHT_CLIP + 1e-12,
397                    "Weight {w} exceeds WEIGHT_CLIP"
398                );
399            }
400        }
401        for &b in &layer.bias {
402            assert!(
403                b.abs() <= WEIGHT_CLIP + 1e-12,
404                "Bias {b} exceeds WEIGHT_CLIP"
405            );
406        }
407    }
408
409    #[test]
410    fn test_backward_returns_finite_delta() {
411        let mut rng = make_rng();
412        let backend = make_backend();
413        let mut layer: Layer = Layer::new(4, 3, Activation::Tanh, &backend, &mut rng);
414        let input = vec![1.0, 0.5, -0.5, 0.0];
415        let output = layer.forward(&input);
416        let delta = vec![0.1, -0.2, 0.3];
417        let prop_delta = layer.backward(&input, &output, &delta, 0.01, 1.0);
418        for &v in &prop_delta {
419            assert!(v.is_finite(), "Propagated delta {v} is not finite");
420        }
421    }
422
423    #[test]
424    fn test_backward_zero_lr_does_not_change_weights() {
425        let mut rng = make_rng();
426        let backend = make_backend();
427        let mut layer: Layer = Layer::new(4, 3, Activation::Tanh, &backend, &mut rng);
428        let input = vec![1.0, 0.5, -0.5, 0.0];
429        let output = layer.forward(&input);
430        let delta = vec![0.1, -0.2, 0.3];
431        let weights_before = layer.weights.clone();
432        let bias_before = layer.bias.clone();
433        let _ = layer.backward(&input, &output, &delta, 0.0, 1.0);
434        for r in 0..3 {
435            for c in 0..4 {
436                assert!(
437                    (layer.weights.get(r, c) - weights_before.get(r, c)).abs() < 1e-15,
438                    "Weights changed with zero lr"
439                );
440            }
441        }
442        for (a, b) in layer.bias.iter().zip(bias_before.iter()) {
443            assert!((a - b).abs() < 1e-15, "Bias changed with zero lr");
444        }
445    }
446
447    // ── serde test ─────────────────────────────────────────────────
448
449    #[test]
450    fn test_serde_roundtrip_preserves_weights_and_activation() {
451        let mut rng = make_rng();
452        let backend = make_backend();
453        let layer: Layer = Layer::new(4, 3, Activation::Tanh, &backend, &mut rng);
454        let json = serde_json::to_string(&layer).unwrap();
455        let restored: Layer = serde_json::from_str(&json).unwrap();
456        assert_eq!(layer.bias, restored.bias);
457        assert_eq!(layer.activation, restored.activation);
458        for r in 0..3 {
459            for c in 0..4 {
460                assert!(
461                    (layer.weights.get(r, c) - restored.weights.get(r, c)).abs() < 1e-15,
462                    "Weights not preserved in serde roundtrip"
463                );
464            }
465        }
466    }
467}