use rand::Rng;
use serde::{Deserialize, Serialize};
use crate::activation::Activation;
use crate::linalg::cpu::CpuLinAlg;
use crate::linalg::LinAlg;
use crate::matrix::{GRAD_CLIP, WEIGHT_CLIP};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerDef {
pub size: usize,
pub activation: Activation,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(bound(
serialize = "L::Matrix: Serialize, L::Vector: Serialize",
deserialize = "L::Matrix: for<'a> Deserialize<'a>, L::Vector: for<'a> Deserialize<'a>"
))]
pub struct Layer<L: LinAlg = CpuLinAlg> {
pub weights: L::Matrix,
pub bias: L::Vector,
pub activation: Activation,
}
impl<L: LinAlg> Layer<L> {
pub fn new(
input_size: usize,
output_size: usize,
activation: Activation,
rng: &mut impl Rng,
) -> Self {
Self {
weights: L::xavier_mat(output_size, input_size, rng),
bias: L::zeros_vec(output_size),
activation,
}
}
pub fn forward(&self, input: &L::Vector) -> L::Vector {
let linear = L::mat_vec_mul(&self.weights, input);
let biased = L::vec_add(&linear, &self.bias);
L::apply_activation(&biased, self.activation)
}
pub fn transpose_forward(&self, input: &L::Vector, activation: Activation) -> L::Vector {
let wt = L::mat_transpose(&self.weights);
let linear = L::mat_vec_mul(&wt, input);
L::apply_activation(&linear, activation)
}
pub fn backward(
&mut self,
input: &L::Vector,
output: &L::Vector,
delta: &L::Vector,
lr: f64,
surprise_scale: f64,
) -> L::Vector {
let deriv = L::apply_derivative(output, self.activation);
let mut grad = L::vec_hadamard(delta, &deriv);
L::clip_vec(&mut grad, GRAD_CLIP);
let effective_lr = lr * surprise_scale;
let dw = L::outer_product(&grad, input);
L::mat_scale_add(&mut self.weights, &dw, -effective_lr);
let bias_update = L::vec_scale(&grad, effective_lr);
let new_bias = L::vec_sub(&self.bias, &bias_update);
self.bias = new_bias;
L::clip_vec(&mut self.bias, WEIGHT_CLIP);
let wt = L::mat_transpose(&self.weights);
L::mat_vec_mul(&wt, &grad)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
fn make_rng() -> StdRng {
StdRng::seed_from_u64(42)
}
#[test]
fn test_forward_output_length_equals_output_size() {
let mut rng = make_rng();
let layer: Layer = Layer::new(4, 3, Activation::Linear, &mut rng);
let out = layer.forward(&vec![1.0, 0.0, -1.0, 0.5]);
assert_eq!(out.len(), 3);
}
#[test]
fn test_forward_linear_known_value() {
let mut rng = make_rng();
let mut layer: Layer = Layer::new(2, 1, Activation::Linear, &mut rng);
layer.weights.set(0, 0, 2.0);
layer.weights.set(0, 1, 3.0);
layer.bias[0] = 1.0;
let out = layer.forward(&vec![1.0, 2.0]);
assert!((out[0] - 9.0).abs() < 1e-12);
}
#[test]
fn test_forward_tanh_output_bounded() {
let mut rng = make_rng();
let layer: Layer = Layer::new(4, 5, Activation::Tanh, &mut rng);
let out = layer.forward(&vec![10.0, -10.0, 5.0, -5.0]);
for &v in &out {
assert!(v > -1.0 && v < 1.0, "Tanh output {v} not in (-1,1)");
}
}
#[test]
fn test_forward_sigmoid_output_bounded() {
let mut rng = make_rng();
let layer: Layer = Layer::new(4, 5, Activation::Sigmoid, &mut rng);
let out = layer.forward(&vec![10.0, -10.0, 5.0, -5.0]);
for &v in &out {
assert!(v > 0.0 && v < 1.0, "Sigmoid output {v} not in (0,1)");
}
}
#[test]
fn test_forward_relu_no_negative_outputs() {
let mut rng = make_rng();
let layer: Layer = Layer::new(4, 5, Activation::Relu, &mut rng);
let out = layer.forward(&vec![10.0, -10.0, 5.0, -5.0]);
for &v in &out {
assert!(v >= 0.0, "ReLU output {v} is negative");
}
}
#[test]
fn test_forward_all_outputs_finite() {
let mut rng = make_rng();
let layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
let out = layer.forward(&vec![1e6, -1e6, 1e3, -1e3]);
for &v in &out {
assert!(v.is_finite(), "Output {v} is not finite");
}
}
#[test]
#[should_panic]
fn test_forward_panics_wrong_input_length() {
let mut rng = make_rng();
let layer: Layer = Layer::new(4, 3, Activation::Linear, &mut rng);
let _ = layer.forward(&vec![1.0, 2.0]); }
#[test]
fn test_transpose_forward_output_length_equals_input_size() {
let mut rng = make_rng();
let layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
let out = layer.transpose_forward(&vec![0.5, -0.5, 0.0], Activation::Tanh);
assert_eq!(out.len(), 4);
}
#[test]
fn test_transpose_forward_all_finite() {
let mut rng = make_rng();
let layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
let out = layer.transpose_forward(&vec![1e3, -1e3, 0.0], Activation::Tanh);
for &v in &out {
assert!(v.is_finite(), "transpose_forward output {v} is not finite");
}
}
#[test]
fn test_transpose_forward_different_activation_changes_output() {
let mut rng = make_rng();
let layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
let input = vec![0.5, -0.5, 0.3];
let out_tanh = layer.transpose_forward(&input, Activation::Tanh);
let out_linear = layer.transpose_forward(&input, Activation::Linear);
let differs = out_tanh
.iter()
.zip(out_linear.iter())
.any(|(a, b)| (a - b).abs() > 1e-12);
assert!(
differs,
"Different activations should produce different outputs"
);
}
#[test]
#[should_panic]
fn test_transpose_forward_panics_wrong_input_length() {
let mut rng = make_rng();
let layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
let _ = layer.transpose_forward(&vec![0.5, -0.5], Activation::Tanh); }
#[test]
fn test_backward_changes_weights() {
let mut rng = make_rng();
let mut layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
let input = vec![1.0, 0.5, -0.5, 0.0];
let output = layer.forward(&input);
let delta = vec![0.1, -0.2, 0.3];
let weights_before = layer.weights.clone();
let _ = layer.backward(&input, &output, &delta, 0.01, 1.0);
let changed = (0..3).any(|r| {
(0..4).any(|c| (layer.weights.get(r, c) - weights_before.get(r, c)).abs() > 1e-15)
});
assert!(changed, "Weights should change after backward");
}
#[test]
fn test_backward_changes_bias() {
let mut rng = make_rng();
let mut layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
let input = vec![1.0, 0.5, -0.5, 0.0];
let output = layer.forward(&input);
let delta = vec![0.1, -0.2, 0.3];
let bias_before = layer.bias.clone();
let _ = layer.backward(&input, &output, &delta, 0.01, 1.0);
let changed = layer
.bias
.iter()
.zip(bias_before.iter())
.any(|(a, b)| (a - b).abs() > 1e-15);
assert!(changed, "Bias should change after backward");
}
#[test]
fn test_backward_returns_delta_of_correct_length() {
let mut rng = make_rng();
let mut layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
let input = vec![1.0, 0.5, -0.5, 0.0];
let output = layer.forward(&input);
let delta = vec![0.1, -0.2, 0.3];
let prop_delta = layer.backward(&input, &output, &delta, 0.01, 1.0);
assert_eq!(prop_delta.len(), 4);
}
#[test]
fn test_backward_clips_weights_to_weight_clip() {
let mut rng = make_rng();
let mut layer: Layer = Layer::new(4, 3, Activation::Linear, &mut rng);
let input = vec![100.0, 100.0, 100.0, 100.0];
let output = layer.forward(&input);
let delta = vec![1e6, 1e6, 1e6];
let _ = layer.backward(&input, &output, &delta, 1.0, 1.0);
for r in 0..3 {
for c in 0..4 {
let w = layer.weights.get(r, c);
assert!(
w.abs() <= WEIGHT_CLIP + 1e-12,
"Weight {w} exceeds WEIGHT_CLIP"
);
}
}
for &b in &layer.bias {
assert!(
b.abs() <= WEIGHT_CLIP + 1e-12,
"Bias {b} exceeds WEIGHT_CLIP"
);
}
}
#[test]
fn test_backward_returns_finite_delta() {
let mut rng = make_rng();
let mut layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
let input = vec![1.0, 0.5, -0.5, 0.0];
let output = layer.forward(&input);
let delta = vec![0.1, -0.2, 0.3];
let prop_delta = layer.backward(&input, &output, &delta, 0.01, 1.0);
for &v in &prop_delta {
assert!(v.is_finite(), "Propagated delta {v} is not finite");
}
}
#[test]
fn test_backward_zero_lr_does_not_change_weights() {
let mut rng = make_rng();
let mut layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
let input = vec![1.0, 0.5, -0.5, 0.0];
let output = layer.forward(&input);
let delta = vec![0.1, -0.2, 0.3];
let weights_before = layer.weights.clone();
let bias_before = layer.bias.clone();
let _ = layer.backward(&input, &output, &delta, 0.0, 1.0);
for r in 0..3 {
for c in 0..4 {
assert!(
(layer.weights.get(r, c) - weights_before.get(r, c)).abs() < 1e-15,
"Weights changed with zero lr"
);
}
}
for (a, b) in layer.bias.iter().zip(bias_before.iter()) {
assert!((a - b).abs() < 1e-15, "Bias changed with zero lr");
}
}
#[test]
fn test_serde_roundtrip_preserves_weights_and_activation() {
let mut rng = make_rng();
let layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
let json = serde_json::to_string(&layer).unwrap();
let restored: Layer = serde_json::from_str(&json).unwrap();
assert_eq!(layer.bias, restored.bias);
assert_eq!(layer.activation, restored.activation);
for r in 0..3 {
for c in 0..4 {
assert!(
(layer.weights.get(r, c) - restored.weights.get(r, c)).abs() < 1e-15,
"Weights not preserved in serde roundtrip"
);
}
}
}
}