1use rand::Rng;
12use serde::{Deserialize, Serialize};
13
14use crate::activation::Activation;
15use crate::linalg::cpu::CpuLinAlg;
16use crate::linalg::LinAlg;
17use crate::matrix::{GRAD_CLIP, WEIGHT_CLIP};
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct LayerDef {
32 pub size: usize,
34 pub activation: Activation,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
71#[serde(bound(
72 serialize = "L::Matrix: Serialize, L::Vector: Serialize",
73 deserialize = "L::Matrix: for<'a> Deserialize<'a>, L::Vector: for<'a> Deserialize<'a>, L: Default"
74))]
75pub struct Layer<L: LinAlg = CpuLinAlg> {
76 pub weights: L::Matrix,
78 pub bias: L::Vector,
80 pub activation: Activation,
82 #[serde(skip, default)]
84 pub(crate) backend: L,
85}
86
87impl<L: LinAlg> Layer<L> {
88 pub fn new(
97 input_size: usize,
98 output_size: usize,
99 activation: Activation,
100 backend: &L,
101 rng: &mut impl Rng,
102 ) -> Self {
103 Self {
104 weights: backend.xavier_mat(output_size, input_size, rng),
105 bias: backend.zeros_vec(output_size),
106 activation,
107 backend: backend.clone(),
108 }
109 }
110
111 pub fn forward(&self, input: &L::Vector) -> L::Vector {
117 let linear = self.backend.mat_vec_mul(&self.weights, input);
118 let biased = self.backend.vec_add(&linear, &self.bias);
119 self.backend.apply_activation(&biased, self.activation)
120 }
121
122 pub fn transpose_forward(&self, input: &L::Vector, activation: Activation) -> L::Vector {
132 let wt = self.backend.mat_transpose(&self.weights);
133 let linear = self.backend.mat_vec_mul(&wt, input);
134 self.backend.apply_activation(&linear, activation)
135 }
136
137 pub fn backward(
153 &mut self,
154 input: &L::Vector,
155 output: &L::Vector,
156 delta: &L::Vector,
157 lr: f64,
158 surprise_scale: f64,
159 ) -> L::Vector {
160 let deriv = self.backend.apply_derivative(output, self.activation);
162
163 let mut grad = self.backend.vec_hadamard(delta, &deriv);
165
166 self.backend.clip_vec(&mut grad, GRAD_CLIP);
168
169 let effective_lr = lr * surprise_scale;
171
172 let dw = self.backend.outer_product(&grad, input);
174
175 self.backend
177 .mat_scale_add(&mut self.weights, &dw, -effective_lr);
178
179 let bias_update = self.backend.vec_scale(&grad, effective_lr);
181 let new_bias = self.backend.vec_sub(&self.bias, &bias_update);
182 self.bias = new_bias;
183 self.backend.clip_vec(&mut self.bias, WEIGHT_CLIP);
184
185 let wt = self.backend.mat_transpose(&self.weights);
187 self.backend.mat_vec_mul(&wt, &grad)
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use rand::rngs::StdRng;
195 use rand::SeedableRng;
196
197 fn make_rng() -> StdRng {
198 StdRng::seed_from_u64(42)
199 }
200
201 fn make_backend() -> CpuLinAlg {
202 CpuLinAlg::new()
203 }
204
205 #[test]
208 fn test_forward_output_length_equals_output_size() {
209 let mut rng = make_rng();
210 let backend = make_backend();
211 let layer: Layer = Layer::new(4, 3, Activation::Linear, &backend, &mut rng);
212 let out = layer.forward(&vec![1.0, 0.0, -1.0, 0.5]);
213 assert_eq!(out.len(), 3);
214 }
215
216 #[test]
217 fn test_forward_linear_known_value() {
218 let mut rng = make_rng();
219 let backend = make_backend();
220 let mut layer: Layer = Layer::new(2, 1, Activation::Linear, &backend, &mut rng);
221 layer.weights.set(0, 0, 2.0);
223 layer.weights.set(0, 1, 3.0);
224 layer.bias[0] = 1.0;
225 let out = layer.forward(&vec![1.0, 2.0]);
227 assert!((out[0] - 9.0).abs() < 1e-12);
228 }
229
230 #[test]
231 fn test_forward_tanh_output_bounded() {
232 let mut rng = make_rng();
233 let backend = make_backend();
234 let layer: Layer = Layer::new(4, 5, Activation::Tanh, &backend, &mut rng);
235 let out = layer.forward(&vec![10.0, -10.0, 5.0, -5.0]);
236 for &v in &out {
237 assert!(v > -1.0 && v < 1.0, "Tanh output {v} not in (-1,1)");
238 }
239 }
240
241 #[test]
242 fn test_forward_sigmoid_output_bounded() {
243 let mut rng = make_rng();
244 let backend = make_backend();
245 let layer: Layer = Layer::new(4, 5, Activation::Sigmoid, &backend, &mut rng);
246 let out = layer.forward(&vec![10.0, -10.0, 5.0, -5.0]);
247 for &v in &out {
248 assert!(v > 0.0 && v < 1.0, "Sigmoid output {v} not in (0,1)");
249 }
250 }
251
252 #[test]
253 fn test_forward_relu_no_negative_outputs() {
254 let mut rng = make_rng();
255 let backend = make_backend();
256 let layer: Layer = Layer::new(4, 5, Activation::Relu, &backend, &mut rng);
257 let out = layer.forward(&vec![10.0, -10.0, 5.0, -5.0]);
258 for &v in &out {
259 assert!(v >= 0.0, "ReLU output {v} is negative");
260 }
261 }
262
263 #[test]
264 fn test_forward_all_outputs_finite() {
265 let mut rng = make_rng();
266 let backend = make_backend();
267 let layer: Layer = Layer::new(4, 3, Activation::Tanh, &backend, &mut rng);
268 let out = layer.forward(&vec![1e6, -1e6, 1e3, -1e3]);
269 for &v in &out {
270 assert!(v.is_finite(), "Output {v} is not finite");
271 }
272 }
273
274 #[test]
275 #[should_panic]
276 fn test_forward_panics_wrong_input_length() {
277 let mut rng = make_rng();
278 let backend = make_backend();
279 let layer: Layer = Layer::new(4, 3, Activation::Linear, &backend, &mut rng);
280 let _ = layer.forward(&vec![1.0, 2.0]); }
282
283 #[test]
286 fn test_transpose_forward_output_length_equals_input_size() {
287 let mut rng = make_rng();
288 let backend = make_backend();
289 let layer: Layer = Layer::new(4, 3, Activation::Tanh, &backend, &mut rng);
290 let out = layer.transpose_forward(&vec![0.5, -0.5, 0.0], Activation::Tanh);
292 assert_eq!(out.len(), 4);
293 }
294
295 #[test]
296 fn test_transpose_forward_all_finite() {
297 let mut rng = make_rng();
298 let backend = make_backend();
299 let layer: Layer = Layer::new(4, 3, Activation::Tanh, &backend, &mut rng);
300 let out = layer.transpose_forward(&vec![1e3, -1e3, 0.0], Activation::Tanh);
301 for &v in &out {
302 assert!(v.is_finite(), "transpose_forward output {v} is not finite");
303 }
304 }
305
306 #[test]
307 fn test_transpose_forward_different_activation_changes_output() {
308 let mut rng = make_rng();
309 let backend = make_backend();
310 let layer: Layer = Layer::new(4, 3, Activation::Tanh, &backend, &mut rng);
311 let input = vec![0.5, -0.5, 0.3];
312 let out_tanh = layer.transpose_forward(&input, Activation::Tanh);
313 let out_linear = layer.transpose_forward(&input, Activation::Linear);
314 let differs = out_tanh
316 .iter()
317 .zip(out_linear.iter())
318 .any(|(a, b)| (a - b).abs() > 1e-12);
319 assert!(
320 differs,
321 "Different activations should produce different outputs"
322 );
323 }
324
325 #[test]
326 #[should_panic]
327 fn test_transpose_forward_panics_wrong_input_length() {
328 let mut rng = make_rng();
329 let backend = make_backend();
330 let layer: Layer = Layer::new(4, 3, Activation::Tanh, &backend, &mut rng);
331 let _ = layer.transpose_forward(&vec![0.5, -0.5], Activation::Tanh); }
333
334 #[test]
337 fn test_backward_changes_weights() {
338 let mut rng = make_rng();
339 let backend = make_backend();
340 let mut layer: Layer = Layer::new(4, 3, Activation::Tanh, &backend, &mut rng);
341 let input = vec![1.0, 0.5, -0.5, 0.0];
342 let output = layer.forward(&input);
343 let delta = vec![0.1, -0.2, 0.3];
344 let weights_before = layer.weights.clone();
345 let _ = layer.backward(&input, &output, &delta, 0.01, 1.0);
346 let changed = (0..3).any(|r| {
348 (0..4).any(|c| (layer.weights.get(r, c) - weights_before.get(r, c)).abs() > 1e-15)
349 });
350 assert!(changed, "Weights should change after backward");
351 }
352
353 #[test]
354 fn test_backward_changes_bias() {
355 let mut rng = make_rng();
356 let backend = make_backend();
357 let mut layer: Layer = Layer::new(4, 3, Activation::Tanh, &backend, &mut rng);
358 let input = vec![1.0, 0.5, -0.5, 0.0];
359 let output = layer.forward(&input);
360 let delta = vec![0.1, -0.2, 0.3];
361 let bias_before = layer.bias.clone();
362 let _ = layer.backward(&input, &output, &delta, 0.01, 1.0);
363 let changed = layer
364 .bias
365 .iter()
366 .zip(bias_before.iter())
367 .any(|(a, b)| (a - b).abs() > 1e-15);
368 assert!(changed, "Bias should change after backward");
369 }
370
371 #[test]
372 fn test_backward_returns_delta_of_correct_length() {
373 let mut rng = make_rng();
374 let backend = make_backend();
375 let mut layer: Layer = Layer::new(4, 3, Activation::Tanh, &backend, &mut rng);
376 let input = vec![1.0, 0.5, -0.5, 0.0];
377 let output = layer.forward(&input);
378 let delta = vec![0.1, -0.2, 0.3];
379 let prop_delta = layer.backward(&input, &output, &delta, 0.01, 1.0);
380 assert_eq!(prop_delta.len(), 4);
381 }
382
383 #[test]
384 fn test_backward_clips_weights_to_weight_clip() {
385 let mut rng = make_rng();
386 let backend = make_backend();
387 let mut layer: Layer = Layer::new(4, 3, Activation::Linear, &backend, &mut rng);
388 let input = vec![100.0, 100.0, 100.0, 100.0];
389 let output = layer.forward(&input);
390 let delta = vec![1e6, 1e6, 1e6];
391 let _ = layer.backward(&input, &output, &delta, 1.0, 1.0);
392 for r in 0..3 {
393 for c in 0..4 {
394 let w = layer.weights.get(r, c);
395 assert!(
396 w.abs() <= WEIGHT_CLIP + 1e-12,
397 "Weight {w} exceeds WEIGHT_CLIP"
398 );
399 }
400 }
401 for &b in &layer.bias {
402 assert!(
403 b.abs() <= WEIGHT_CLIP + 1e-12,
404 "Bias {b} exceeds WEIGHT_CLIP"
405 );
406 }
407 }
408
409 #[test]
410 fn test_backward_returns_finite_delta() {
411 let mut rng = make_rng();
412 let backend = make_backend();
413 let mut layer: Layer = Layer::new(4, 3, Activation::Tanh, &backend, &mut rng);
414 let input = vec![1.0, 0.5, -0.5, 0.0];
415 let output = layer.forward(&input);
416 let delta = vec![0.1, -0.2, 0.3];
417 let prop_delta = layer.backward(&input, &output, &delta, 0.01, 1.0);
418 for &v in &prop_delta {
419 assert!(v.is_finite(), "Propagated delta {v} is not finite");
420 }
421 }
422
423 #[test]
424 fn test_backward_zero_lr_does_not_change_weights() {
425 let mut rng = make_rng();
426 let backend = make_backend();
427 let mut layer: Layer = Layer::new(4, 3, Activation::Tanh, &backend, &mut rng);
428 let input = vec![1.0, 0.5, -0.5, 0.0];
429 let output = layer.forward(&input);
430 let delta = vec![0.1, -0.2, 0.3];
431 let weights_before = layer.weights.clone();
432 let bias_before = layer.bias.clone();
433 let _ = layer.backward(&input, &output, &delta, 0.0, 1.0);
434 for r in 0..3 {
435 for c in 0..4 {
436 assert!(
437 (layer.weights.get(r, c) - weights_before.get(r, c)).abs() < 1e-15,
438 "Weights changed with zero lr"
439 );
440 }
441 }
442 for (a, b) in layer.bias.iter().zip(bias_before.iter()) {
443 assert!((a - b).abs() < 1e-15, "Bias changed with zero lr");
444 }
445 }
446
447 #[test]
450 fn test_serde_roundtrip_preserves_weights_and_activation() {
451 let mut rng = make_rng();
452 let backend = make_backend();
453 let layer: Layer = Layer::new(4, 3, Activation::Tanh, &backend, &mut rng);
454 let json = serde_json::to_string(&layer).unwrap();
455 let restored: Layer = serde_json::from_str(&json).unwrap();
456 assert_eq!(layer.bias, restored.bias);
457 assert_eq!(layer.activation, restored.activation);
458 for r in 0..3 {
459 for c in 0..4 {
460 assert!(
461 (layer.weights.get(r, c) - restored.weights.get(r, c)).abs() < 1e-15,
462 "Weights not preserved in serde roundtrip"
463 );
464 }
465 }
466 }
467}