1use rand::Rng;
12use serde::{Deserialize, Serialize};
13
14use crate::activation::Activation;
15use crate::linalg::cpu::CpuLinAlg;
16use crate::linalg::LinAlg;
17use crate::matrix::{GRAD_CLIP, WEIGHT_CLIP};
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct LayerDef {
32 pub size: usize,
34 pub activation: Activation,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
59#[serde(bound(
60 serialize = "L::Matrix: Serialize, L::Vector: Serialize",
61 deserialize = "L::Matrix: for<'a> Deserialize<'a>, L::Vector: for<'a> Deserialize<'a>"
62))]
63pub struct Layer<L: LinAlg = CpuLinAlg> {
64 pub weights: L::Matrix,
66 pub bias: L::Vector,
68 pub activation: Activation,
70}
71
72impl<L: LinAlg> Layer<L> {
73 pub fn new(
82 input_size: usize,
83 output_size: usize,
84 activation: Activation,
85 rng: &mut impl Rng,
86 ) -> Self {
87 Self {
88 weights: L::xavier_mat(output_size, input_size, rng),
89 bias: L::zeros_vec(output_size),
90 activation,
91 }
92 }
93
94 pub fn forward(&self, input: &L::Vector) -> L::Vector {
100 let linear = L::mat_vec_mul(&self.weights, input);
101 let biased = L::vec_add(&linear, &self.bias);
102 L::apply_activation(&biased, self.activation)
103 }
104
105 pub fn transpose_forward(&self, input: &L::Vector, activation: Activation) -> L::Vector {
115 let wt = L::mat_transpose(&self.weights);
116 let linear = L::mat_vec_mul(&wt, input);
117 L::apply_activation(&linear, activation)
118 }
119
120 pub fn backward(
136 &mut self,
137 input: &L::Vector,
138 output: &L::Vector,
139 delta: &L::Vector,
140 lr: f64,
141 surprise_scale: f64,
142 ) -> L::Vector {
143 let deriv = L::apply_derivative(output, self.activation);
145
146 let mut grad = L::vec_hadamard(delta, &deriv);
148
149 L::clip_vec(&mut grad, GRAD_CLIP);
151
152 let effective_lr = lr * surprise_scale;
154
155 let dw = L::outer_product(&grad, input);
157
158 L::mat_scale_add(&mut self.weights, &dw, -effective_lr);
160
161 let bias_update = L::vec_scale(&grad, effective_lr);
163 let new_bias = L::vec_sub(&self.bias, &bias_update);
164 self.bias = new_bias;
165 L::clip_vec(&mut self.bias, WEIGHT_CLIP);
166
167 let wt = L::mat_transpose(&self.weights);
169 L::mat_vec_mul(&wt, &grad)
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use rand::rngs::StdRng;
177 use rand::SeedableRng;
178
179 fn make_rng() -> StdRng {
180 StdRng::seed_from_u64(42)
181 }
182
183 #[test]
186 fn test_forward_output_length_equals_output_size() {
187 let mut rng = make_rng();
188 let layer: Layer = Layer::new(4, 3, Activation::Linear, &mut rng);
189 let out = layer.forward(&vec![1.0, 0.0, -1.0, 0.5]);
190 assert_eq!(out.len(), 3);
191 }
192
193 #[test]
194 fn test_forward_linear_known_value() {
195 let mut rng = make_rng();
196 let mut layer: Layer = Layer::new(2, 1, Activation::Linear, &mut rng);
197 layer.weights.set(0, 0, 2.0);
199 layer.weights.set(0, 1, 3.0);
200 layer.bias[0] = 1.0;
201 let out = layer.forward(&vec![1.0, 2.0]);
203 assert!((out[0] - 9.0).abs() < 1e-12);
204 }
205
206 #[test]
207 fn test_forward_tanh_output_bounded() {
208 let mut rng = make_rng();
209 let layer: Layer = Layer::new(4, 5, Activation::Tanh, &mut rng);
210 let out = layer.forward(&vec![10.0, -10.0, 5.0, -5.0]);
211 for &v in &out {
212 assert!(v > -1.0 && v < 1.0, "Tanh output {v} not in (-1,1)");
213 }
214 }
215
216 #[test]
217 fn test_forward_sigmoid_output_bounded() {
218 let mut rng = make_rng();
219 let layer: Layer = Layer::new(4, 5, Activation::Sigmoid, &mut rng);
220 let out = layer.forward(&vec![10.0, -10.0, 5.0, -5.0]);
221 for &v in &out {
222 assert!(v > 0.0 && v < 1.0, "Sigmoid output {v} not in (0,1)");
223 }
224 }
225
226 #[test]
227 fn test_forward_relu_no_negative_outputs() {
228 let mut rng = make_rng();
229 let layer: Layer = Layer::new(4, 5, Activation::Relu, &mut rng);
230 let out = layer.forward(&vec![10.0, -10.0, 5.0, -5.0]);
231 for &v in &out {
232 assert!(v >= 0.0, "ReLU output {v} is negative");
233 }
234 }
235
236 #[test]
237 fn test_forward_all_outputs_finite() {
238 let mut rng = make_rng();
239 let layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
240 let out = layer.forward(&vec![1e6, -1e6, 1e3, -1e3]);
241 for &v in &out {
242 assert!(v.is_finite(), "Output {v} is not finite");
243 }
244 }
245
246 #[test]
247 #[should_panic]
248 fn test_forward_panics_wrong_input_length() {
249 let mut rng = make_rng();
250 let layer: Layer = Layer::new(4, 3, Activation::Linear, &mut rng);
251 let _ = layer.forward(&vec![1.0, 2.0]); }
253
254 #[test]
257 fn test_transpose_forward_output_length_equals_input_size() {
258 let mut rng = make_rng();
259 let layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
260 let out = layer.transpose_forward(&vec![0.5, -0.5, 0.0], Activation::Tanh);
262 assert_eq!(out.len(), 4);
263 }
264
265 #[test]
266 fn test_transpose_forward_all_finite() {
267 let mut rng = make_rng();
268 let layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
269 let out = layer.transpose_forward(&vec![1e3, -1e3, 0.0], Activation::Tanh);
270 for &v in &out {
271 assert!(v.is_finite(), "transpose_forward output {v} is not finite");
272 }
273 }
274
275 #[test]
276 fn test_transpose_forward_different_activation_changes_output() {
277 let mut rng = make_rng();
278 let layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
279 let input = vec![0.5, -0.5, 0.3];
280 let out_tanh = layer.transpose_forward(&input, Activation::Tanh);
281 let out_linear = layer.transpose_forward(&input, Activation::Linear);
282 let differs = out_tanh
284 .iter()
285 .zip(out_linear.iter())
286 .any(|(a, b)| (a - b).abs() > 1e-12);
287 assert!(
288 differs,
289 "Different activations should produce different outputs"
290 );
291 }
292
293 #[test]
294 #[should_panic]
295 fn test_transpose_forward_panics_wrong_input_length() {
296 let mut rng = make_rng();
297 let layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
298 let _ = layer.transpose_forward(&vec![0.5, -0.5], Activation::Tanh); }
300
301 #[test]
304 fn test_backward_changes_weights() {
305 let mut rng = make_rng();
306 let mut layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
307 let input = vec![1.0, 0.5, -0.5, 0.0];
308 let output = layer.forward(&input);
309 let delta = vec![0.1, -0.2, 0.3];
310 let weights_before = layer.weights.clone();
311 let _ = layer.backward(&input, &output, &delta, 0.01, 1.0);
312 let changed = (0..3).any(|r| {
314 (0..4).any(|c| (layer.weights.get(r, c) - weights_before.get(r, c)).abs() > 1e-15)
315 });
316 assert!(changed, "Weights should change after backward");
317 }
318
319 #[test]
320 fn test_backward_changes_bias() {
321 let mut rng = make_rng();
322 let mut layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
323 let input = vec![1.0, 0.5, -0.5, 0.0];
324 let output = layer.forward(&input);
325 let delta = vec![0.1, -0.2, 0.3];
326 let bias_before = layer.bias.clone();
327 let _ = layer.backward(&input, &output, &delta, 0.01, 1.0);
328 let changed = layer
329 .bias
330 .iter()
331 .zip(bias_before.iter())
332 .any(|(a, b)| (a - b).abs() > 1e-15);
333 assert!(changed, "Bias should change after backward");
334 }
335
336 #[test]
337 fn test_backward_returns_delta_of_correct_length() {
338 let mut rng = make_rng();
339 let mut layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
340 let input = vec![1.0, 0.5, -0.5, 0.0];
341 let output = layer.forward(&input);
342 let delta = vec![0.1, -0.2, 0.3];
343 let prop_delta = layer.backward(&input, &output, &delta, 0.01, 1.0);
344 assert_eq!(prop_delta.len(), 4);
345 }
346
347 #[test]
348 fn test_backward_clips_weights_to_weight_clip() {
349 let mut rng = make_rng();
350 let mut layer: Layer = Layer::new(4, 3, Activation::Linear, &mut rng);
351 let input = vec![100.0, 100.0, 100.0, 100.0];
352 let output = layer.forward(&input);
353 let delta = vec![1e6, 1e6, 1e6];
354 let _ = layer.backward(&input, &output, &delta, 1.0, 1.0);
355 for r in 0..3 {
356 for c in 0..4 {
357 let w = layer.weights.get(r, c);
358 assert!(
359 w.abs() <= WEIGHT_CLIP + 1e-12,
360 "Weight {w} exceeds WEIGHT_CLIP"
361 );
362 }
363 }
364 for &b in &layer.bias {
365 assert!(
366 b.abs() <= WEIGHT_CLIP + 1e-12,
367 "Bias {b} exceeds WEIGHT_CLIP"
368 );
369 }
370 }
371
372 #[test]
373 fn test_backward_returns_finite_delta() {
374 let mut rng = make_rng();
375 let mut layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
376 let input = vec![1.0, 0.5, -0.5, 0.0];
377 let output = layer.forward(&input);
378 let delta = vec![0.1, -0.2, 0.3];
379 let prop_delta = layer.backward(&input, &output, &delta, 0.01, 1.0);
380 for &v in &prop_delta {
381 assert!(v.is_finite(), "Propagated delta {v} is not finite");
382 }
383 }
384
385 #[test]
386 fn test_backward_zero_lr_does_not_change_weights() {
387 let mut rng = make_rng();
388 let mut layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
389 let input = vec![1.0, 0.5, -0.5, 0.0];
390 let output = layer.forward(&input);
391 let delta = vec![0.1, -0.2, 0.3];
392 let weights_before = layer.weights.clone();
393 let bias_before = layer.bias.clone();
394 let _ = layer.backward(&input, &output, &delta, 0.0, 1.0);
395 for r in 0..3 {
396 for c in 0..4 {
397 assert!(
398 (layer.weights.get(r, c) - weights_before.get(r, c)).abs() < 1e-15,
399 "Weights changed with zero lr"
400 );
401 }
402 }
403 for (a, b) in layer.bias.iter().zip(bias_before.iter()) {
404 assert!((a - b).abs() < 1e-15, "Bias changed with zero lr");
405 }
406 }
407
408 #[test]
411 fn test_serde_roundtrip_preserves_weights_and_activation() {
412 let mut rng = make_rng();
413 let layer: Layer = Layer::new(4, 3, Activation::Tanh, &mut rng);
414 let json = serde_json::to_string(&layer).unwrap();
415 let restored: Layer = serde_json::from_str(&json).unwrap();
416 assert_eq!(layer.bias, restored.bias);
417 assert_eq!(layer.activation, restored.activation);
418 for r in 0..3 {
419 for c in 0..4 {
420 assert!(
421 (layer.weights.get(r, c) - restored.weights.get(r, c)).abs() < 1e-15,
422 "Weights not preserved in serde roundtrip"
423 );
424 }
425 }
426 }
427}