1use std::f32::consts::PI as PI_F32;
6
7use super::types::*;
8
9pub fn batch_forward(net: &FeedForwardNet, batch: &[Vec<f32>]) -> Vec<Vec<f32>> {
11 batch.iter().map(|x| net.forward(x)).collect()
12}
13pub fn batch_atomic_energies(
17 aann: &AtomicNeuralNetwork,
18 descriptors: &[Vec<f32>],
19 atomic_numbers: &[u8],
20) -> Vec<f32> {
21 assert_eq!(
22 descriptors.len(),
23 atomic_numbers.len(),
24 "batch_atomic_energies: descriptors and atomic_numbers must have the same length"
25 );
26 descriptors
27 .iter()
28 .zip(atomic_numbers.iter())
29 .map(|(desc, &z)| aann.atomic_energy(z, desc).unwrap_or(0.0))
30 .collect()
31}
32pub(super) const _PI_F32_USED: f32 = PI_F32;
33pub fn load_weights_from_buffer(net: &mut FeedForwardNet, buffer: &[f32]) -> usize {
37 let mut offset = 0;
38 for layer in &mut net.layers {
39 let n_weights = layer.out_features * layer.in_features;
40 let n_biases = layer.out_features;
41 let total = n_weights + n_biases;
42 if offset + total > buffer.len() {
43 break;
44 }
45 layer.set_weights(&buffer[offset..offset + n_weights]);
46 offset += n_weights;
47 layer.set_biases(&buffer[offset..offset + n_biases]);
48 offset += n_biases;
49 }
50 offset
51}
52pub fn save_weights_to_buffer(net: &FeedForwardNet) -> Vec<f32> {
54 let mut buffer = Vec::new();
55 for layer in &net.layers {
56 buffer.extend_from_slice(&layer.weights);
57 buffer.extend_from_slice(&layer.biases);
58 }
59 buffer
60}
61pub fn softmax(x: &[f32]) -> Vec<f32> {
63 if x.is_empty() {
64 return Vec::new();
65 }
66 let max_val = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
67 let exps: Vec<f32> = x.iter().map(|&v| (v - max_val).exp()).collect();
68 let sum: f32 = exps.iter().sum();
69 exps.iter().map(|&e| e / sum).collect()
70}
71pub fn cross_entropy_loss(logits: &[f32], target_idx: usize) -> f32 {
75 let probs = softmax(logits);
76 let p = probs[target_idx].max(1e-7);
77 -p.ln()
78}
79pub fn mse_loss(predictions: &[f32], targets: &[f32]) -> f32 {
81 assert_eq!(predictions.len(), targets.len());
82 let n = predictions.len() as f32;
83 predictions
84 .iter()
85 .zip(targets.iter())
86 .map(|(&p, &t)| (p - t) * (p - t))
87 .sum::<f32>()
88 / n
89}
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 use crate::BatchNormLayer;
95
96 use crate::FeedForwardNet;
97
98 #[test]
99 fn test_dense_layer_zero_weights_bias_output() {
100 let mut layer = DenseLayer::new(3, 2, ActivationFn::Linear);
101 layer.set_biases(&[1.0, 2.0]);
102 let out = layer.forward(&[10.0, 20.0, 30.0]);
103 assert!((out[0] - 1.0).abs() < 1e-6, "expected 1.0, got {}", out[0]);
104 assert!((out[1] - 2.0).abs() < 1e-6, "expected 2.0, got {}", out[1]);
105 }
106 #[test]
107 fn test_activation_tanh_zero() {
108 let act = ActivationFn::Tanh;
109 assert!((act.apply(0.0)).abs() < 1e-7, "tanh(0) should be 0");
110 }
111 #[test]
112 fn test_activation_relu() {
113 let act = ActivationFn::Relu;
114 assert!((act.apply(-1.0)).abs() < 1e-7, "relu(-1) should be 0");
115 assert!((act.apply(1.0) - 1.0).abs() < 1e-7, "relu(1) should be 1");
116 }
117 #[test]
118 fn test_feedforward_net_two_layers() {
119 let mut layer1 = DenseLayer::new(2, 3, ActivationFn::Linear);
120 layer1.set_weights(&[1.0, 0.0, 0.0, 1.0, 1.0, 0.0]);
121 let mut layer2 = DenseLayer::new(3, 1, ActivationFn::Linear);
122 layer2.set_weights(&[1.0, 1.0, 1.0]);
123 let mut net = FeedForwardNet::new();
124 net.add_layer(layer1);
125 net.add_layer(layer2);
126 let out = net.forward(&[1.0, 2.0]);
127 assert!((out[0] - 4.0).abs() < 1e-5, "expected 4.0, got {}", out[0]);
128 }
129 #[test]
130 fn test_cutoff_fn_at_zero_and_beyond_rc() {
131 let rc = 6.0;
132 let fc0 = BehlerParrinelloDescriptor::cutoff_fn(0.0, rc);
133 assert!((fc0 - 1.0).abs() < 1e-10, "fc(0) should be 1.0, got {fc0}");
134 let fc_rc = BehlerParrinelloDescriptor::cutoff_fn(rc, rc);
135 assert!((fc_rc).abs() < 1e-10, "fc(rc) should be 0.0, got {fc_rc}");
136 let fc_beyond = BehlerParrinelloDescriptor::cutoff_fn(rc + 1.0, rc);
137 assert!(
138 (fc_beyond).abs() < 1e-10,
139 "fc(>rc) should be 0.0, got {fc_beyond}"
140 );
141 }
142 #[test]
143 fn test_radial_g2_decreases_with_distance() {
144 let rc = 6.0;
145 let eta = 0.5;
146 let rs = 0.0;
147 let g2_near = BehlerParrinelloDescriptor::radial_g2(1.0, eta, rs, rc);
148 let g2_far = BehlerParrinelloDescriptor::radial_g2(5.0, eta, rs, rc);
149 assert!(
150 g2_near > g2_far,
151 "G2 should decrease with distance: near={g2_near}, far={g2_far}"
152 );
153 }
154 #[test]
155 fn test_data_normalizer_round_trip() {
156 let data = vec![
157 vec![1.0_f32, 2.0, 3.0],
158 vec![4.0, 5.0, 6.0],
159 vec![7.0, 8.0, 9.0],
160 ];
161 let norm = DataNormalizer::fit(&data);
162 let sample = &data[1];
163 let transformed = norm.transform(sample);
164 let recovered = norm.inverse_transform(&transformed);
165 for (a, b) in recovered.iter().zip(sample.iter()) {
166 assert!(
167 (a - b).abs() < 1e-5,
168 "round-trip failed: got {a}, expected {b}"
169 );
170 }
171 }
172 #[test]
173 fn test_network_builder_simple_aann_architecture() {
174 let hidden = &[64_usize, 32];
175 let net = NetworkBuilder::simple_aann(20, hidden, 1);
176 assert_eq!(net.layers.len(), 3);
177 assert_eq!(net.input_size(), Some(20));
178 assert_eq!(net.output_size(), Some(1));
179 assert_eq!(net.layers[0].activation, ActivationFn::Tanh);
180 assert_eq!(net.layers[1].activation, ActivationFn::Tanh);
181 assert_eq!(net.layers[2].activation, ActivationFn::Linear);
182 assert_eq!(net.total_parameters(), 3457);
183 }
184 #[test]
185 fn test_batch_norm_identity_transform() {
186 let bn = BatchNormLayer::new(3);
187 let input = vec![1.0, 2.0, 3.0];
188 let output = bn.forward(&input);
189 for i in 0..3 {
190 assert!(
191 (output[i] - input[i]).abs() < 1e-4,
192 "output[{i}]={}",
193 output[i]
194 );
195 }
196 }
197 #[test]
198 fn test_batch_norm_zero_mean_unit_var() {
199 let mut bn = BatchNormLayer::new(2);
200 bn.set_stats(&[5.0, 10.0], &[4.0, 9.0]);
201 let output = bn.forward(&[5.0, 10.0]);
202 assert!(output[0].abs() < 1e-4);
203 assert!(output[1].abs() < 1e-4);
204 }
205 #[test]
206 fn test_batch_norm_affine() {
207 let mut bn = BatchNormLayer::new(2);
208 bn.set_stats(&[0.0, 0.0], &[1.0, 1.0]);
209 bn.set_affine(&[2.0, 3.0], &[1.0, -1.0]);
210 let output = bn.forward(&[1.0, 1.0]);
211 assert!((output[0] - 3.0).abs() < 1e-4);
212 assert!((output[1] - 2.0).abs() < 1e-4);
213 }
214 #[test]
215 fn test_inference_pipeline_dense_only() {
216 let mut pipe = InferencePipeline::new();
217 let mut layer = DenseLayer::new(2, 1, ActivationFn::Linear);
218 layer.set_weights(&[1.0, 1.0]);
219 layer.set_biases(&[0.0]);
220 pipe.add_op(InferenceOp::Dense(layer));
221 let out = pipe.forward(&[3.0, 4.0]);
222 assert!((out[0] - 7.0).abs() < 1e-5);
223 }
224 #[test]
225 fn test_inference_pipeline_with_batch_norm() {
226 let mut pipe = InferencePipeline::new();
227 let mut layer = DenseLayer::new(2, 2, ActivationFn::Linear);
228 layer.set_weights(&[1.0, 0.0, 0.0, 1.0]);
229 pipe.add_op(InferenceOp::Dense(layer));
230 pipe.add_op(InferenceOp::BatchNorm(BatchNormLayer::new(2)));
231 pipe.add_op(InferenceOp::Activation(ActivationFn::Relu));
232 let out = pipe.forward(&[1.0, -2.0]);
233 assert!((out[0] - 1.0).abs() < 1e-4);
234 assert!(out[1].abs() < 1e-4);
235 }
236 #[test]
237 fn test_inference_pipeline_total_params() {
238 let mut pipe = InferencePipeline::new();
239 pipe.add_op(InferenceOp::Dense(DenseLayer::new(
240 4,
241 3,
242 ActivationFn::Relu,
243 )));
244 pipe.add_op(InferenceOp::BatchNorm(BatchNormLayer::new(3)));
245 pipe.add_op(InferenceOp::Dense(DenseLayer::new(
246 3,
247 1,
248 ActivationFn::Linear,
249 )));
250 assert_eq!(pipe.total_parameters(), 25);
251 }
252 #[test]
253 fn test_save_load_weights_roundtrip() {
254 let mut net = FeedForwardNet::new();
255 let mut l1 = DenseLayer::new(2, 3, ActivationFn::Relu);
256 l1.set_weights(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
257 l1.set_biases(&[0.1, 0.2, 0.3]);
258 net.add_layer(l1);
259 let buf = save_weights_to_buffer(&net);
260 assert_eq!(buf.len(), 9);
261 let mut net2 = FeedForwardNet::new();
262 net2.add_layer(DenseLayer::new(2, 3, ActivationFn::Relu));
263 let consumed = load_weights_from_buffer(&mut net2, &buf);
264 assert_eq!(consumed, 9);
265 let input = vec![1.0, 2.0];
266 let out1 = net.forward(&input);
267 let out2 = net2.forward(&input);
268 for (a, b) in out1.iter().zip(out2.iter()) {
269 assert!((a - b).abs() < 1e-6);
270 }
271 }
272 #[test]
273 fn test_load_weights_partial() {
274 let mut net = FeedForwardNet::new();
275 net.add_layer(DenseLayer::new(2, 3, ActivationFn::Relu));
276 net.add_layer(DenseLayer::new(3, 1, ActivationFn::Linear));
277 let buf = vec![1.0_f32; 9];
278 let consumed = load_weights_from_buffer(&mut net, &buf);
279 assert_eq!(consumed, 9);
280 }
281 #[test]
282 fn test_softmax_sums_to_one() {
283 let logits = vec![1.0_f32, 2.0, 3.0, 4.0];
284 let probs = softmax(&logits);
285 let sum: f32 = probs.iter().sum();
286 assert!((sum - 1.0).abs() < 1e-5, "softmax sum={sum}");
287 }
288 #[test]
289 fn test_softmax_largest_is_max() {
290 let logits = vec![1.0_f32, 5.0, 2.0];
291 let probs = softmax(&logits);
292 assert!(probs[1] > probs[0] && probs[1] > probs[2]);
293 }
294 #[test]
295 fn test_softmax_empty() {
296 assert!(softmax(&[]).is_empty());
297 }
298 #[test]
299 fn test_softmax_single() {
300 let probs = softmax(&[42.0]);
301 assert!((probs[0] - 1.0).abs() < 1e-5);
302 }
303 #[test]
304 fn test_cross_entropy_loss_correct_class() {
305 let logits = vec![-10.0_f32, 10.0, -10.0];
306 let loss = cross_entropy_loss(&logits, 1);
307 assert!(loss < 0.01, "loss should be small, got {loss}");
308 }
309 #[test]
310 fn test_cross_entropy_loss_wrong_class() {
311 let logits = vec![10.0_f32, -10.0, -10.0];
312 let loss = cross_entropy_loss(&logits, 1);
313 assert!(loss > 1.0, "loss should be large, got {loss}");
314 }
315 #[test]
316 fn test_mse_loss_zero() {
317 let pred = vec![1.0_f32, 2.0, 3.0];
318 let target = vec![1.0, 2.0, 3.0];
319 assert!(mse_loss(&pred, &target).abs() < 1e-7);
320 }
321 #[test]
322 fn test_mse_loss_positive() {
323 let pred = vec![1.0_f32, 2.0];
324 let target = vec![3.0, 4.0];
325 let loss = mse_loss(&pred, &target);
326 assert!((loss - 4.0).abs() < 1e-5);
327 }
328 #[test]
329 fn test_sigmoid_derivative() {
330 let act = ActivationFn::Sigmoid;
331 let d = act.derivative(0.0);
332 assert!((d - 0.25).abs() < 1e-5, "sigmoid'(0) = {d}");
333 }
334 #[test]
335 fn test_silu_at_zero() {
336 let act = ActivationFn::Silu;
337 let v = act.apply(0.0);
338 assert!(v.abs() < 1e-7);
339 }
340 #[test]
341 fn test_gelu_derivative_positive() {
342 let act = ActivationFn::Gelu;
343 let d = act.derivative(1.0);
344 assert!(d > 0.0, "GELU derivative at 1.0 should be positive");
345 }
346 #[test]
347 fn test_linear_derivative() {
348 let act = ActivationFn::Linear;
349 assert!((act.derivative(42.0) - 1.0).abs() < 1e-7);
350 }
351}
352pub fn compute_forces_batch(positions: &[[f64; 3]], network: &NeuralNetwork) -> Vec<[f64; 3]> {
357 positions
358 .iter()
359 .map(|p| {
360 let out = network.forward(p);
361 let fx = out.first().copied().unwrap_or(0.0);
362 let fy = out.get(1).copied().unwrap_or(0.0);
363 let fz = out.get(2).copied().unwrap_or(0.0);
364 [fx, fy, fz]
365 })
366 .collect()
367}
368pub fn neural_potential_energy(network: &NeuralNetwork, positions: &[[f64; 3]]) -> f64 {
372 positions
373 .iter()
374 .map(|p| network.forward(p).first().copied().unwrap_or(0.0))
375 .sum()
376}
377#[cfg(test)]
378mod neural_f64_tests {
379
380 use crate::ActivationFn64;
381
382 use crate::GpuNeuralBuffer;
383
384 use crate::NeuralNetwork;
385
386 use crate::compute_forces_batch;
387
388 use crate::neural_potential_energy;
389
390 #[test]
391 fn test_forward_pass_output_size() {
392 let net = NeuralNetwork::new(&[3, 8, 8, 3], ActivationFn64::Tanh);
393 let input = [1.0, 0.5, -0.5];
394 let out = net.forward(&input);
395 assert_eq!(out.len(), 3, "output should have 3 components");
396 }
397 #[test]
398 fn test_relu_activation_f64() {
399 let act = ActivationFn64::Relu;
400 assert!((act.apply(-2.0)).abs() < 1e-12, "relu(-2) = 0");
401 assert!((act.apply(3.0) - 3.0).abs() < 1e-12, "relu(3) = 3");
402 }
403 #[test]
404 fn test_relu_batch() {
405 let act = ActivationFn64::Relu;
406 let mut v = vec![-1.0, 0.0, 2.0, -0.5, 3.0];
407 act.apply_batch(&mut v);
408 assert_eq!(v, vec![0.0, 0.0, 2.0, 0.0, 3.0]);
409 }
410 #[test]
411 fn test_xavier_init_non_zero() {
412 let net = NeuralNetwork::new(&[4, 16, 4], ActivationFn64::Relu);
413 let all_zero = net
414 .layers
415 .iter()
416 .all(|l| l.weights.iter().all(|row| row.iter().all(|&w| w == 0.0)));
417 assert!(!all_zero, "Xavier-init weights should not all be zero");
418 }
419 #[test]
420 fn test_batch_forces_count() {
421 let net = NeuralNetwork::new(&[3, 8, 3], ActivationFn64::Relu);
422 let positions: Vec<[f64; 3]> = (0..5).map(|i| [i as f64, 0.0, 0.0]).collect();
423 let forces = compute_forces_batch(&positions, &net);
424 assert_eq!(forces.len(), 5, "one force vector per position");
425 }
426 #[test]
427 fn test_gpu_neural_buffer_roundtrip() {
428 let positions = vec![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
429 let buf = GpuNeuralBuffer::pack_positions(&positions);
430 assert_eq!(buf.batch_size, 2);
431 assert_eq!(buf.data.len(), 6);
432 let forces = buf.unpack_forces();
433 assert_eq!(forces[0], [1.0, 2.0, 3.0]);
434 assert_eq!(forces[1], [4.0, 5.0, 6.0]);
435 }
436 #[test]
437 fn test_neural_potential_energy_positive() {
438 let net = NeuralNetwork::new(&[3, 4, 1], ActivationFn64::Relu);
439 let positions = vec![[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]];
440 let energy = neural_potential_energy(&net, &positions);
441 assert!(energy.is_finite(), "energy should be finite");
442 }
443}
444pub fn l2_regularisation(weights: &[f64], lambda: f64) -> f64 {
448 0.5 * lambda * weights.iter().map(|&w| w * w).sum::<f64>()
449}
450pub fn l2_regularisation_grad(weights: &[f64], lambda: f64) -> Vec<f64> {
452 weights.iter().map(|&w| lambda * w).collect()
453}
454pub fn huber_loss(pred: f64, target: f64, delta: f64) -> f64 {
458 let e = (pred - target).abs();
459 if e <= delta {
460 0.5 * e * e
461 } else {
462 delta * (e - 0.5 * delta)
463 }
464}
465pub fn mean_huber_loss(predictions: &[f64], targets: &[f64], delta: f64) -> f64 {
467 assert_eq!(predictions.len(), targets.len());
468 let n = predictions.len() as f64;
469 predictions
470 .iter()
471 .zip(targets.iter())
472 .map(|(&p, &t)| huber_loss(p, t, delta))
473 .sum::<f64>()
474 / n
475}
476pub fn huber_loss_grad(predictions: &[f64], targets: &[f64], delta: f64) -> Vec<f64> {
478 predictions
479 .iter()
480 .zip(targets.iter())
481 .map(|(&p, &t)| {
482 let e = p - t;
483 if e.abs() <= delta {
484 e
485 } else {
486 delta * e.signum()
487 }
488 })
489 .collect()
490}
491pub fn compute_gradient_norm(gradients: &[&[f64]]) -> f64 {
496 let sum_sq: f64 = gradients
497 .iter()
498 .flat_map(|g| g.iter())
499 .map(|&v| v * v)
500 .sum();
501 sum_sq.sqrt()
502}
503pub fn clip_gradients_by_norm(gradients: &mut [Vec<f64>], max_norm: f64) -> f64 {
509 let refs: Vec<&[f64]> = gradients.iter().map(|v| v.as_slice()).collect();
510 let norm = compute_gradient_norm(&refs);
511 if norm > max_norm && norm > 0.0 {
512 let scale = max_norm / norm;
513 for g in gradients.iter_mut() {
514 for v in g.iter_mut() {
515 *v *= scale;
516 }
517 }
518 }
519 norm
520}
521#[cfg(test)]
522mod extended_tests {
523
524 use crate::AdamOptimizer;
525
526 use crate::DenseLayer64;
527 use crate::DropoutLayer;
528 use crate::ExtActivation;
529
530 use crate::GradAccumulator;
531
532 use crate::huber_loss;
533 use crate::huber_loss_grad;
534 use crate::l2_regularisation;
535 use crate::l2_regularisation_grad;
536 use crate::mean_huber_loss;
537
538 #[test]
539 fn test_leaky_relu_positive() {
540 let act = ExtActivation::LeakyRelu(0.01);
541 assert!((act.apply(3.0) - 3.0).abs() < 1e-12);
542 }
543 #[test]
544 fn test_leaky_relu_negative() {
545 let act = ExtActivation::LeakyRelu(0.1);
546 assert!((act.apply(-2.0) - (-0.2)).abs() < 1e-12);
547 }
548 #[test]
549 fn test_leaky_relu_derivative_positive() {
550 let act = ExtActivation::LeakyRelu(0.05);
551 assert!((act.derivative(1.0) - 1.0).abs() < 1e-12);
552 }
553 #[test]
554 fn test_leaky_relu_derivative_negative() {
555 let act = ExtActivation::LeakyRelu(0.05);
556 assert!((act.derivative(-1.0) - 0.05).abs() < 1e-12);
557 }
558 #[test]
559 fn test_swish_at_zero() {
560 let act = ExtActivation::Swish(1.0);
561 assert!(act.apply(0.0).abs() < 1e-12);
562 }
563 #[test]
564 fn test_swish_positive_region() {
565 let act = ExtActivation::Swish(1.0);
566 let v = act.apply(10.0);
567 assert!((v - 10.0).abs() < 0.01, "swish(10) โ 10, got {v}");
568 }
569 #[test]
570 fn test_ext_activation_apply_vec() {
571 let act = ExtActivation::Relu;
572 let mut v = vec![-1.0, 0.0, 2.0, -3.0];
573 act.apply_vec(&mut v);
574 assert_eq!(v, vec![0.0, 0.0, 2.0, 0.0]);
575 }
576 #[test]
577 fn test_dense64_forward_linear_known_values() {
578 let mut layer = DenseLayer64::new(2, 2, ExtActivation::Linear);
579 layer.weights = vec![1.0, 0.0, 0.0, 1.0];
580 let out = layer.forward(&[3.0, 5.0]);
581 assert!((out[0] - 3.0).abs() < 1e-12);
582 assert!((out[1] - 5.0).abs() < 1e-12);
583 }
584 #[test]
585 fn test_dense64_backward_gradient_shapes() {
586 let mut layer = DenseLayer64::new(3, 2, ExtActivation::Relu);
587 layer.weights = vec![1.0; 6];
588 let _out = layer.forward(&[1.0, 2.0, 3.0]);
589 let (gw, gb, di) = layer.backward(&[1.0, 1.0]);
590 assert_eq!(gw.len(), 6, "grad_weights shape");
591 assert_eq!(gb.len(), 2, "grad_biases shape");
592 assert_eq!(di.len(), 3, "delta_in shape");
593 }
594 #[test]
595 fn test_dense64_sgd_update_reduces_output() {
596 let mut layer = DenseLayer64::new(1, 1, ExtActivation::Linear);
597 layer.weights = vec![2.0];
598 layer.biases = vec![0.0];
599 let out = layer.forward(&[1.0]);
600 let loss_before = out[0] * out[0];
601 let (gw, gb, _) = layer.backward(&[2.0 * out[0]]);
602 layer.apply_sgd(&gw, &gb, 0.1);
603 let out2 = layer.forward(&[1.0]);
604 let loss_after = out2[0] * out2[0];
605 assert!(loss_after < loss_before, "SGD should reduce loss");
606 }
607 #[test]
608 fn test_dense64_num_params() {
609 let layer = DenseLayer64::new(4, 3, ExtActivation::Linear);
610 assert_eq!(layer.num_params(), 4 * 3 + 3);
611 }
612 #[test]
613 fn test_dropout_inference_passthrough() {
614 let mut drop = DropoutLayer::new(0.5, false);
615 let input = vec![1.0, 2.0, 3.0, 4.0];
616 let out = drop.forward(&input);
617 assert_eq!(out, input, "dropout in eval mode should pass through");
618 }
619 #[test]
620 fn test_dropout_rate_zero_no_drop() {
621 let mut drop = DropoutLayer::new(0.0, true);
622 let input = vec![1.0, 2.0, 3.0];
623 let out = drop.forward(&input);
624 assert_eq!(out, input, "zero rate should not drop anything");
625 }
626 #[test]
627 fn test_dropout_rate_one_all_zero() {
628 let mut drop = DropoutLayer::new(1.0, true);
629 let input = vec![5.0, 6.0, 7.0];
630 let out = drop.forward(&input);
631 assert!(
632 out.iter().all(|&x| x == 0.0),
633 "rate=1 should zero everything"
634 );
635 }
636 #[test]
637 fn test_dropout_training_some_zeros() {
638 let mut drop = DropoutLayer::new(0.5, true);
639 drop.set_seed(42);
640 let input = vec![1.0_f64; 100];
641 let out = drop.forward(&input);
642 let n_zeros = out.iter().filter(|&&x| x == 0.0).count();
643 assert!(n_zeros > 10, "expected some zeros, got {n_zeros}");
644 assert!(
645 n_zeros < 90,
646 "expected some non-zeros, too many zeros: {n_zeros}"
647 );
648 }
649 #[test]
650 fn test_dropout_backward_applies_mask() {
651 let mut drop = DropoutLayer::new(0.0, false);
652 let input = vec![1.0, 2.0, 3.0];
653 let _out = drop.forward(&input);
654 let grad = drop.backward(&[1.0, 1.0, 1.0]);
655 assert_eq!(grad, vec![1.0, 1.0, 1.0]);
656 }
657 #[test]
658 fn test_adam_step_decreases_loss() {
659 let mut params = vec![5.0_f64];
660 let mut opt = AdamOptimizer::default_params(1);
661 let initial_abs = params[0].abs();
662 for _ in 0..20 {
663 let grads = vec![2.0 * params[0]];
664 opt.step_update(&mut params, &grads);
665 }
666 assert!(
667 params[0].abs() < initial_abs,
668 "Adam should move towards zero, final={}",
669 params[0]
670 );
671 }
672 #[test]
673 fn test_adam_step_increments_counter() {
674 let mut params = vec![1.0_f64; 3];
675 let mut opt = AdamOptimizer::default_params(3);
676 assert_eq!(opt.step, 0);
677 let grads = vec![0.1; 3];
678 opt.step_update(&mut params, &grads);
679 assert_eq!(opt.step, 1);
680 opt.step_update(&mut params, &grads);
681 assert_eq!(opt.step, 2);
682 }
683 #[test]
684 fn test_adam_reset() {
685 let mut params = vec![1.0_f64; 2];
686 let mut opt = AdamOptimizer::default_params(2);
687 let grads = vec![0.5; 2];
688 opt.step_update(&mut params, &grads);
689 opt.reset();
690 assert_eq!(opt.step, 0);
691 assert!(opt.m.iter().all(|&x| x == 0.0));
692 assert!(opt.v.iter().all(|&x| x == 0.0));
693 }
694 #[test]
695 fn test_adam_moment_accumulation() {
696 let mut params = vec![1.0_f64];
697 let mut opt = AdamOptimizer::new(1, 1e-3, 0.9, 0.999, 1e-8);
698 let grads = vec![1.0_f64];
699 opt.step_update(&mut params, &grads);
700 assert!(
701 (opt.m[0] - 0.1).abs() < 1e-10,
702 "m after step 1 = {}",
703 opt.m[0]
704 );
705 assert!(
706 (opt.v[0] - 0.001).abs() < 1e-10,
707 "v after step 1 = {}",
708 opt.v[0]
709 );
710 }
711 #[test]
712 fn test_grad_accumulator_mean() {
713 let mut acc = GradAccumulator::new(2, 1);
714 acc.accumulate(&[1.0, 2.0], &[3.0]);
715 acc.accumulate(&[3.0, 4.0], &[1.0]);
716 let (gw, gb) = acc.mean_grads();
717 assert!((gw[0] - 2.0).abs() < 1e-12);
718 assert!((gw[1] - 3.0).abs() < 1e-12);
719 assert!((gb[0] - 2.0).abs() < 1e-12);
720 }
721 #[test]
722 fn test_grad_accumulator_zero() {
723 let mut acc = GradAccumulator::new(3, 2);
724 acc.accumulate(&[1.0, 2.0, 3.0], &[4.0, 5.0]);
725 assert_eq!(acc.count, 1);
726 acc.zero();
727 assert_eq!(acc.count, 0);
728 assert!(acc.grad_weights.iter().all(|&x| x == 0.0));
729 }
730 #[test]
731 fn test_l2_regularisation() {
732 let weights = vec![1.0, 2.0, 3.0];
733 let reg = l2_regularisation(&weights, 0.01);
734 assert!((reg - 0.07).abs() < 1e-12, "L2 reg = {reg}");
735 }
736 #[test]
737 fn test_l2_regularisation_grad() {
738 let weights = vec![2.0, -3.0];
739 let grad = l2_regularisation_grad(&weights, 0.1);
740 assert!((grad[0] - 0.2).abs() < 1e-12);
741 assert!((grad[1] - (-0.3)).abs() < 1e-12);
742 }
743 #[test]
744 fn test_huber_loss_small_error() {
745 let loss = huber_loss(1.0, 1.1, 0.5);
746 assert!((loss - 0.5 * 0.01).abs() < 1e-12, "huber loss = {loss}");
747 }
748 #[test]
749 fn test_huber_loss_large_error() {
750 let loss = huber_loss(0.0, 5.0, 1.0);
751 assert!((loss - 4.5).abs() < 1e-12, "huber loss = {loss}");
752 }
753 #[test]
754 fn test_mean_huber_loss() {
755 let preds = vec![0.0, 0.0];
756 let targets = vec![0.1, 5.0];
757 let loss = mean_huber_loss(&preds, &targets, 1.0);
758 assert!((loss - 2.2525).abs() < 1e-10, "mean huber loss = {loss}");
759 }
760 #[test]
761 fn test_huber_loss_grad_small() {
762 let preds = vec![1.0];
763 let targets = vec![1.1];
764 let grad = huber_loss_grad(&preds, &targets, 1.0);
765 assert!((grad[0] - (-0.1)).abs() < 1e-12);
766 }
767 #[test]
768 fn test_huber_loss_grad_large() {
769 let preds = vec![10.0];
770 let targets = vec![0.0];
771 let grad = huber_loss_grad(&preds, &targets, 1.0);
772 assert!((grad[0] - 1.0).abs() < 1e-12);
773 }
774}
775#[cfg(test)]
776mod conv_rnn_tests {
777
778 use crate::BatchNormLayer;
779 use crate::Conv1DLayer;
780
781 use crate::ExtActivation;
782 use crate::FeedForwardNet;
783
784 use crate::LayerNorm;
785 use crate::LayerNormLayer;
786
787 use crate::RnnCell;
788
789 use crate::clip_gradients_by_norm;
790
791 use crate::compute_gradient_norm;
792
793 #[test]
794 fn test_conv1d_zero_weights_output_is_bias() {
795 let mut conv = Conv1DLayer::new(2, 3, 2, ExtActivation::Linear);
796 conv.biases = vec![1.0, 2.0, 3.0];
797 let input = vec![vec![0.5, 0.5]; 4];
798 let out = conv.forward(&input);
799 assert_eq!(out.len(), 4);
800 for row in &out {
801 assert_eq!(row.len(), 3);
802 assert!((row[0] - 1.0).abs() < 1e-12, "out[0] = {}", row[0]);
803 assert!((row[1] - 2.0).abs() < 1e-12, "out[1] = {}", row[1]);
804 assert!((row[2] - 3.0).abs() < 1e-12, "out[2] = {}", row[2]);
805 }
806 }
807 #[test]
808 fn test_conv1d_num_params() {
809 let conv = Conv1DLayer::new(4, 8, 3, ExtActivation::Relu);
810 assert_eq!(conv.num_params(), 104);
811 }
812 #[test]
813 fn test_conv1d_output_shape() {
814 let conv = Conv1DLayer::new(3, 5, 3, ExtActivation::Tanh);
815 let input: Vec<Vec<f64>> = (0..10).map(|_| vec![1.0, 0.0, -1.0]).collect();
816 let out = conv.forward(&input);
817 assert_eq!(out.len(), 10, "seq_len preserved");
818 assert_eq!(out[0].len(), 5, "out_channels");
819 }
820 #[test]
821 fn test_conv1d_causal_first_step_only_sees_t0() {
822 let mut conv = Conv1DLayer::new(1, 1, 3, ExtActivation::Linear);
823 conv.weights[0][0][0] = 1.0;
824 conv.weights[0][1][0] = 100.0;
825 conv.weights[0][2][0] = 100.0;
826 let input = vec![vec![5.0], vec![0.0], vec![0.0]];
827 let out = conv.forward(&input);
828 assert!(
829 (out[0][0] - 5.0).abs() < 1e-12,
830 "t=0 output = {}",
831 out[0][0]
832 );
833 }
834 #[test]
835 fn test_conv1d_kernel1_is_pointwise() {
836 let mut conv = Conv1DLayer::new(2, 1, 1, ExtActivation::Linear);
837 conv.weights[0][0][0] = 2.0;
838 conv.weights[0][0][1] = 3.0;
839 let input = vec![vec![1.0, 1.0], vec![2.0, 2.0]];
840 let out = conv.forward(&input);
841 assert!((out[0][0] - 5.0).abs() < 1e-12);
842 assert!((out[1][0] - 10.0).abs() < 1e-12);
843 }
844 #[test]
845 fn test_conv1d_relu_clips_negative() {
846 let conv = Conv1DLayer::new(1, 1, 1, ExtActivation::Relu);
847 let input: Vec<Vec<f64>> = vec![vec![-5.0], vec![-3.0]];
848 let out = conv.forward(&input);
849 assert!(out[0][0] >= 0.0, "relu should clip negative");
850 assert!(out[1][0] >= 0.0);
851 }
852 #[test]
853 fn test_layer_norm_zero_mean_after_forward() {
854 let ln = LayerNorm::new(4);
855 let x = vec![1.0, 2.0, 3.0, 4.0];
856 let out = ln.forward(&x);
857 let mean: f64 = out.iter().sum::<f64>() / out.len() as f64;
858 assert!(mean.abs() < 1e-10, "mean after LayerNorm = {mean}");
859 }
860 #[test]
861 fn test_layer_norm_unit_variance() {
862 let ln = LayerNorm::new(4);
863 let x = vec![1.0, 2.0, 3.0, 4.0];
864 let out = ln.forward(&x);
865 let mean: f64 = out.iter().sum::<f64>() / out.len() as f64;
866 let var: f64 = out.iter().map(|&v| (v - mean) * (v - mean)).sum::<f64>() / out.len() as f64;
867 assert!((var - 1.0).abs() < 1e-3, "variance after LayerNorm = {var}");
868 }
869 #[test]
870 fn test_layer_norm_identity_gamma_beta() {
871 let ln = LayerNorm::new(3);
872 let x = vec![5.0, 5.0, 5.0];
873 let out = ln.forward(&x);
874 for &v in &out {
875 assert!(v.abs() < 1e-4, "constant input โ near-zero output, got {v}");
876 }
877 }
878 #[test]
879 fn test_layer_norm_output_length() {
880 let ln = LayerNorm::new(6);
881 let x = vec![1.0; 6];
882 let out = ln.forward(&x);
883 assert_eq!(out.len(), 6);
884 }
885 #[test]
886 fn test_layer_norm_custom_gamma_beta() {
887 let mut ln = LayerNorm::new(2);
888 ln.gamma = vec![2.0, 3.0];
889 ln.beta = vec![1.0, -1.0];
890 let x = vec![0.0, 4.0];
891 let out = ln.forward(&x);
892 assert!((out[0] - (-1.0)).abs() < 1e-4, "out[0] = {}", out[0]);
893 assert!((out[1] - 2.0).abs() < 1e-4, "out[1] = {}", out[1]);
894 }
895 #[test]
896 fn test_rnn_cell_zero_weights_output_is_activated_bias() {
897 let cell = RnnCell::new(2, 3, ExtActivation::Linear);
898 let x = vec![10.0, 20.0];
899 let h_prev = vec![1.0, 2.0, 3.0];
900 let h = cell.step(&x, &h_prev);
901 for &v in &h {
902 assert!(v.abs() < 1e-12, "zero weights โ zero output, got {v}");
903 }
904 }
905 #[test]
906 fn test_rnn_cell_output_length() {
907 let cell = RnnCell::new(4, 8, ExtActivation::Tanh);
908 let x = vec![0.0; 4];
909 let h_prev = vec![0.0; 8];
910 let h = cell.step(&x, &h_prev);
911 assert_eq!(h.len(), 8);
912 }
913 #[test]
914 fn test_rnn_cell_identity_weights_copies_input() {
915 let mut cell = RnnCell::new(2, 2, ExtActivation::Linear);
916 cell.w_x[0] = 1.0;
917 cell.w_x[3] = 1.0;
918 let x = vec![3.0, 7.0];
919 let h_prev = vec![0.0, 0.0];
920 let h = cell.step(&x, &h_prev);
921 assert!((h[0] - 3.0).abs() < 1e-12, "h[0] = {}", h[0]);
922 assert!((h[1] - 7.0).abs() < 1e-12, "h[1] = {}", h[1]);
923 }
924 #[test]
925 fn test_rnn_cell_sequence_length() {
926 let cell = RnnCell::new(3, 5, ExtActivation::Relu);
927 let seq: Vec<Vec<f64>> = (0..7).map(|_| vec![0.0; 3]).collect();
928 let states = cell.forward_sequence(&seq);
929 assert_eq!(states.len(), 7);
930 assert_eq!(states[0].len(), 5);
931 }
932 #[test]
933 fn test_rnn_cell_sequence_accumulates_state() {
934 let mut cell = RnnCell::new(1, 1, ExtActivation::Linear);
935 cell.w_x[0] = 0.0;
936 cell.w_h[0] = 2.0;
937 let h0 = vec![1.0_f64];
938 let h1 = cell.step(&[0.0], &h0);
939 assert!((h1[0] - 2.0).abs() < 1e-12, "h1 = {}", h1[0]);
940 let h2 = cell.step(&[0.0], &h1);
941 assert!((h2[0] - 4.0).abs() < 1e-12, "h2 = {}", h2[0]);
942 }
943 #[test]
944 fn test_layer_norm_zero_mean_unit_variance() {
945 let ln = LayerNormLayer::new(4);
946 let input = vec![1.0, 2.0, 3.0, 4.0];
947 let out = ln.forward(&input);
948 let mean: f64 = out.iter().sum::<f64>() / out.len() as f64;
949 let var: f64 = out.iter().map(|&x| (x - mean) * (x - mean)).sum::<f64>() / out.len() as f64;
950 assert!(mean.abs() < 1e-10, "output mean should be ~0, got {mean}");
951 assert!(
952 (var - 1.0).abs() < 1e-5,
953 "output var should be ~1, got {var}"
954 );
955 }
956 #[test]
957 fn test_layer_norm_gamma_scales_output() {
958 let mut ln = LayerNormLayer::new(3);
959 ln.gamma = vec![2.0, 2.0, 2.0];
960 let input = vec![1.0, 2.0, 3.0];
961 let out = ln.forward(&input);
962 let mean: f64 = out.iter().sum::<f64>() / out.len() as f64;
963 let var: f64 = out.iter().map(|&x| (x - mean) * (x - mean)).sum::<f64>() / out.len() as f64;
964 assert!(
965 (var.sqrt() - 2.0).abs() < 1e-3,
966 "std should be ~2 with gamma=2, got {}",
967 var.sqrt()
968 );
969 }
970 #[test]
971 fn test_layer_norm_beta_shifts_output() {
972 let mut ln = LayerNormLayer::new(2);
973 ln.beta = vec![5.0, 5.0];
974 let input = vec![0.0, 1.0];
975 let out = ln.forward(&input);
976 let mean: f64 = out.iter().sum::<f64>() / out.len() as f64;
977 assert!(
978 (mean - 5.0).abs() < 1e-5,
979 "mean should be ~5 with beta=5, got {mean}"
980 );
981 }
982 #[test]
983 fn test_layer_norm_backward_d_beta_equals_d_output() {
984 let ln = LayerNormLayer::new(3);
985 let input = vec![1.0, 2.0, 3.0];
986 let d_out = vec![0.1, 0.2, 0.3];
987 let (_d_in, _d_gamma, d_beta) = ln.backward(&input, &d_out);
988 for (i, (&db, &dout)) in d_beta.iter().zip(d_out.iter()).enumerate() {
989 assert!(
990 (db - dout).abs() < 1e-12,
991 "d_beta[{i}] should equal d_output[{i}]"
992 );
993 }
994 }
995 #[test]
996 fn test_compute_gradient_norm_zero() {
997 let g: Vec<f64> = vec![0.0; 5];
998 let norm = compute_gradient_norm(&[&g]);
999 assert!(norm.abs() < 1e-12, "norm of zero gradient should be 0");
1000 }
1001 #[test]
1002 fn test_compute_gradient_norm_known_value() {
1003 let g = vec![3.0_f64, 4.0];
1004 let norm = compute_gradient_norm(&[&g]);
1005 assert!((norm - 5.0).abs() < 1e-10, "norm should be 5.0, got {norm}");
1006 }
1007 #[test]
1008 fn test_clip_gradients_no_clip_when_below_max() {
1009 let mut grads = vec![vec![1.0_f64, 0.0], vec![0.0, 1.0]];
1010 let norm_before = clip_gradients_by_norm(&mut grads, 5.0);
1011 assert!((norm_before - 2.0_f64.sqrt()).abs() < 1e-10);
1012 assert!((grads[0][0] - 1.0).abs() < 1e-12);
1013 }
1014 #[test]
1015 fn test_clip_gradients_clips_correctly() {
1016 let mut grads = vec![vec![3.0_f64, 4.0]];
1017 clip_gradients_by_norm(&mut grads, 1.0);
1018 let new_norm: f64 = grads[0].iter().map(|&v| v * v).sum::<f64>().sqrt();
1019 assert!(
1020 (new_norm - 1.0).abs() < 1e-10,
1021 "clipped norm should be 1.0, got {new_norm}"
1022 );
1023 }
1024 #[test]
1025 fn test_batch_norm_update_running_stats_mean_converges() {
1026 let mut bn = BatchNormLayer::new(2);
1027 let batch = vec![vec![2.0_f32, 3.0], vec![2.0, 3.0], vec![2.0, 3.0]];
1028 for _ in 0..50 {
1029 bn.update_running_stats(&batch, 0.1);
1030 }
1031 assert!(
1032 (bn.running_mean[0] - 2.0).abs() < 0.1,
1033 "mean[0] should converge to 2.0"
1034 );
1035 assert!(
1036 (bn.running_mean[1] - 3.0).abs() < 0.1,
1037 "mean[1] should converge to 3.0"
1038 );
1039 }
1040 #[test]
1041 fn test_batch_norm_update_running_stats_variance_zero() {
1042 let mut bn = BatchNormLayer::new(2);
1043 let batch = vec![vec![1.0_f32, 1.0]; 4];
1044 for _ in 0..100 {
1045 bn.update_running_stats(&batch, 0.2);
1046 }
1047 assert!(
1048 bn.running_var[0] < 0.05,
1049 "variance should be ~0 for constant input, got {}",
1050 bn.running_var[0]
1051 );
1052 }
1053 #[test]
1054 fn test_ffnet_compute_gradient_norm_correct() {
1055 let net = FeedForwardNet::new();
1056 let grads = vec![vec![3.0_f32, 4.0]];
1057 let norm = net.compute_gradient_norm(&grads);
1058 assert!(
1059 (norm - 5.0).abs() < 1e-5,
1060 "gradient norm should be 5.0, got {norm}"
1061 );
1062 }
1063 #[test]
1064 fn test_ffnet_clip_gradients_applies_scaling() {
1065 let net = FeedForwardNet::new();
1066 let mut grads = vec![vec![3.0_f32, 4.0]];
1067 net.clip_gradients(&mut grads, 1.0);
1068 let new_norm: f32 = grads[0].iter().map(|&v| v * v).sum::<f32>().sqrt();
1069 assert!(
1070 (new_norm - 1.0).abs() < 1e-5,
1071 "clipped norm should be 1.0, got {new_norm}"
1072 );
1073 }
1074}
1075#[allow(dead_code)]
1086#[allow(clippy::too_many_arguments)]
1087pub fn scaled_dot_product_attention(
1088 q: &[f64],
1089 k: &[f64],
1090 v: &[f64],
1091 seq_q: usize,
1092 seq_k: usize,
1093 d_k: usize,
1094 d_v: usize,
1095 mask: Option<&[f64]>,
1096) -> Vec<f64> {
1097 assert_eq!(q.len(), seq_q * d_k);
1098 assert_eq!(k.len(), seq_k * d_k);
1099 assert_eq!(v.len(), seq_k * d_v);
1100 let scale = (d_k as f64).sqrt();
1101 let mut scores = vec![0.0_f64; seq_q * seq_k];
1102 for i in 0..seq_q {
1103 for j in 0..seq_k {
1104 let mut dot = 0.0_f64;
1105 for d in 0..d_k {
1106 dot += q[i * d_k + d] * k[j * d_k + d];
1107 }
1108 scores[i * seq_k + j] = dot / scale;
1109 }
1110 }
1111 if let Some(m) = mask {
1112 assert_eq!(m.len(), seq_q * seq_k);
1113 for idx in 0..scores.len() {
1114 scores[idx] += m[idx];
1115 }
1116 }
1117 let mut attn_weights = vec![0.0_f64; seq_q * seq_k];
1118 for i in 0..seq_q {
1119 let row_start = i * seq_k;
1120 let row = &scores[row_start..row_start + seq_k];
1121 let max_val = row.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1122 let exp_vals: Vec<f64> = row.iter().map(|&s| (s - max_val).exp()).collect();
1123 let sum_exp: f64 = exp_vals.iter().sum();
1124 for j in 0..seq_k {
1125 attn_weights[row_start + j] = exp_vals[j] / sum_exp.max(1e-30);
1126 }
1127 }
1128 let mut output = vec![0.0_f64; seq_q * d_v];
1129 for i in 0..seq_q {
1130 for dv in 0..d_v {
1131 let mut acc = 0.0_f64;
1132 for j in 0..seq_k {
1133 acc += attn_weights[i * seq_k + j] * v[j * d_v + dv];
1134 }
1135 output[i * d_v + dv] = acc;
1136 }
1137 }
1138 output
1139}
1140#[cfg(test)]
1141mod attention_gnn_tests {
1142
1143 use crate::AttentionReadout;
1144
1145 use crate::ExtActivation;
1146
1147 use crate::GnnLayer;
1148
1149 use crate::MessagePassingNet;
1150 use crate::MultiHeadAttention;
1151
1152 use crate::PositionalEncoding;
1153
1154 use crate::TransformerBlock;
1155 use crate::TransformerFfn;
1156
1157 use crate::scaled_dot_product_attention;
1158 #[test]
1159 fn test_positional_encoding_shape() {
1160 let pe = PositionalEncoding::new(8, 16);
1161 assert_eq!(pe.table.len(), 16);
1162 assert_eq!(pe.table[0].len(), 8);
1163 }
1164 #[test]
1165 fn test_positional_encoding_position_zero_first_dim_sin_zero() {
1166 let pe = PositionalEncoding::new(4, 10);
1167 assert!(
1168 pe.table[0][0].abs() < 1e-12,
1169 "PE[0,0] should be 0.0 (sin(0))"
1170 );
1171 }
1172 #[test]
1173 fn test_positional_encoding_first_dim_cos_at_zero() {
1174 let pe = PositionalEncoding::new(4, 10);
1175 assert!(
1176 (pe.table[0][1] - 1.0).abs() < 1e-12,
1177 "PE[0,1] should be 1.0 (cos(0))"
1178 );
1179 }
1180 #[test]
1181 fn test_positional_encoding_add_to_sequence() {
1182 let pe = PositionalEncoding::new(4, 5);
1183 let mut seq = vec![vec![0.0_f64; 4]; 3];
1184 pe.add_to_sequence(&mut seq);
1185 let expected = (1.0_f64 / 1.0_f64).sin();
1186 assert!(
1187 (seq[1][0] - expected).abs() < 1e-12,
1188 "seq[1][0] = {}",
1189 seq[1][0]
1190 );
1191 }
1192 #[test]
1193 fn test_positional_encoding_get_returns_slice() {
1194 let pe = PositionalEncoding::new(8, 10);
1195 let row = pe.get(3);
1196 assert_eq!(row.len(), 8);
1197 }
1198 #[test]
1199 fn test_positional_encoding_different_positions_differ() {
1200 let pe = PositionalEncoding::new(8, 10);
1201 let row0 = pe.get(0);
1202 let row1 = pe.get(1);
1203 let same = row0
1204 .iter()
1205 .zip(row1.iter())
1206 .all(|(a, b)| (a - b).abs() < 1e-12);
1207 assert!(!same, "PE at pos=0 and pos=1 should differ");
1208 }
1209 #[test]
1210 fn test_sdpa_output_shape() {
1211 let q = vec![0.1_f64; 3 * 4];
1212 let k = vec![0.2_f64; 3 * 4];
1213 let v = vec![0.3_f64; 3 * 4];
1214 let out = scaled_dot_product_attention(&q, &k, &v, 3, 3, 4, 4, None);
1215 assert_eq!(out.len(), 3 * 4, "output should have seq_q * d_v elements");
1216 }
1217 #[test]
1218 fn test_sdpa_uniform_attention_averages_values() {
1219 let q = vec![0.0_f64; 2 * 2];
1220 let k = vec![0.0_f64; 3 * 2];
1221 let v = vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0];
1222 let out = scaled_dot_product_attention(&q, &k, &v, 2, 3, 2, 2, None);
1223 assert!((out[0] - 1.0 / 3.0).abs() < 1e-8, "out[0]={}", out[0]);
1224 assert!((out[1] - 1.0 / 3.0).abs() < 1e-8, "out[1]={}", out[1]);
1225 }
1226 #[test]
1227 fn test_sdpa_masking_blocks_position() {
1228 let q = vec![0.0_f64; 2];
1229 let k = vec![0.0_f64; 2];
1230 let v = vec![10.0_f64, 20.0];
1231 let mask = vec![0.0_f64, -1e9, 0.0, 0.0];
1232 let out = scaled_dot_product_attention(&q, &k, &v, 2, 2, 1, 1, Some(&mask));
1233 assert!((out[0] - 10.0).abs() < 1e-4, "masked output[0]={}", out[0]);
1234 }
1235 #[test]
1236 fn test_sdpa_attention_weights_sum_to_one() {
1237 let seq = 4;
1238 let dk = 3;
1239 let q = vec![0.0_f64; seq * dk];
1240 let k = vec![0.0_f64; seq * dk];
1241 let v = vec![1.0_f64; seq];
1242 let out = scaled_dot_product_attention(&q, &k, &v, seq, seq, dk, 1, None);
1243 for &o in &out {
1244 assert!(
1245 (o - 1.0).abs() < 1e-8,
1246 "attention weight sum = 1 check: out={o}"
1247 );
1248 }
1249 }
1250 #[test]
1251 fn test_mha_output_shape() {
1252 let mha = MultiHeadAttention::new(8, 2);
1253 let x = vec![0.0_f64; 5 * 8];
1254 let out = mha.forward(&x, 5);
1255 assert_eq!(out.len(), 5 * 8, "MHA output shape mismatch");
1256 }
1257 #[test]
1258 fn test_mha_num_params() {
1259 let mha = MultiHeadAttention::new(16, 4);
1260 assert_eq!(mha.num_params(), 1040);
1261 }
1262 #[test]
1263 fn test_mha_zero_weights_zero_output_except_bias() {
1264 let mut mha = MultiHeadAttention::new(4, 2);
1265 mha.b_o = vec![1.0_f64; 4];
1266 let x = vec![0.0_f64; 3 * 4];
1267 let out = mha.forward(&x, 3);
1268 for &v in &out {
1269 assert!(
1270 (v - 1.0).abs() < 1e-10,
1271 "out should equal bias=1.0, got {v}"
1272 );
1273 }
1274 }
1275 #[test]
1276 fn test_mha_identity_init_output_finite() {
1277 let mut mha = MultiHeadAttention::new(4, 2);
1278 mha.init_identity();
1279 let x: Vec<f64> = (0..3 * 4).map(|i| (i as f64) * 0.1).collect();
1280 let out = mha.forward(&x, 3);
1281 assert_eq!(out.len(), 3 * 4);
1282 for &v in &out {
1283 assert!(v.is_finite(), "output must be finite, got {v}");
1284 }
1285 }
1286 #[test]
1287 fn test_transformer_ffn_output_shape() {
1288 let ffn = TransformerFfn::new(8, 32);
1289 let x = vec![1.0_f64; 5 * 8];
1290 let out = ffn.forward(&x, 5);
1291 assert_eq!(out.len(), 5 * 8);
1292 }
1293 #[test]
1294 fn test_transformer_ffn_zero_weights_zero_output() {
1295 let ffn = TransformerFfn::new(4, 16);
1296 let x = vec![1.0_f64; 3 * 4];
1297 let out = ffn.forward(&x, 3);
1298 for &v in &out {
1299 assert!(v.abs() < 1e-12, "zero weights โ zero output, got {v}");
1300 }
1301 }
1302 #[test]
1303 fn test_transformer_ffn_relu_activation() {
1304 let mut ffn = TransformerFfn::new(2, 2);
1305 ffn.w1 = vec![-1.0_f64; 4];
1306 let x = vec![1.0, 1.0, 1.0, 1.0];
1307 let out = ffn.forward(&x, 2);
1308 for &v in &out {
1309 assert!(
1310 v.abs() < 1e-12,
1311 "relu-clipped hidden โ zero output, got {v}"
1312 );
1313 }
1314 }
1315 #[test]
1316 fn test_transformer_block_output_shape() {
1317 let block = TransformerBlock::new(8, 2, 32);
1318 let x = vec![0.5_f64; 4 * 8];
1319 let out = block.forward(&x, 4);
1320 assert_eq!(out.len(), 4 * 8, "transformer block output shape mismatch");
1321 }
1322 #[test]
1323 fn test_transformer_block_residual_preserves_input_with_zero_weights() {
1324 let block = TransformerBlock::new(4, 2, 16);
1325 let x = vec![1.0_f64; 3 * 4];
1326 let out = block.forward(&x, 3);
1327 assert_eq!(out.len(), x.len());
1328 for &v in &out {
1329 assert!(v.is_finite(), "transformer block output must be finite");
1330 }
1331 }
1332 #[test]
1333 fn test_transformer_block_output_differs_from_input() {
1334 let mut block = TransformerBlock::new(4, 2, 8);
1335 block.mha.init_identity();
1336 let x: Vec<f64> = (0..4 * 4).map(|i| ((i % 7) as f64) * 0.3 - 0.5).collect();
1337 let out = block.forward(&x, 4);
1338 for &v in &out {
1339 assert!(v.is_finite());
1340 }
1341 }
1342 #[test]
1343 fn test_gnn_layer_output_shape() {
1344 let gnn = GnnLayer::new(4, 8, ExtActivation::Relu);
1345 let n_nodes = 3;
1346 let feats = vec![0.5_f64; n_nodes * 4];
1347 let adj = vec![vec![1usize, 2], vec![0], vec![0, 1]];
1348 let out = gnn.forward(&feats, n_nodes, &adj);
1349 assert_eq!(out.len(), n_nodes * 8, "GNN output shape mismatch");
1350 }
1351 #[test]
1352 fn test_gnn_layer_num_params() {
1353 let gnn = GnnLayer::new(4, 8, ExtActivation::Relu);
1354 assert_eq!(gnn.num_params(), 72);
1355 }
1356 #[test]
1357 fn test_gnn_layer_zero_weights_zero_output() {
1358 let gnn = GnnLayer::new(3, 5, ExtActivation::Linear);
1359 let feats = vec![1.0_f64; 4 * 3];
1360 let adj = vec![vec![1usize], vec![0], vec![3], vec![2]];
1361 let out = gnn.forward(&feats, 4, &adj);
1362 for &v in &out {
1363 assert!(v.abs() < 1e-12, "zero weights โ zero output, got {v}");
1364 }
1365 }
1366 #[test]
1367 fn test_gnn_layer_isolated_node_uses_only_self() {
1368 let mut gnn = GnnLayer::new(2, 2, ExtActivation::Linear);
1369 gnn.w_self = vec![1.0, 0.0, 0.0, 1.0];
1370 let feats = vec![3.0, 7.0, 0.0, 0.0];
1371 let adj = vec![vec![], vec![]];
1372 let out = gnn.forward(&feats, 2, &adj);
1373 assert!((out[0] - 3.0).abs() < 1e-12, "node 0 out[0] = {}", out[0]);
1374 assert!((out[1] - 7.0).abs() < 1e-12, "node 0 out[1] = {}", out[1]);
1375 }
1376 #[test]
1377 fn test_gnn_layer_neighbour_aggregation_sum() {
1378 let mut gnn = GnnLayer::new(2, 2, ExtActivation::Linear);
1379 gnn.w_neigh = vec![1.0, 0.0, 0.0, 1.0];
1380 let feats = vec![0.0, 0.0, 1.0, 0.0, 1.0, 0.0];
1381 let adj = vec![vec![1usize, 2], vec![], vec![]];
1382 let out = gnn.forward(&feats, 3, &adj);
1383 assert!(
1384 (out[0] - 2.0).abs() < 1e-12,
1385 "aggregated out[0] = {}",
1386 out[0]
1387 );
1388 assert!((out[1]).abs() < 1e-12, "aggregated out[1] = {}", out[1]);
1389 }
1390 #[test]
1391 fn test_mpnn_output_shape_two_layers() {
1392 let mut mpnn = MessagePassingNet::new();
1393 mpnn.add_layer(GnnLayer::new(4, 8, ExtActivation::Relu));
1394 mpnn.add_layer(GnnLayer::new(8, 4, ExtActivation::Linear));
1395 let feats = vec![0.1_f64; 5 * 4];
1396 let adj: Vec<Vec<usize>> = (0..5)
1397 .map(|i| if i > 0 { vec![i - 1] } else { vec![] })
1398 .collect();
1399 let out = mpnn.forward(&feats, 5, &adj);
1400 assert_eq!(out.len(), 5 * 4);
1401 }
1402 #[test]
1403 fn test_mpnn_global_mean_pool_shape() {
1404 let mpnn = MessagePassingNet::new();
1405 let feats = vec![1.0_f64; 4 * 3];
1406 let pooled = mpnn.global_mean_pool(&feats, 4, 3);
1407 assert_eq!(pooled.len(), 3);
1408 }
1409 #[test]
1410 fn test_mpnn_global_mean_pool_known_values() {
1411 let mpnn = MessagePassingNet::new();
1412 let feats = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
1413 let pooled = mpnn.global_mean_pool(&feats, 3, 2);
1414 assert!((pooled[0] - 3.0).abs() < 1e-12, "mean[0] = {}", pooled[0]);
1415 assert!((pooled[1] - 4.0).abs() < 1e-12, "mean[1] = {}", pooled[1]);
1416 }
1417 #[test]
1418 fn test_mpnn_empty_returns_zero_pool() {
1419 let mpnn = MessagePassingNet::new();
1420 let pooled = mpnn.global_mean_pool(&[], 0, 4);
1421 assert_eq!(pooled, vec![0.0_f64; 4]);
1422 }
1423 #[test]
1424 fn test_mpnn_default_is_empty() {
1425 let mpnn = MessagePassingNet::default();
1426 assert_eq!(mpnn.layers.len(), 0);
1427 }
1428 #[test]
1429 fn test_attention_readout_output_shape() {
1430 let ar = AttentionReadout::new(8);
1431 let feats = vec![0.0_f64; 5 * 8];
1432 let out = ar.forward(&feats, 5);
1433 assert_eq!(out.len(), 8);
1434 }
1435 #[test]
1436 fn test_attention_readout_zero_weights_equal_scores() {
1437 let ar = AttentionReadout::new(2);
1438 let feats = vec![1.0_f64, 2.0, 3.0, 4.0];
1439 let out = ar.forward(&feats, 2);
1440 assert!((out[0] - 2.0).abs() < 1e-10, "readout[0] = {}", out[0]);
1441 assert!((out[1] - 3.0).abs() < 1e-10, "readout[1] = {}", out[1]);
1442 }
1443 #[test]
1444 fn test_attention_readout_all_zeros_input() {
1445 let ar = AttentionReadout::new(4);
1446 let feats = vec![0.0_f64; 3 * 4];
1447 let out = ar.forward(&feats, 3);
1448 for &v in &out {
1449 assert!(v.abs() < 1e-12, "zero input โ zero output, got {v}");
1450 }
1451 }
1452 #[test]
1453 fn test_attention_readout_single_node() {
1454 let ar = AttentionReadout::new(3);
1455 let feats = vec![2.0, 4.0, 6.0];
1456 let out = ar.forward(&feats, 1);
1457 let score = 0.5_f64;
1458 assert!((out[0] - score * 2.0).abs() < 1e-10);
1459 assert!((out[1] - score * 4.0).abs() < 1e-10);
1460 assert!((out[2] - score * 6.0).abs() < 1e-10);
1461 }
1462}