Skip to main content

oxiphysics_gpu/neural_compute/
functions.rs

1//! Auto-generated module
2//!
3//! ๐Ÿค– Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use std::f32::consts::PI as PI_F32;
6
7use super::types::*;
8
9/// Run a forward pass for every sample in `batch`.
10pub fn batch_forward(net: &FeedForwardNet, batch: &[Vec<f32>]) -> Vec<Vec<f32>> {
11    batch.iter().map(|x| net.forward(x)).collect()
12}
13/// Predict atomic energies for a batch of (descriptor, atomic_number) pairs.
14///
15/// Returns `0.0` for atoms whose element has no registered network.
16pub fn batch_atomic_energies(
17    aann: &AtomicNeuralNetwork,
18    descriptors: &[Vec<f32>],
19    atomic_numbers: &[u8],
20) -> Vec<f32> {
21    assert_eq!(
22        descriptors.len(),
23        atomic_numbers.len(),
24        "batch_atomic_energies: descriptors and atomic_numbers must have the same length"
25    );
26    descriptors
27        .iter()
28        .zip(atomic_numbers.iter())
29        .map(|(desc, &z)| aann.atomic_energy(z, desc).unwrap_or(0.0))
30        .collect()
31}
32pub(super) const _PI_F32_USED: f32 = PI_F32;
33/// Load weights from a flat f32 buffer into a network, partitioning by layer sizes.
34///
35/// Returns the number of weights consumed.
36pub fn load_weights_from_buffer(net: &mut FeedForwardNet, buffer: &[f32]) -> usize {
37    let mut offset = 0;
38    for layer in &mut net.layers {
39        let n_weights = layer.out_features * layer.in_features;
40        let n_biases = layer.out_features;
41        let total = n_weights + n_biases;
42        if offset + total > buffer.len() {
43            break;
44        }
45        layer.set_weights(&buffer[offset..offset + n_weights]);
46        offset += n_weights;
47        layer.set_biases(&buffer[offset..offset + n_biases]);
48        offset += n_biases;
49    }
50    offset
51}
52/// Serialize network weights to a flat f32 buffer.
53pub fn save_weights_to_buffer(net: &FeedForwardNet) -> Vec<f32> {
54    let mut buffer = Vec::new();
55    for layer in &net.layers {
56        buffer.extend_from_slice(&layer.weights);
57        buffer.extend_from_slice(&layer.biases);
58    }
59    buffer
60}
61/// Compute softmax of a vector.
62pub fn softmax(x: &[f32]) -> Vec<f32> {
63    if x.is_empty() {
64        return Vec::new();
65    }
66    let max_val = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
67    let exps: Vec<f32> = x.iter().map(|&v| (v - max_val).exp()).collect();
68    let sum: f32 = exps.iter().sum();
69    exps.iter().map(|&e| e / sum).collect()
70}
71/// Compute cross-entropy loss between predictions and one-hot target.
72///
73/// target_idx is the index of the correct class.
74pub fn cross_entropy_loss(logits: &[f32], target_idx: usize) -> f32 {
75    let probs = softmax(logits);
76    let p = probs[target_idx].max(1e-7);
77    -p.ln()
78}
79/// Mean squared error loss between predictions and targets.
80pub fn mse_loss(predictions: &[f32], targets: &[f32]) -> f32 {
81    assert_eq!(predictions.len(), targets.len());
82    let n = predictions.len() as f32;
83    predictions
84        .iter()
85        .zip(targets.iter())
86        .map(|(&p, &t)| (p - t) * (p - t))
87        .sum::<f32>()
88        / n
89}
90#[cfg(test)]
91mod tests {
92    use super::*;
93
94    use crate::BatchNormLayer;
95
96    use crate::FeedForwardNet;
97
98    #[test]
99    fn test_dense_layer_zero_weights_bias_output() {
100        let mut layer = DenseLayer::new(3, 2, ActivationFn::Linear);
101        layer.set_biases(&[1.0, 2.0]);
102        let out = layer.forward(&[10.0, 20.0, 30.0]);
103        assert!((out[0] - 1.0).abs() < 1e-6, "expected 1.0, got {}", out[0]);
104        assert!((out[1] - 2.0).abs() < 1e-6, "expected 2.0, got {}", out[1]);
105    }
106    #[test]
107    fn test_activation_tanh_zero() {
108        let act = ActivationFn::Tanh;
109        assert!((act.apply(0.0)).abs() < 1e-7, "tanh(0) should be 0");
110    }
111    #[test]
112    fn test_activation_relu() {
113        let act = ActivationFn::Relu;
114        assert!((act.apply(-1.0)).abs() < 1e-7, "relu(-1) should be 0");
115        assert!((act.apply(1.0) - 1.0).abs() < 1e-7, "relu(1) should be 1");
116    }
117    #[test]
118    fn test_feedforward_net_two_layers() {
119        let mut layer1 = DenseLayer::new(2, 3, ActivationFn::Linear);
120        layer1.set_weights(&[1.0, 0.0, 0.0, 1.0, 1.0, 0.0]);
121        let mut layer2 = DenseLayer::new(3, 1, ActivationFn::Linear);
122        layer2.set_weights(&[1.0, 1.0, 1.0]);
123        let mut net = FeedForwardNet::new();
124        net.add_layer(layer1);
125        net.add_layer(layer2);
126        let out = net.forward(&[1.0, 2.0]);
127        assert!((out[0] - 4.0).abs() < 1e-5, "expected 4.0, got {}", out[0]);
128    }
129    #[test]
130    fn test_cutoff_fn_at_zero_and_beyond_rc() {
131        let rc = 6.0;
132        let fc0 = BehlerParrinelloDescriptor::cutoff_fn(0.0, rc);
133        assert!((fc0 - 1.0).abs() < 1e-10, "fc(0) should be 1.0, got {fc0}");
134        let fc_rc = BehlerParrinelloDescriptor::cutoff_fn(rc, rc);
135        assert!((fc_rc).abs() < 1e-10, "fc(rc) should be 0.0, got {fc_rc}");
136        let fc_beyond = BehlerParrinelloDescriptor::cutoff_fn(rc + 1.0, rc);
137        assert!(
138            (fc_beyond).abs() < 1e-10,
139            "fc(>rc) should be 0.0, got {fc_beyond}"
140        );
141    }
142    #[test]
143    fn test_radial_g2_decreases_with_distance() {
144        let rc = 6.0;
145        let eta = 0.5;
146        let rs = 0.0;
147        let g2_near = BehlerParrinelloDescriptor::radial_g2(1.0, eta, rs, rc);
148        let g2_far = BehlerParrinelloDescriptor::radial_g2(5.0, eta, rs, rc);
149        assert!(
150            g2_near > g2_far,
151            "G2 should decrease with distance: near={g2_near}, far={g2_far}"
152        );
153    }
154    #[test]
155    fn test_data_normalizer_round_trip() {
156        let data = vec![
157            vec![1.0_f32, 2.0, 3.0],
158            vec![4.0, 5.0, 6.0],
159            vec![7.0, 8.0, 9.0],
160        ];
161        let norm = DataNormalizer::fit(&data);
162        let sample = &data[1];
163        let transformed = norm.transform(sample);
164        let recovered = norm.inverse_transform(&transformed);
165        for (a, b) in recovered.iter().zip(sample.iter()) {
166            assert!(
167                (a - b).abs() < 1e-5,
168                "round-trip failed: got {a}, expected {b}"
169            );
170        }
171    }
172    #[test]
173    fn test_network_builder_simple_aann_architecture() {
174        let hidden = &[64_usize, 32];
175        let net = NetworkBuilder::simple_aann(20, hidden, 1);
176        assert_eq!(net.layers.len(), 3);
177        assert_eq!(net.input_size(), Some(20));
178        assert_eq!(net.output_size(), Some(1));
179        assert_eq!(net.layers[0].activation, ActivationFn::Tanh);
180        assert_eq!(net.layers[1].activation, ActivationFn::Tanh);
181        assert_eq!(net.layers[2].activation, ActivationFn::Linear);
182        assert_eq!(net.total_parameters(), 3457);
183    }
184    #[test]
185    fn test_batch_norm_identity_transform() {
186        let bn = BatchNormLayer::new(3);
187        let input = vec![1.0, 2.0, 3.0];
188        let output = bn.forward(&input);
189        for i in 0..3 {
190            assert!(
191                (output[i] - input[i]).abs() < 1e-4,
192                "output[{i}]={}",
193                output[i]
194            );
195        }
196    }
197    #[test]
198    fn test_batch_norm_zero_mean_unit_var() {
199        let mut bn = BatchNormLayer::new(2);
200        bn.set_stats(&[5.0, 10.0], &[4.0, 9.0]);
201        let output = bn.forward(&[5.0, 10.0]);
202        assert!(output[0].abs() < 1e-4);
203        assert!(output[1].abs() < 1e-4);
204    }
205    #[test]
206    fn test_batch_norm_affine() {
207        let mut bn = BatchNormLayer::new(2);
208        bn.set_stats(&[0.0, 0.0], &[1.0, 1.0]);
209        bn.set_affine(&[2.0, 3.0], &[1.0, -1.0]);
210        let output = bn.forward(&[1.0, 1.0]);
211        assert!((output[0] - 3.0).abs() < 1e-4);
212        assert!((output[1] - 2.0).abs() < 1e-4);
213    }
214    #[test]
215    fn test_inference_pipeline_dense_only() {
216        let mut pipe = InferencePipeline::new();
217        let mut layer = DenseLayer::new(2, 1, ActivationFn::Linear);
218        layer.set_weights(&[1.0, 1.0]);
219        layer.set_biases(&[0.0]);
220        pipe.add_op(InferenceOp::Dense(layer));
221        let out = pipe.forward(&[3.0, 4.0]);
222        assert!((out[0] - 7.0).abs() < 1e-5);
223    }
224    #[test]
225    fn test_inference_pipeline_with_batch_norm() {
226        let mut pipe = InferencePipeline::new();
227        let mut layer = DenseLayer::new(2, 2, ActivationFn::Linear);
228        layer.set_weights(&[1.0, 0.0, 0.0, 1.0]);
229        pipe.add_op(InferenceOp::Dense(layer));
230        pipe.add_op(InferenceOp::BatchNorm(BatchNormLayer::new(2)));
231        pipe.add_op(InferenceOp::Activation(ActivationFn::Relu));
232        let out = pipe.forward(&[1.0, -2.0]);
233        assert!((out[0] - 1.0).abs() < 1e-4);
234        assert!(out[1].abs() < 1e-4);
235    }
236    #[test]
237    fn test_inference_pipeline_total_params() {
238        let mut pipe = InferencePipeline::new();
239        pipe.add_op(InferenceOp::Dense(DenseLayer::new(
240            4,
241            3,
242            ActivationFn::Relu,
243        )));
244        pipe.add_op(InferenceOp::BatchNorm(BatchNormLayer::new(3)));
245        pipe.add_op(InferenceOp::Dense(DenseLayer::new(
246            3,
247            1,
248            ActivationFn::Linear,
249        )));
250        assert_eq!(pipe.total_parameters(), 25);
251    }
252    #[test]
253    fn test_save_load_weights_roundtrip() {
254        let mut net = FeedForwardNet::new();
255        let mut l1 = DenseLayer::new(2, 3, ActivationFn::Relu);
256        l1.set_weights(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
257        l1.set_biases(&[0.1, 0.2, 0.3]);
258        net.add_layer(l1);
259        let buf = save_weights_to_buffer(&net);
260        assert_eq!(buf.len(), 9);
261        let mut net2 = FeedForwardNet::new();
262        net2.add_layer(DenseLayer::new(2, 3, ActivationFn::Relu));
263        let consumed = load_weights_from_buffer(&mut net2, &buf);
264        assert_eq!(consumed, 9);
265        let input = vec![1.0, 2.0];
266        let out1 = net.forward(&input);
267        let out2 = net2.forward(&input);
268        for (a, b) in out1.iter().zip(out2.iter()) {
269            assert!((a - b).abs() < 1e-6);
270        }
271    }
272    #[test]
273    fn test_load_weights_partial() {
274        let mut net = FeedForwardNet::new();
275        net.add_layer(DenseLayer::new(2, 3, ActivationFn::Relu));
276        net.add_layer(DenseLayer::new(3, 1, ActivationFn::Linear));
277        let buf = vec![1.0_f32; 9];
278        let consumed = load_weights_from_buffer(&mut net, &buf);
279        assert_eq!(consumed, 9);
280    }
281    #[test]
282    fn test_softmax_sums_to_one() {
283        let logits = vec![1.0_f32, 2.0, 3.0, 4.0];
284        let probs = softmax(&logits);
285        let sum: f32 = probs.iter().sum();
286        assert!((sum - 1.0).abs() < 1e-5, "softmax sum={sum}");
287    }
288    #[test]
289    fn test_softmax_largest_is_max() {
290        let logits = vec![1.0_f32, 5.0, 2.0];
291        let probs = softmax(&logits);
292        assert!(probs[1] > probs[0] && probs[1] > probs[2]);
293    }
294    #[test]
295    fn test_softmax_empty() {
296        assert!(softmax(&[]).is_empty());
297    }
298    #[test]
299    fn test_softmax_single() {
300        let probs = softmax(&[42.0]);
301        assert!((probs[0] - 1.0).abs() < 1e-5);
302    }
303    #[test]
304    fn test_cross_entropy_loss_correct_class() {
305        let logits = vec![-10.0_f32, 10.0, -10.0];
306        let loss = cross_entropy_loss(&logits, 1);
307        assert!(loss < 0.01, "loss should be small, got {loss}");
308    }
309    #[test]
310    fn test_cross_entropy_loss_wrong_class() {
311        let logits = vec![10.0_f32, -10.0, -10.0];
312        let loss = cross_entropy_loss(&logits, 1);
313        assert!(loss > 1.0, "loss should be large, got {loss}");
314    }
315    #[test]
316    fn test_mse_loss_zero() {
317        let pred = vec![1.0_f32, 2.0, 3.0];
318        let target = vec![1.0, 2.0, 3.0];
319        assert!(mse_loss(&pred, &target).abs() < 1e-7);
320    }
321    #[test]
322    fn test_mse_loss_positive() {
323        let pred = vec![1.0_f32, 2.0];
324        let target = vec![3.0, 4.0];
325        let loss = mse_loss(&pred, &target);
326        assert!((loss - 4.0).abs() < 1e-5);
327    }
328    #[test]
329    fn test_sigmoid_derivative() {
330        let act = ActivationFn::Sigmoid;
331        let d = act.derivative(0.0);
332        assert!((d - 0.25).abs() < 1e-5, "sigmoid'(0) = {d}");
333    }
334    #[test]
335    fn test_silu_at_zero() {
336        let act = ActivationFn::Silu;
337        let v = act.apply(0.0);
338        assert!(v.abs() < 1e-7);
339    }
340    #[test]
341    fn test_gelu_derivative_positive() {
342        let act = ActivationFn::Gelu;
343        let d = act.derivative(1.0);
344        assert!(d > 0.0, "GELU derivative at 1.0 should be positive");
345    }
346    #[test]
347    fn test_linear_derivative() {
348        let act = ActivationFn::Linear;
349        assert!((act.derivative(42.0) - 1.0).abs() < 1e-7);
350    }
351}
352/// Run the network on every position and return a force vector per atom.
353///
354/// The network is expected to take a 3-component position and output a
355/// 3-component force vector.  This is a CPU mock of a GPU batch dispatch.
356pub fn compute_forces_batch(positions: &[[f64; 3]], network: &NeuralNetwork) -> Vec<[f64; 3]> {
357    positions
358        .iter()
359        .map(|p| {
360            let out = network.forward(p);
361            let fx = out.first().copied().unwrap_or(0.0);
362            let fy = out.get(1).copied().unwrap_or(0.0);
363            let fz = out.get(2).copied().unwrap_or(0.0);
364            [fx, fy, fz]
365        })
366        .collect()
367}
368/// Compute a scalar potential energy by summing the first network output over all atoms.
369///
370/// Suitable for networks that map a position to a scalar energy contribution.
371pub fn neural_potential_energy(network: &NeuralNetwork, positions: &[[f64; 3]]) -> f64 {
372    positions
373        .iter()
374        .map(|p| network.forward(p).first().copied().unwrap_or(0.0))
375        .sum()
376}
377#[cfg(test)]
378mod neural_f64_tests {
379
380    use crate::ActivationFn64;
381
382    use crate::GpuNeuralBuffer;
383
384    use crate::NeuralNetwork;
385
386    use crate::compute_forces_batch;
387
388    use crate::neural_potential_energy;
389
390    #[test]
391    fn test_forward_pass_output_size() {
392        let net = NeuralNetwork::new(&[3, 8, 8, 3], ActivationFn64::Tanh);
393        let input = [1.0, 0.5, -0.5];
394        let out = net.forward(&input);
395        assert_eq!(out.len(), 3, "output should have 3 components");
396    }
397    #[test]
398    fn test_relu_activation_f64() {
399        let act = ActivationFn64::Relu;
400        assert!((act.apply(-2.0)).abs() < 1e-12, "relu(-2) = 0");
401        assert!((act.apply(3.0) - 3.0).abs() < 1e-12, "relu(3) = 3");
402    }
403    #[test]
404    fn test_relu_batch() {
405        let act = ActivationFn64::Relu;
406        let mut v = vec![-1.0, 0.0, 2.0, -0.5, 3.0];
407        act.apply_batch(&mut v);
408        assert_eq!(v, vec![0.0, 0.0, 2.0, 0.0, 3.0]);
409    }
410    #[test]
411    fn test_xavier_init_non_zero() {
412        let net = NeuralNetwork::new(&[4, 16, 4], ActivationFn64::Relu);
413        let all_zero = net
414            .layers
415            .iter()
416            .all(|l| l.weights.iter().all(|row| row.iter().all(|&w| w == 0.0)));
417        assert!(!all_zero, "Xavier-init weights should not all be zero");
418    }
419    #[test]
420    fn test_batch_forces_count() {
421        let net = NeuralNetwork::new(&[3, 8, 3], ActivationFn64::Relu);
422        let positions: Vec<[f64; 3]> = (0..5).map(|i| [i as f64, 0.0, 0.0]).collect();
423        let forces = compute_forces_batch(&positions, &net);
424        assert_eq!(forces.len(), 5, "one force vector per position");
425    }
426    #[test]
427    fn test_gpu_neural_buffer_roundtrip() {
428        let positions = vec![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
429        let buf = GpuNeuralBuffer::pack_positions(&positions);
430        assert_eq!(buf.batch_size, 2);
431        assert_eq!(buf.data.len(), 6);
432        let forces = buf.unpack_forces();
433        assert_eq!(forces[0], [1.0, 2.0, 3.0]);
434        assert_eq!(forces[1], [4.0, 5.0, 6.0]);
435    }
436    #[test]
437    fn test_neural_potential_energy_positive() {
438        let net = NeuralNetwork::new(&[3, 4, 1], ActivationFn64::Relu);
439        let positions = vec![[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]];
440        let energy = neural_potential_energy(&net, &positions);
441        assert!(energy.is_finite(), "energy should be finite");
442    }
443}
444/// Compute L2 regularisation loss contribution.
445///
446/// Returns `0.5 * lambda * sum(w^2)` for all elements of `weights`.
447pub fn l2_regularisation(weights: &[f64], lambda: f64) -> f64 {
448    0.5 * lambda * weights.iter().map(|&w| w * w).sum::<f64>()
449}
450/// Compute L2 regularisation gradient contribution (adds `lambda * w` to each element).
451pub fn l2_regularisation_grad(weights: &[f64], lambda: f64) -> Vec<f64> {
452    weights.iter().map(|&w| lambda * w).collect()
453}
454/// Huber loss between prediction and target.
455///
456/// Acts as MSE when `|pred - target| < delta`, and as MAE beyond that.
457pub fn huber_loss(pred: f64, target: f64, delta: f64) -> f64 {
458    let e = (pred - target).abs();
459    if e <= delta {
460        0.5 * e * e
461    } else {
462        delta * (e - 0.5 * delta)
463    }
464}
465/// Mean Huber loss over a batch.
466pub fn mean_huber_loss(predictions: &[f64], targets: &[f64], delta: f64) -> f64 {
467    assert_eq!(predictions.len(), targets.len());
468    let n = predictions.len() as f64;
469    predictions
470        .iter()
471        .zip(targets.iter())
472        .map(|(&p, &t)| huber_loss(p, t, delta))
473        .sum::<f64>()
474        / n
475}
476/// Gradient of Huber loss w.r.t. predictions.
477pub fn huber_loss_grad(predictions: &[f64], targets: &[f64], delta: f64) -> Vec<f64> {
478    predictions
479        .iter()
480        .zip(targets.iter())
481        .map(|(&p, &t)| {
482            let e = p - t;
483            if e.abs() <= delta {
484                e
485            } else {
486                delta * e.signum()
487            }
488        })
489        .collect()
490}
491/// Compute the L2 norm (Frobenius norm) of a concatenated gradient vector.
492///
493/// Used for gradient clipping: if the norm exceeds a threshold the caller
494/// should scale all gradients by `threshold / norm`.
495pub fn compute_gradient_norm(gradients: &[&[f64]]) -> f64 {
496    let sum_sq: f64 = gradients
497        .iter()
498        .flat_map(|g| g.iter())
499        .map(|&v| v * v)
500        .sum();
501    sum_sq.sqrt()
502}
503/// Clip a collection of gradient slices in-place so that their combined L2
504/// norm does not exceed `max_norm`.
505///
506/// If the current norm is already โ‰ค `max_norm` the gradients are unchanged.
507/// Returns the pre-clip norm.
508pub fn clip_gradients_by_norm(gradients: &mut [Vec<f64>], max_norm: f64) -> f64 {
509    let refs: Vec<&[f64]> = gradients.iter().map(|v| v.as_slice()).collect();
510    let norm = compute_gradient_norm(&refs);
511    if norm > max_norm && norm > 0.0 {
512        let scale = max_norm / norm;
513        for g in gradients.iter_mut() {
514            for v in g.iter_mut() {
515                *v *= scale;
516            }
517        }
518    }
519    norm
520}
521#[cfg(test)]
522mod extended_tests {
523
524    use crate::AdamOptimizer;
525
526    use crate::DenseLayer64;
527    use crate::DropoutLayer;
528    use crate::ExtActivation;
529
530    use crate::GradAccumulator;
531
532    use crate::huber_loss;
533    use crate::huber_loss_grad;
534    use crate::l2_regularisation;
535    use crate::l2_regularisation_grad;
536    use crate::mean_huber_loss;
537
538    #[test]
539    fn test_leaky_relu_positive() {
540        let act = ExtActivation::LeakyRelu(0.01);
541        assert!((act.apply(3.0) - 3.0).abs() < 1e-12);
542    }
543    #[test]
544    fn test_leaky_relu_negative() {
545        let act = ExtActivation::LeakyRelu(0.1);
546        assert!((act.apply(-2.0) - (-0.2)).abs() < 1e-12);
547    }
548    #[test]
549    fn test_leaky_relu_derivative_positive() {
550        let act = ExtActivation::LeakyRelu(0.05);
551        assert!((act.derivative(1.0) - 1.0).abs() < 1e-12);
552    }
553    #[test]
554    fn test_leaky_relu_derivative_negative() {
555        let act = ExtActivation::LeakyRelu(0.05);
556        assert!((act.derivative(-1.0) - 0.05).abs() < 1e-12);
557    }
558    #[test]
559    fn test_swish_at_zero() {
560        let act = ExtActivation::Swish(1.0);
561        assert!(act.apply(0.0).abs() < 1e-12);
562    }
563    #[test]
564    fn test_swish_positive_region() {
565        let act = ExtActivation::Swish(1.0);
566        let v = act.apply(10.0);
567        assert!((v - 10.0).abs() < 0.01, "swish(10) โ‰ˆ 10, got {v}");
568    }
569    #[test]
570    fn test_ext_activation_apply_vec() {
571        let act = ExtActivation::Relu;
572        let mut v = vec![-1.0, 0.0, 2.0, -3.0];
573        act.apply_vec(&mut v);
574        assert_eq!(v, vec![0.0, 0.0, 2.0, 0.0]);
575    }
576    #[test]
577    fn test_dense64_forward_linear_known_values() {
578        let mut layer = DenseLayer64::new(2, 2, ExtActivation::Linear);
579        layer.weights = vec![1.0, 0.0, 0.0, 1.0];
580        let out = layer.forward(&[3.0, 5.0]);
581        assert!((out[0] - 3.0).abs() < 1e-12);
582        assert!((out[1] - 5.0).abs() < 1e-12);
583    }
584    #[test]
585    fn test_dense64_backward_gradient_shapes() {
586        let mut layer = DenseLayer64::new(3, 2, ExtActivation::Relu);
587        layer.weights = vec![1.0; 6];
588        let _out = layer.forward(&[1.0, 2.0, 3.0]);
589        let (gw, gb, di) = layer.backward(&[1.0, 1.0]);
590        assert_eq!(gw.len(), 6, "grad_weights shape");
591        assert_eq!(gb.len(), 2, "grad_biases shape");
592        assert_eq!(di.len(), 3, "delta_in shape");
593    }
594    #[test]
595    fn test_dense64_sgd_update_reduces_output() {
596        let mut layer = DenseLayer64::new(1, 1, ExtActivation::Linear);
597        layer.weights = vec![2.0];
598        layer.biases = vec![0.0];
599        let out = layer.forward(&[1.0]);
600        let loss_before = out[0] * out[0];
601        let (gw, gb, _) = layer.backward(&[2.0 * out[0]]);
602        layer.apply_sgd(&gw, &gb, 0.1);
603        let out2 = layer.forward(&[1.0]);
604        let loss_after = out2[0] * out2[0];
605        assert!(loss_after < loss_before, "SGD should reduce loss");
606    }
607    #[test]
608    fn test_dense64_num_params() {
609        let layer = DenseLayer64::new(4, 3, ExtActivation::Linear);
610        assert_eq!(layer.num_params(), 4 * 3 + 3);
611    }
612    #[test]
613    fn test_dropout_inference_passthrough() {
614        let mut drop = DropoutLayer::new(0.5, false);
615        let input = vec![1.0, 2.0, 3.0, 4.0];
616        let out = drop.forward(&input);
617        assert_eq!(out, input, "dropout in eval mode should pass through");
618    }
619    #[test]
620    fn test_dropout_rate_zero_no_drop() {
621        let mut drop = DropoutLayer::new(0.0, true);
622        let input = vec![1.0, 2.0, 3.0];
623        let out = drop.forward(&input);
624        assert_eq!(out, input, "zero rate should not drop anything");
625    }
626    #[test]
627    fn test_dropout_rate_one_all_zero() {
628        let mut drop = DropoutLayer::new(1.0, true);
629        let input = vec![5.0, 6.0, 7.0];
630        let out = drop.forward(&input);
631        assert!(
632            out.iter().all(|&x| x == 0.0),
633            "rate=1 should zero everything"
634        );
635    }
636    #[test]
637    fn test_dropout_training_some_zeros() {
638        let mut drop = DropoutLayer::new(0.5, true);
639        drop.set_seed(42);
640        let input = vec![1.0_f64; 100];
641        let out = drop.forward(&input);
642        let n_zeros = out.iter().filter(|&&x| x == 0.0).count();
643        assert!(n_zeros > 10, "expected some zeros, got {n_zeros}");
644        assert!(
645            n_zeros < 90,
646            "expected some non-zeros, too many zeros: {n_zeros}"
647        );
648    }
649    #[test]
650    fn test_dropout_backward_applies_mask() {
651        let mut drop = DropoutLayer::new(0.0, false);
652        let input = vec![1.0, 2.0, 3.0];
653        let _out = drop.forward(&input);
654        let grad = drop.backward(&[1.0, 1.0, 1.0]);
655        assert_eq!(grad, vec![1.0, 1.0, 1.0]);
656    }
657    #[test]
658    fn test_adam_step_decreases_loss() {
659        let mut params = vec![5.0_f64];
660        let mut opt = AdamOptimizer::default_params(1);
661        let initial_abs = params[0].abs();
662        for _ in 0..20 {
663            let grads = vec![2.0 * params[0]];
664            opt.step_update(&mut params, &grads);
665        }
666        assert!(
667            params[0].abs() < initial_abs,
668            "Adam should move towards zero, final={}",
669            params[0]
670        );
671    }
672    #[test]
673    fn test_adam_step_increments_counter() {
674        let mut params = vec![1.0_f64; 3];
675        let mut opt = AdamOptimizer::default_params(3);
676        assert_eq!(opt.step, 0);
677        let grads = vec![0.1; 3];
678        opt.step_update(&mut params, &grads);
679        assert_eq!(opt.step, 1);
680        opt.step_update(&mut params, &grads);
681        assert_eq!(opt.step, 2);
682    }
683    #[test]
684    fn test_adam_reset() {
685        let mut params = vec![1.0_f64; 2];
686        let mut opt = AdamOptimizer::default_params(2);
687        let grads = vec![0.5; 2];
688        opt.step_update(&mut params, &grads);
689        opt.reset();
690        assert_eq!(opt.step, 0);
691        assert!(opt.m.iter().all(|&x| x == 0.0));
692        assert!(opt.v.iter().all(|&x| x == 0.0));
693    }
694    #[test]
695    fn test_adam_moment_accumulation() {
696        let mut params = vec![1.0_f64];
697        let mut opt = AdamOptimizer::new(1, 1e-3, 0.9, 0.999, 1e-8);
698        let grads = vec![1.0_f64];
699        opt.step_update(&mut params, &grads);
700        assert!(
701            (opt.m[0] - 0.1).abs() < 1e-10,
702            "m after step 1 = {}",
703            opt.m[0]
704        );
705        assert!(
706            (opt.v[0] - 0.001).abs() < 1e-10,
707            "v after step 1 = {}",
708            opt.v[0]
709        );
710    }
711    #[test]
712    fn test_grad_accumulator_mean() {
713        let mut acc = GradAccumulator::new(2, 1);
714        acc.accumulate(&[1.0, 2.0], &[3.0]);
715        acc.accumulate(&[3.0, 4.0], &[1.0]);
716        let (gw, gb) = acc.mean_grads();
717        assert!((gw[0] - 2.0).abs() < 1e-12);
718        assert!((gw[1] - 3.0).abs() < 1e-12);
719        assert!((gb[0] - 2.0).abs() < 1e-12);
720    }
721    #[test]
722    fn test_grad_accumulator_zero() {
723        let mut acc = GradAccumulator::new(3, 2);
724        acc.accumulate(&[1.0, 2.0, 3.0], &[4.0, 5.0]);
725        assert_eq!(acc.count, 1);
726        acc.zero();
727        assert_eq!(acc.count, 0);
728        assert!(acc.grad_weights.iter().all(|&x| x == 0.0));
729    }
730    #[test]
731    fn test_l2_regularisation() {
732        let weights = vec![1.0, 2.0, 3.0];
733        let reg = l2_regularisation(&weights, 0.01);
734        assert!((reg - 0.07).abs() < 1e-12, "L2 reg = {reg}");
735    }
736    #[test]
737    fn test_l2_regularisation_grad() {
738        let weights = vec![2.0, -3.0];
739        let grad = l2_regularisation_grad(&weights, 0.1);
740        assert!((grad[0] - 0.2).abs() < 1e-12);
741        assert!((grad[1] - (-0.3)).abs() < 1e-12);
742    }
743    #[test]
744    fn test_huber_loss_small_error() {
745        let loss = huber_loss(1.0, 1.1, 0.5);
746        assert!((loss - 0.5 * 0.01).abs() < 1e-12, "huber loss = {loss}");
747    }
748    #[test]
749    fn test_huber_loss_large_error() {
750        let loss = huber_loss(0.0, 5.0, 1.0);
751        assert!((loss - 4.5).abs() < 1e-12, "huber loss = {loss}");
752    }
753    #[test]
754    fn test_mean_huber_loss() {
755        let preds = vec![0.0, 0.0];
756        let targets = vec![0.1, 5.0];
757        let loss = mean_huber_loss(&preds, &targets, 1.0);
758        assert!((loss - 2.2525).abs() < 1e-10, "mean huber loss = {loss}");
759    }
760    #[test]
761    fn test_huber_loss_grad_small() {
762        let preds = vec![1.0];
763        let targets = vec![1.1];
764        let grad = huber_loss_grad(&preds, &targets, 1.0);
765        assert!((grad[0] - (-0.1)).abs() < 1e-12);
766    }
767    #[test]
768    fn test_huber_loss_grad_large() {
769        let preds = vec![10.0];
770        let targets = vec![0.0];
771        let grad = huber_loss_grad(&preds, &targets, 1.0);
772        assert!((grad[0] - 1.0).abs() < 1e-12);
773    }
774}
775#[cfg(test)]
776mod conv_rnn_tests {
777
778    use crate::BatchNormLayer;
779    use crate::Conv1DLayer;
780
781    use crate::ExtActivation;
782    use crate::FeedForwardNet;
783
784    use crate::LayerNorm;
785    use crate::LayerNormLayer;
786
787    use crate::RnnCell;
788
789    use crate::clip_gradients_by_norm;
790
791    use crate::compute_gradient_norm;
792
793    #[test]
794    fn test_conv1d_zero_weights_output_is_bias() {
795        let mut conv = Conv1DLayer::new(2, 3, 2, ExtActivation::Linear);
796        conv.biases = vec![1.0, 2.0, 3.0];
797        let input = vec![vec![0.5, 0.5]; 4];
798        let out = conv.forward(&input);
799        assert_eq!(out.len(), 4);
800        for row in &out {
801            assert_eq!(row.len(), 3);
802            assert!((row[0] - 1.0).abs() < 1e-12, "out[0] = {}", row[0]);
803            assert!((row[1] - 2.0).abs() < 1e-12, "out[1] = {}", row[1]);
804            assert!((row[2] - 3.0).abs() < 1e-12, "out[2] = {}", row[2]);
805        }
806    }
807    #[test]
808    fn test_conv1d_num_params() {
809        let conv = Conv1DLayer::new(4, 8, 3, ExtActivation::Relu);
810        assert_eq!(conv.num_params(), 104);
811    }
812    #[test]
813    fn test_conv1d_output_shape() {
814        let conv = Conv1DLayer::new(3, 5, 3, ExtActivation::Tanh);
815        let input: Vec<Vec<f64>> = (0..10).map(|_| vec![1.0, 0.0, -1.0]).collect();
816        let out = conv.forward(&input);
817        assert_eq!(out.len(), 10, "seq_len preserved");
818        assert_eq!(out[0].len(), 5, "out_channels");
819    }
820    #[test]
821    fn test_conv1d_causal_first_step_only_sees_t0() {
822        let mut conv = Conv1DLayer::new(1, 1, 3, ExtActivation::Linear);
823        conv.weights[0][0][0] = 1.0;
824        conv.weights[0][1][0] = 100.0;
825        conv.weights[0][2][0] = 100.0;
826        let input = vec![vec![5.0], vec![0.0], vec![0.0]];
827        let out = conv.forward(&input);
828        assert!(
829            (out[0][0] - 5.0).abs() < 1e-12,
830            "t=0 output = {}",
831            out[0][0]
832        );
833    }
834    #[test]
835    fn test_conv1d_kernel1_is_pointwise() {
836        let mut conv = Conv1DLayer::new(2, 1, 1, ExtActivation::Linear);
837        conv.weights[0][0][0] = 2.0;
838        conv.weights[0][0][1] = 3.0;
839        let input = vec![vec![1.0, 1.0], vec![2.0, 2.0]];
840        let out = conv.forward(&input);
841        assert!((out[0][0] - 5.0).abs() < 1e-12);
842        assert!((out[1][0] - 10.0).abs() < 1e-12);
843    }
844    #[test]
845    fn test_conv1d_relu_clips_negative() {
846        let conv = Conv1DLayer::new(1, 1, 1, ExtActivation::Relu);
847        let input: Vec<Vec<f64>> = vec![vec![-5.0], vec![-3.0]];
848        let out = conv.forward(&input);
849        assert!(out[0][0] >= 0.0, "relu should clip negative");
850        assert!(out[1][0] >= 0.0);
851    }
852    #[test]
853    fn test_layer_norm_zero_mean_after_forward() {
854        let ln = LayerNorm::new(4);
855        let x = vec![1.0, 2.0, 3.0, 4.0];
856        let out = ln.forward(&x);
857        let mean: f64 = out.iter().sum::<f64>() / out.len() as f64;
858        assert!(mean.abs() < 1e-10, "mean after LayerNorm = {mean}");
859    }
860    #[test]
861    fn test_layer_norm_unit_variance() {
862        let ln = LayerNorm::new(4);
863        let x = vec![1.0, 2.0, 3.0, 4.0];
864        let out = ln.forward(&x);
865        let mean: f64 = out.iter().sum::<f64>() / out.len() as f64;
866        let var: f64 = out.iter().map(|&v| (v - mean) * (v - mean)).sum::<f64>() / out.len() as f64;
867        assert!((var - 1.0).abs() < 1e-3, "variance after LayerNorm = {var}");
868    }
869    #[test]
870    fn test_layer_norm_identity_gamma_beta() {
871        let ln = LayerNorm::new(3);
872        let x = vec![5.0, 5.0, 5.0];
873        let out = ln.forward(&x);
874        for &v in &out {
875            assert!(v.abs() < 1e-4, "constant input โ†’ near-zero output, got {v}");
876        }
877    }
878    #[test]
879    fn test_layer_norm_output_length() {
880        let ln = LayerNorm::new(6);
881        let x = vec![1.0; 6];
882        let out = ln.forward(&x);
883        assert_eq!(out.len(), 6);
884    }
885    #[test]
886    fn test_layer_norm_custom_gamma_beta() {
887        let mut ln = LayerNorm::new(2);
888        ln.gamma = vec![2.0, 3.0];
889        ln.beta = vec![1.0, -1.0];
890        let x = vec![0.0, 4.0];
891        let out = ln.forward(&x);
892        assert!((out[0] - (-1.0)).abs() < 1e-4, "out[0] = {}", out[0]);
893        assert!((out[1] - 2.0).abs() < 1e-4, "out[1] = {}", out[1]);
894    }
895    #[test]
896    fn test_rnn_cell_zero_weights_output_is_activated_bias() {
897        let cell = RnnCell::new(2, 3, ExtActivation::Linear);
898        let x = vec![10.0, 20.0];
899        let h_prev = vec![1.0, 2.0, 3.0];
900        let h = cell.step(&x, &h_prev);
901        for &v in &h {
902            assert!(v.abs() < 1e-12, "zero weights โ†’ zero output, got {v}");
903        }
904    }
905    #[test]
906    fn test_rnn_cell_output_length() {
907        let cell = RnnCell::new(4, 8, ExtActivation::Tanh);
908        let x = vec![0.0; 4];
909        let h_prev = vec![0.0; 8];
910        let h = cell.step(&x, &h_prev);
911        assert_eq!(h.len(), 8);
912    }
913    #[test]
914    fn test_rnn_cell_identity_weights_copies_input() {
915        let mut cell = RnnCell::new(2, 2, ExtActivation::Linear);
916        cell.w_x[0] = 1.0;
917        cell.w_x[3] = 1.0;
918        let x = vec![3.0, 7.0];
919        let h_prev = vec![0.0, 0.0];
920        let h = cell.step(&x, &h_prev);
921        assert!((h[0] - 3.0).abs() < 1e-12, "h[0] = {}", h[0]);
922        assert!((h[1] - 7.0).abs() < 1e-12, "h[1] = {}", h[1]);
923    }
924    #[test]
925    fn test_rnn_cell_sequence_length() {
926        let cell = RnnCell::new(3, 5, ExtActivation::Relu);
927        let seq: Vec<Vec<f64>> = (0..7).map(|_| vec![0.0; 3]).collect();
928        let states = cell.forward_sequence(&seq);
929        assert_eq!(states.len(), 7);
930        assert_eq!(states[0].len(), 5);
931    }
932    #[test]
933    fn test_rnn_cell_sequence_accumulates_state() {
934        let mut cell = RnnCell::new(1, 1, ExtActivation::Linear);
935        cell.w_x[0] = 0.0;
936        cell.w_h[0] = 2.0;
937        let h0 = vec![1.0_f64];
938        let h1 = cell.step(&[0.0], &h0);
939        assert!((h1[0] - 2.0).abs() < 1e-12, "h1 = {}", h1[0]);
940        let h2 = cell.step(&[0.0], &h1);
941        assert!((h2[0] - 4.0).abs() < 1e-12, "h2 = {}", h2[0]);
942    }
943    #[test]
944    fn test_layer_norm_zero_mean_unit_variance() {
945        let ln = LayerNormLayer::new(4);
946        let input = vec![1.0, 2.0, 3.0, 4.0];
947        let out = ln.forward(&input);
948        let mean: f64 = out.iter().sum::<f64>() / out.len() as f64;
949        let var: f64 = out.iter().map(|&x| (x - mean) * (x - mean)).sum::<f64>() / out.len() as f64;
950        assert!(mean.abs() < 1e-10, "output mean should be ~0, got {mean}");
951        assert!(
952            (var - 1.0).abs() < 1e-5,
953            "output var should be ~1, got {var}"
954        );
955    }
956    #[test]
957    fn test_layer_norm_gamma_scales_output() {
958        let mut ln = LayerNormLayer::new(3);
959        ln.gamma = vec![2.0, 2.0, 2.0];
960        let input = vec![1.0, 2.0, 3.0];
961        let out = ln.forward(&input);
962        let mean: f64 = out.iter().sum::<f64>() / out.len() as f64;
963        let var: f64 = out.iter().map(|&x| (x - mean) * (x - mean)).sum::<f64>() / out.len() as f64;
964        assert!(
965            (var.sqrt() - 2.0).abs() < 1e-3,
966            "std should be ~2 with gamma=2, got {}",
967            var.sqrt()
968        );
969    }
970    #[test]
971    fn test_layer_norm_beta_shifts_output() {
972        let mut ln = LayerNormLayer::new(2);
973        ln.beta = vec![5.0, 5.0];
974        let input = vec![0.0, 1.0];
975        let out = ln.forward(&input);
976        let mean: f64 = out.iter().sum::<f64>() / out.len() as f64;
977        assert!(
978            (mean - 5.0).abs() < 1e-5,
979            "mean should be ~5 with beta=5, got {mean}"
980        );
981    }
982    #[test]
983    fn test_layer_norm_backward_d_beta_equals_d_output() {
984        let ln = LayerNormLayer::new(3);
985        let input = vec![1.0, 2.0, 3.0];
986        let d_out = vec![0.1, 0.2, 0.3];
987        let (_d_in, _d_gamma, d_beta) = ln.backward(&input, &d_out);
988        for (i, (&db, &dout)) in d_beta.iter().zip(d_out.iter()).enumerate() {
989            assert!(
990                (db - dout).abs() < 1e-12,
991                "d_beta[{i}] should equal d_output[{i}]"
992            );
993        }
994    }
995    #[test]
996    fn test_compute_gradient_norm_zero() {
997        let g: Vec<f64> = vec![0.0; 5];
998        let norm = compute_gradient_norm(&[&g]);
999        assert!(norm.abs() < 1e-12, "norm of zero gradient should be 0");
1000    }
1001    #[test]
1002    fn test_compute_gradient_norm_known_value() {
1003        let g = vec![3.0_f64, 4.0];
1004        let norm = compute_gradient_norm(&[&g]);
1005        assert!((norm - 5.0).abs() < 1e-10, "norm should be 5.0, got {norm}");
1006    }
1007    #[test]
1008    fn test_clip_gradients_no_clip_when_below_max() {
1009        let mut grads = vec![vec![1.0_f64, 0.0], vec![0.0, 1.0]];
1010        let norm_before = clip_gradients_by_norm(&mut grads, 5.0);
1011        assert!((norm_before - 2.0_f64.sqrt()).abs() < 1e-10);
1012        assert!((grads[0][0] - 1.0).abs() < 1e-12);
1013    }
1014    #[test]
1015    fn test_clip_gradients_clips_correctly() {
1016        let mut grads = vec![vec![3.0_f64, 4.0]];
1017        clip_gradients_by_norm(&mut grads, 1.0);
1018        let new_norm: f64 = grads[0].iter().map(|&v| v * v).sum::<f64>().sqrt();
1019        assert!(
1020            (new_norm - 1.0).abs() < 1e-10,
1021            "clipped norm should be 1.0, got {new_norm}"
1022        );
1023    }
1024    #[test]
1025    fn test_batch_norm_update_running_stats_mean_converges() {
1026        let mut bn = BatchNormLayer::new(2);
1027        let batch = vec![vec![2.0_f32, 3.0], vec![2.0, 3.0], vec![2.0, 3.0]];
1028        for _ in 0..50 {
1029            bn.update_running_stats(&batch, 0.1);
1030        }
1031        assert!(
1032            (bn.running_mean[0] - 2.0).abs() < 0.1,
1033            "mean[0] should converge to 2.0"
1034        );
1035        assert!(
1036            (bn.running_mean[1] - 3.0).abs() < 0.1,
1037            "mean[1] should converge to 3.0"
1038        );
1039    }
1040    #[test]
1041    fn test_batch_norm_update_running_stats_variance_zero() {
1042        let mut bn = BatchNormLayer::new(2);
1043        let batch = vec![vec![1.0_f32, 1.0]; 4];
1044        for _ in 0..100 {
1045            bn.update_running_stats(&batch, 0.2);
1046        }
1047        assert!(
1048            bn.running_var[0] < 0.05,
1049            "variance should be ~0 for constant input, got {}",
1050            bn.running_var[0]
1051        );
1052    }
1053    #[test]
1054    fn test_ffnet_compute_gradient_norm_correct() {
1055        let net = FeedForwardNet::new();
1056        let grads = vec![vec![3.0_f32, 4.0]];
1057        let norm = net.compute_gradient_norm(&grads);
1058        assert!(
1059            (norm - 5.0).abs() < 1e-5,
1060            "gradient norm should be 5.0, got {norm}"
1061        );
1062    }
1063    #[test]
1064    fn test_ffnet_clip_gradients_applies_scaling() {
1065        let net = FeedForwardNet::new();
1066        let mut grads = vec![vec![3.0_f32, 4.0]];
1067        net.clip_gradients(&mut grads, 1.0);
1068        let new_norm: f32 = grads[0].iter().map(|&v| v * v).sum::<f32>().sqrt();
1069        assert!(
1070            (new_norm - 1.0).abs() < 1e-5,
1071            "clipped norm should be 1.0, got {new_norm}"
1072        );
1073    }
1074}
1075/// Compute scaled dot-product attention.
1076///
1077/// Attention(Q, K, V) = softmax(Q Kแต€ / โˆšd_k) V
1078///
1079/// All matrices use row-major flat storage with shapes:
1080/// - Q: `[seq_q ร— d_k]`
1081/// - K: `[seq_k ร— d_k]`
1082/// - V: `[seq_k ร— d_v]`
1083///
1084/// Returns output of shape `[seq_q ร— d_v]` (flat row-major).
1085#[allow(dead_code)]
1086#[allow(clippy::too_many_arguments)]
1087pub fn scaled_dot_product_attention(
1088    q: &[f64],
1089    k: &[f64],
1090    v: &[f64],
1091    seq_q: usize,
1092    seq_k: usize,
1093    d_k: usize,
1094    d_v: usize,
1095    mask: Option<&[f64]>,
1096) -> Vec<f64> {
1097    assert_eq!(q.len(), seq_q * d_k);
1098    assert_eq!(k.len(), seq_k * d_k);
1099    assert_eq!(v.len(), seq_k * d_v);
1100    let scale = (d_k as f64).sqrt();
1101    let mut scores = vec![0.0_f64; seq_q * seq_k];
1102    for i in 0..seq_q {
1103        for j in 0..seq_k {
1104            let mut dot = 0.0_f64;
1105            for d in 0..d_k {
1106                dot += q[i * d_k + d] * k[j * d_k + d];
1107            }
1108            scores[i * seq_k + j] = dot / scale;
1109        }
1110    }
1111    if let Some(m) = mask {
1112        assert_eq!(m.len(), seq_q * seq_k);
1113        for idx in 0..scores.len() {
1114            scores[idx] += m[idx];
1115        }
1116    }
1117    let mut attn_weights = vec![0.0_f64; seq_q * seq_k];
1118    for i in 0..seq_q {
1119        let row_start = i * seq_k;
1120        let row = &scores[row_start..row_start + seq_k];
1121        let max_val = row.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1122        let exp_vals: Vec<f64> = row.iter().map(|&s| (s - max_val).exp()).collect();
1123        let sum_exp: f64 = exp_vals.iter().sum();
1124        for j in 0..seq_k {
1125            attn_weights[row_start + j] = exp_vals[j] / sum_exp.max(1e-30);
1126        }
1127    }
1128    let mut output = vec![0.0_f64; seq_q * d_v];
1129    for i in 0..seq_q {
1130        for dv in 0..d_v {
1131            let mut acc = 0.0_f64;
1132            for j in 0..seq_k {
1133                acc += attn_weights[i * seq_k + j] * v[j * d_v + dv];
1134            }
1135            output[i * d_v + dv] = acc;
1136        }
1137    }
1138    output
1139}
1140#[cfg(test)]
1141mod attention_gnn_tests {
1142
1143    use crate::AttentionReadout;
1144
1145    use crate::ExtActivation;
1146
1147    use crate::GnnLayer;
1148
1149    use crate::MessagePassingNet;
1150    use crate::MultiHeadAttention;
1151
1152    use crate::PositionalEncoding;
1153
1154    use crate::TransformerBlock;
1155    use crate::TransformerFfn;
1156
1157    use crate::scaled_dot_product_attention;
1158    #[test]
1159    fn test_positional_encoding_shape() {
1160        let pe = PositionalEncoding::new(8, 16);
1161        assert_eq!(pe.table.len(), 16);
1162        assert_eq!(pe.table[0].len(), 8);
1163    }
1164    #[test]
1165    fn test_positional_encoding_position_zero_first_dim_sin_zero() {
1166        let pe = PositionalEncoding::new(4, 10);
1167        assert!(
1168            pe.table[0][0].abs() < 1e-12,
1169            "PE[0,0] should be 0.0 (sin(0))"
1170        );
1171    }
1172    #[test]
1173    fn test_positional_encoding_first_dim_cos_at_zero() {
1174        let pe = PositionalEncoding::new(4, 10);
1175        assert!(
1176            (pe.table[0][1] - 1.0).abs() < 1e-12,
1177            "PE[0,1] should be 1.0 (cos(0))"
1178        );
1179    }
1180    #[test]
1181    fn test_positional_encoding_add_to_sequence() {
1182        let pe = PositionalEncoding::new(4, 5);
1183        let mut seq = vec![vec![0.0_f64; 4]; 3];
1184        pe.add_to_sequence(&mut seq);
1185        let expected = (1.0_f64 / 1.0_f64).sin();
1186        assert!(
1187            (seq[1][0] - expected).abs() < 1e-12,
1188            "seq[1][0] = {}",
1189            seq[1][0]
1190        );
1191    }
1192    #[test]
1193    fn test_positional_encoding_get_returns_slice() {
1194        let pe = PositionalEncoding::new(8, 10);
1195        let row = pe.get(3);
1196        assert_eq!(row.len(), 8);
1197    }
1198    #[test]
1199    fn test_positional_encoding_different_positions_differ() {
1200        let pe = PositionalEncoding::new(8, 10);
1201        let row0 = pe.get(0);
1202        let row1 = pe.get(1);
1203        let same = row0
1204            .iter()
1205            .zip(row1.iter())
1206            .all(|(a, b)| (a - b).abs() < 1e-12);
1207        assert!(!same, "PE at pos=0 and pos=1 should differ");
1208    }
1209    #[test]
1210    fn test_sdpa_output_shape() {
1211        let q = vec![0.1_f64; 3 * 4];
1212        let k = vec![0.2_f64; 3 * 4];
1213        let v = vec![0.3_f64; 3 * 4];
1214        let out = scaled_dot_product_attention(&q, &k, &v, 3, 3, 4, 4, None);
1215        assert_eq!(out.len(), 3 * 4, "output should have seq_q * d_v elements");
1216    }
1217    #[test]
1218    fn test_sdpa_uniform_attention_averages_values() {
1219        let q = vec![0.0_f64; 2 * 2];
1220        let k = vec![0.0_f64; 3 * 2];
1221        let v = vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0];
1222        let out = scaled_dot_product_attention(&q, &k, &v, 2, 3, 2, 2, None);
1223        assert!((out[0] - 1.0 / 3.0).abs() < 1e-8, "out[0]={}", out[0]);
1224        assert!((out[1] - 1.0 / 3.0).abs() < 1e-8, "out[1]={}", out[1]);
1225    }
1226    #[test]
1227    fn test_sdpa_masking_blocks_position() {
1228        let q = vec![0.0_f64; 2];
1229        let k = vec![0.0_f64; 2];
1230        let v = vec![10.0_f64, 20.0];
1231        let mask = vec![0.0_f64, -1e9, 0.0, 0.0];
1232        let out = scaled_dot_product_attention(&q, &k, &v, 2, 2, 1, 1, Some(&mask));
1233        assert!((out[0] - 10.0).abs() < 1e-4, "masked output[0]={}", out[0]);
1234    }
1235    #[test]
1236    fn test_sdpa_attention_weights_sum_to_one() {
1237        let seq = 4;
1238        let dk = 3;
1239        let q = vec![0.0_f64; seq * dk];
1240        let k = vec![0.0_f64; seq * dk];
1241        let v = vec![1.0_f64; seq];
1242        let out = scaled_dot_product_attention(&q, &k, &v, seq, seq, dk, 1, None);
1243        for &o in &out {
1244            assert!(
1245                (o - 1.0).abs() < 1e-8,
1246                "attention weight sum = 1 check: out={o}"
1247            );
1248        }
1249    }
1250    #[test]
1251    fn test_mha_output_shape() {
1252        let mha = MultiHeadAttention::new(8, 2);
1253        let x = vec![0.0_f64; 5 * 8];
1254        let out = mha.forward(&x, 5);
1255        assert_eq!(out.len(), 5 * 8, "MHA output shape mismatch");
1256    }
1257    #[test]
1258    fn test_mha_num_params() {
1259        let mha = MultiHeadAttention::new(16, 4);
1260        assert_eq!(mha.num_params(), 1040);
1261    }
1262    #[test]
1263    fn test_mha_zero_weights_zero_output_except_bias() {
1264        let mut mha = MultiHeadAttention::new(4, 2);
1265        mha.b_o = vec![1.0_f64; 4];
1266        let x = vec![0.0_f64; 3 * 4];
1267        let out = mha.forward(&x, 3);
1268        for &v in &out {
1269            assert!(
1270                (v - 1.0).abs() < 1e-10,
1271                "out should equal bias=1.0, got {v}"
1272            );
1273        }
1274    }
1275    #[test]
1276    fn test_mha_identity_init_output_finite() {
1277        let mut mha = MultiHeadAttention::new(4, 2);
1278        mha.init_identity();
1279        let x: Vec<f64> = (0..3 * 4).map(|i| (i as f64) * 0.1).collect();
1280        let out = mha.forward(&x, 3);
1281        assert_eq!(out.len(), 3 * 4);
1282        for &v in &out {
1283            assert!(v.is_finite(), "output must be finite, got {v}");
1284        }
1285    }
1286    #[test]
1287    fn test_transformer_ffn_output_shape() {
1288        let ffn = TransformerFfn::new(8, 32);
1289        let x = vec![1.0_f64; 5 * 8];
1290        let out = ffn.forward(&x, 5);
1291        assert_eq!(out.len(), 5 * 8);
1292    }
1293    #[test]
1294    fn test_transformer_ffn_zero_weights_zero_output() {
1295        let ffn = TransformerFfn::new(4, 16);
1296        let x = vec![1.0_f64; 3 * 4];
1297        let out = ffn.forward(&x, 3);
1298        for &v in &out {
1299            assert!(v.abs() < 1e-12, "zero weights โ†’ zero output, got {v}");
1300        }
1301    }
1302    #[test]
1303    fn test_transformer_ffn_relu_activation() {
1304        let mut ffn = TransformerFfn::new(2, 2);
1305        ffn.w1 = vec![-1.0_f64; 4];
1306        let x = vec![1.0, 1.0, 1.0, 1.0];
1307        let out = ffn.forward(&x, 2);
1308        for &v in &out {
1309            assert!(
1310                v.abs() < 1e-12,
1311                "relu-clipped hidden โ†’ zero output, got {v}"
1312            );
1313        }
1314    }
1315    #[test]
1316    fn test_transformer_block_output_shape() {
1317        let block = TransformerBlock::new(8, 2, 32);
1318        let x = vec![0.5_f64; 4 * 8];
1319        let out = block.forward(&x, 4);
1320        assert_eq!(out.len(), 4 * 8, "transformer block output shape mismatch");
1321    }
1322    #[test]
1323    fn test_transformer_block_residual_preserves_input_with_zero_weights() {
1324        let block = TransformerBlock::new(4, 2, 16);
1325        let x = vec![1.0_f64; 3 * 4];
1326        let out = block.forward(&x, 3);
1327        assert_eq!(out.len(), x.len());
1328        for &v in &out {
1329            assert!(v.is_finite(), "transformer block output must be finite");
1330        }
1331    }
1332    #[test]
1333    fn test_transformer_block_output_differs_from_input() {
1334        let mut block = TransformerBlock::new(4, 2, 8);
1335        block.mha.init_identity();
1336        let x: Vec<f64> = (0..4 * 4).map(|i| ((i % 7) as f64) * 0.3 - 0.5).collect();
1337        let out = block.forward(&x, 4);
1338        for &v in &out {
1339            assert!(v.is_finite());
1340        }
1341    }
1342    #[test]
1343    fn test_gnn_layer_output_shape() {
1344        let gnn = GnnLayer::new(4, 8, ExtActivation::Relu);
1345        let n_nodes = 3;
1346        let feats = vec![0.5_f64; n_nodes * 4];
1347        let adj = vec![vec![1usize, 2], vec![0], vec![0, 1]];
1348        let out = gnn.forward(&feats, n_nodes, &adj);
1349        assert_eq!(out.len(), n_nodes * 8, "GNN output shape mismatch");
1350    }
1351    #[test]
1352    fn test_gnn_layer_num_params() {
1353        let gnn = GnnLayer::new(4, 8, ExtActivation::Relu);
1354        assert_eq!(gnn.num_params(), 72);
1355    }
1356    #[test]
1357    fn test_gnn_layer_zero_weights_zero_output() {
1358        let gnn = GnnLayer::new(3, 5, ExtActivation::Linear);
1359        let feats = vec![1.0_f64; 4 * 3];
1360        let adj = vec![vec![1usize], vec![0], vec![3], vec![2]];
1361        let out = gnn.forward(&feats, 4, &adj);
1362        for &v in &out {
1363            assert!(v.abs() < 1e-12, "zero weights โ†’ zero output, got {v}");
1364        }
1365    }
1366    #[test]
1367    fn test_gnn_layer_isolated_node_uses_only_self() {
1368        let mut gnn = GnnLayer::new(2, 2, ExtActivation::Linear);
1369        gnn.w_self = vec![1.0, 0.0, 0.0, 1.0];
1370        let feats = vec![3.0, 7.0, 0.0, 0.0];
1371        let adj = vec![vec![], vec![]];
1372        let out = gnn.forward(&feats, 2, &adj);
1373        assert!((out[0] - 3.0).abs() < 1e-12, "node 0 out[0] = {}", out[0]);
1374        assert!((out[1] - 7.0).abs() < 1e-12, "node 0 out[1] = {}", out[1]);
1375    }
1376    #[test]
1377    fn test_gnn_layer_neighbour_aggregation_sum() {
1378        let mut gnn = GnnLayer::new(2, 2, ExtActivation::Linear);
1379        gnn.w_neigh = vec![1.0, 0.0, 0.0, 1.0];
1380        let feats = vec![0.0, 0.0, 1.0, 0.0, 1.0, 0.0];
1381        let adj = vec![vec![1usize, 2], vec![], vec![]];
1382        let out = gnn.forward(&feats, 3, &adj);
1383        assert!(
1384            (out[0] - 2.0).abs() < 1e-12,
1385            "aggregated out[0] = {}",
1386            out[0]
1387        );
1388        assert!((out[1]).abs() < 1e-12, "aggregated out[1] = {}", out[1]);
1389    }
1390    #[test]
1391    fn test_mpnn_output_shape_two_layers() {
1392        let mut mpnn = MessagePassingNet::new();
1393        mpnn.add_layer(GnnLayer::new(4, 8, ExtActivation::Relu));
1394        mpnn.add_layer(GnnLayer::new(8, 4, ExtActivation::Linear));
1395        let feats = vec![0.1_f64; 5 * 4];
1396        let adj: Vec<Vec<usize>> = (0..5)
1397            .map(|i| if i > 0 { vec![i - 1] } else { vec![] })
1398            .collect();
1399        let out = mpnn.forward(&feats, 5, &adj);
1400        assert_eq!(out.len(), 5 * 4);
1401    }
1402    #[test]
1403    fn test_mpnn_global_mean_pool_shape() {
1404        let mpnn = MessagePassingNet::new();
1405        let feats = vec![1.0_f64; 4 * 3];
1406        let pooled = mpnn.global_mean_pool(&feats, 4, 3);
1407        assert_eq!(pooled.len(), 3);
1408    }
1409    #[test]
1410    fn test_mpnn_global_mean_pool_known_values() {
1411        let mpnn = MessagePassingNet::new();
1412        let feats = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
1413        let pooled = mpnn.global_mean_pool(&feats, 3, 2);
1414        assert!((pooled[0] - 3.0).abs() < 1e-12, "mean[0] = {}", pooled[0]);
1415        assert!((pooled[1] - 4.0).abs() < 1e-12, "mean[1] = {}", pooled[1]);
1416    }
1417    #[test]
1418    fn test_mpnn_empty_returns_zero_pool() {
1419        let mpnn = MessagePassingNet::new();
1420        let pooled = mpnn.global_mean_pool(&[], 0, 4);
1421        assert_eq!(pooled, vec![0.0_f64; 4]);
1422    }
1423    #[test]
1424    fn test_mpnn_default_is_empty() {
1425        let mpnn = MessagePassingNet::default();
1426        assert_eq!(mpnn.layers.len(), 0);
1427    }
1428    #[test]
1429    fn test_attention_readout_output_shape() {
1430        let ar = AttentionReadout::new(8);
1431        let feats = vec![0.0_f64; 5 * 8];
1432        let out = ar.forward(&feats, 5);
1433        assert_eq!(out.len(), 8);
1434    }
1435    #[test]
1436    fn test_attention_readout_zero_weights_equal_scores() {
1437        let ar = AttentionReadout::new(2);
1438        let feats = vec![1.0_f64, 2.0, 3.0, 4.0];
1439        let out = ar.forward(&feats, 2);
1440        assert!((out[0] - 2.0).abs() < 1e-10, "readout[0] = {}", out[0]);
1441        assert!((out[1] - 3.0).abs() < 1e-10, "readout[1] = {}", out[1]);
1442    }
1443    #[test]
1444    fn test_attention_readout_all_zeros_input() {
1445        let ar = AttentionReadout::new(4);
1446        let feats = vec![0.0_f64; 3 * 4];
1447        let out = ar.forward(&feats, 3);
1448        for &v in &out {
1449            assert!(v.abs() < 1e-12, "zero input โ†’ zero output, got {v}");
1450        }
1451    }
1452    #[test]
1453    fn test_attention_readout_single_node() {
1454        let ar = AttentionReadout::new(3);
1455        let feats = vec![2.0, 4.0, 6.0];
1456        let out = ar.forward(&feats, 1);
1457        let score = 0.5_f64;
1458        assert!((out[0] - score * 2.0).abs() < 1e-10);
1459        assert!((out[1] - score * 4.0).abs() < 1e-10);
1460        assert!((out[2] - score * 6.0).abs() < 1e-10);
1461    }
1462}