Skip to main content

oxiphysics_gpu/neural_compute/
types.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5#![allow(clippy::needless_range_loop, clippy::ptr_arg)]
6use std::collections::HashMap;
7use std::f64::consts::PI as PI_F64;
8
9use super::functions::scaled_dot_product_attention;
10#[allow(unused_imports)]
11use super::functions::*;
12
13/// Batch normalization layer (inference mode).
14///
15/// Normalizes input features using stored running mean and variance,
16/// then applies learned scale (gamma) and shift (beta).
17#[derive(Debug, Clone)]
18pub struct BatchNormLayer {
19    /// Running mean for each feature.
20    pub running_mean: Vec<f32>,
21    /// Running variance for each feature.
22    pub running_var: Vec<f32>,
23    /// Learned scale parameter (gamma).
24    pub gamma: Vec<f32>,
25    /// Learned shift parameter (beta).
26    pub beta: Vec<f32>,
27    /// Small constant for numerical stability.
28    pub epsilon: f32,
29    /// Number of features.
30    pub n_features: usize,
31}
32impl BatchNormLayer {
33    /// Create a new batch norm layer with identity transform (gamma=1, beta=0).
34    pub fn new(n_features: usize) -> Self {
35        Self {
36            running_mean: vec![0.0; n_features],
37            running_var: vec![1.0; n_features],
38            gamma: vec![1.0; n_features],
39            beta: vec![0.0; n_features],
40            epsilon: 1e-5,
41            n_features,
42        }
43    }
44    /// Apply batch normalization in inference mode.
45    ///
46    /// output\[i\] = gamma\[i\] * (input\[i\] - mean\[i\]) / sqrt(var\[i\] + eps) + beta\[i\]
47    pub fn forward(&self, input: &[f32]) -> Vec<f32> {
48        assert_eq!(input.len(), self.n_features);
49        let mut output = Vec::with_capacity(self.n_features);
50        for i in 0..self.n_features {
51            let normalized =
52                (input[i] - self.running_mean[i]) / (self.running_var[i] + self.epsilon).sqrt();
53            output.push(self.gamma[i] * normalized + self.beta[i]);
54        }
55        output
56    }
57    /// Set the running statistics.
58    pub fn set_stats(&mut self, mean: &[f32], var: &[f32]) {
59        assert_eq!(mean.len(), self.n_features);
60        assert_eq!(var.len(), self.n_features);
61        self.running_mean.copy_from_slice(mean);
62        self.running_var.copy_from_slice(var);
63    }
64    /// Set the affine parameters.
65    pub fn set_affine(&mut self, gamma: &[f32], beta: &[f32]) {
66        assert_eq!(gamma.len(), self.n_features);
67        assert_eq!(beta.len(), self.n_features);
68        self.gamma.copy_from_slice(gamma);
69        self.beta.copy_from_slice(beta);
70    }
71}
72impl BatchNormLayer {
73    /// Update running statistics from a mini-batch (training mode).
74    ///
75    /// Uses exponential moving average:
76    /// `running_mean = (1-momentum) * running_mean + momentum * batch_mean`
77    ///
78    /// # Panics
79    /// Panics if `batch` is empty or if any sample has the wrong feature count.
80    pub fn update_running_stats(&mut self, batch: &[Vec<f32>], momentum: f32) {
81        assert!(
82            !batch.is_empty(),
83            "update_running_stats: batch must not be empty"
84        );
85        let n = batch.len() as f32;
86        let mut batch_mean = vec![0.0_f32; self.n_features];
87        for sample in batch {
88            assert_eq!(sample.len(), self.n_features, "sample length mismatch");
89            for (k, &v) in sample.iter().enumerate() {
90                batch_mean[k] += v;
91            }
92        }
93        for m in &mut batch_mean {
94            *m /= n;
95        }
96        let mut batch_var = vec![0.0_f32; self.n_features];
97        for sample in batch {
98            for (k, &v) in sample.iter().enumerate() {
99                let d = v - batch_mean[k];
100                batch_var[k] += d * d;
101            }
102        }
103        for v in &mut batch_var {
104            *v /= n;
105        }
106        for k in 0..self.n_features {
107            self.running_mean[k] =
108                (1.0 - momentum) * self.running_mean[k] + momentum * batch_mean[k];
109            self.running_var[k] = (1.0 - momentum) * self.running_var[k] + momentum * batch_var[k];
110        }
111    }
112}
113/// A single-step Elman RNN cell:
114/// `h_t = activation(W_x * x_t + W_h * h_{t-1} + b)`.
115#[allow(dead_code)]
116#[derive(Debug, Clone)]
117pub struct RnnCell {
118    /// Input-to-hidden weight matrix `[hidden_size × input_size]`.
119    pub w_x: Vec<f64>,
120    /// Hidden-to-hidden weight matrix `[hidden_size × hidden_size]`.
121    pub w_h: Vec<f64>,
122    /// Bias vector `[hidden_size]`.
123    pub b: Vec<f64>,
124    /// Input dimensionality.
125    pub input_size: usize,
126    /// Hidden state dimensionality.
127    pub hidden_size: usize,
128    /// Activation function.
129    pub activation: ExtActivation,
130}
131impl RnnCell {
132    /// Create a new RNN cell with zero weights.
133    pub fn new(input_size: usize, hidden_size: usize, activation: ExtActivation) -> Self {
134        Self {
135            w_x: vec![0.0_f64; hidden_size * input_size],
136            w_h: vec![0.0_f64; hidden_size * hidden_size],
137            b: vec![0.0_f64; hidden_size],
138            input_size,
139            hidden_size,
140            activation,
141        }
142    }
143    /// One forward step.
144    ///
145    /// Returns the new hidden state `h_t` of length `hidden_size`.
146    pub fn step(&self, x: &[f64], h_prev: &[f64]) -> Vec<f64> {
147        assert_eq!(x.len(), self.input_size);
148        assert_eq!(h_prev.len(), self.hidden_size);
149        let mut h = Vec::with_capacity(self.hidden_size);
150        for o in 0..self.hidden_size {
151            let mut acc = self.b[o];
152            for i in 0..self.input_size {
153                acc += self.w_x[o * self.input_size + i] * x[i];
154            }
155            for i in 0..self.hidden_size {
156                acc += self.w_h[o * self.hidden_size + i] * h_prev[i];
157            }
158            h.push(self.activation.apply(acc));
159        }
160        h
161    }
162    /// Run the RNN over a full sequence `[seq_len][input_size]`.
163    ///
164    /// Returns all hidden states `[seq_len][hidden_size]`.
165    pub fn forward_sequence(&self, sequence: &[Vec<f64>]) -> Vec<Vec<f64>> {
166        let mut h = vec![0.0_f64; self.hidden_size];
167        let mut hidden_states = Vec::with_capacity(sequence.len());
168        for x in sequence {
169            h = self.step(x, &h);
170            hidden_states.push(h.clone());
171        }
172        hidden_states
173    }
174}
175/// An inference pipeline that chains DenseLayer and BatchNormLayer operations.
176#[derive(Debug, Clone)]
177pub struct InferencePipeline {
178    /// Ordered list of operations.
179    pub ops: Vec<InferenceOp>,
180}
181impl InferencePipeline {
182    /// Create an empty pipeline.
183    pub fn new() -> Self {
184        Self { ops: Vec::new() }
185    }
186    /// Add an operation to the pipeline.
187    pub fn add_op(&mut self, op: InferenceOp) {
188        self.ops.push(op);
189    }
190    /// Run forward pass through all operations.
191    pub fn forward(&self, input: &[f32]) -> Vec<f32> {
192        let mut current = input.to_vec();
193        for op in &self.ops {
194            current = match op {
195                InferenceOp::Dense(layer) => layer.forward(&current),
196                InferenceOp::BatchNorm(bn) => bn.forward(&current),
197                InferenceOp::Activation(act) => current.iter().map(|&x| act.apply(x)).collect(),
198            };
199        }
200        current
201    }
202    /// Total number of trainable parameters.
203    pub fn total_parameters(&self) -> usize {
204        self.ops
205            .iter()
206            .map(|op| match op {
207                InferenceOp::Dense(layer) => layer.parameter_count(),
208                InferenceOp::BatchNorm(bn) => 2 * bn.n_features,
209                InferenceOp::Activation(_) => 0,
210            })
211            .sum()
212    }
213}
214/// Convenience builder for standard AANN architectures.
215pub struct NetworkBuilder;
216impl NetworkBuilder {
217    /// Build a simple element-specific network:
218    ///
219    /// `n_descriptors → hidden[0] (Tanh) → hidden[1] (Tanh) → … → 1 (Linear)`
220    ///
221    /// At least one hidden size must be provided.
222    pub fn simple_aann(
223        n_descriptors: usize,
224        hidden_sizes: &[usize],
225        _element: u8,
226    ) -> FeedForwardNet {
227        let mut net = FeedForwardNet::new();
228        let mut prev = n_descriptors;
229        for &h in hidden_sizes {
230            net.add_layer(DenseLayer::new(prev, h, ActivationFn::Tanh));
231            prev = h;
232        }
233        net.add_layer(DenseLayer::new(prev, 1, ActivationFn::Linear));
234        net
235    }
236}
237/// Position-wise feed-forward network used inside a transformer block.
238///
239/// FFN(x) = max(0, x W1 + b1) W2 + b2
240///
241/// Applied identically to each position in the sequence.
242#[allow(dead_code)]
243#[derive(Debug, Clone)]
244pub struct TransformerFfn {
245    /// Input / output dimensionality.
246    pub d_model: usize,
247    /// Inner (hidden) dimensionality.
248    pub d_ff: usize,
249    /// W1: \[d_ff × d_model\]
250    pub w1: Vec<f64>,
251    /// b1: \[d_ff\]
252    pub b1: Vec<f64>,
253    /// W2: \[d_model × d_ff\]
254    pub w2: Vec<f64>,
255    /// b2: \[d_model\]
256    pub b2: Vec<f64>,
257}
258impl TransformerFfn {
259    /// Create with zero weights.
260    pub fn new(d_model: usize, d_ff: usize) -> Self {
261        Self {
262            d_model,
263            d_ff,
264            w1: vec![0.0_f64; d_ff * d_model],
265            b1: vec![0.0_f64; d_ff],
266            w2: vec![0.0_f64; d_model * d_ff],
267            b2: vec![0.0_f64; d_model],
268        }
269    }
270    /// Forward pass over a sequence `[seq_len × d_model]` (flat row-major).
271    pub fn forward(&self, x: &[f64], seq_len: usize) -> Vec<f64> {
272        let dm = self.d_model;
273        let df = self.d_ff;
274        let mut out = vec![0.0_f64; seq_len * dm];
275        for t in 0..seq_len {
276            let mut hidden = vec![0.0_f64; df];
277            for j in 0..df {
278                let mut acc = self.b1[j];
279                for i in 0..dm {
280                    acc += x[t * dm + i] * self.w1[j * dm + i];
281                }
282                hidden[j] = acc.max(0.0);
283            }
284            for j in 0..dm {
285                let mut acc = self.b2[j];
286                for i in 0..df {
287                    acc += hidden[i] * self.w2[j * df + i];
288                }
289                out[t * dm + j] = acc;
290            }
291        }
292        out
293    }
294}
295/// A 1-D convolutional layer operating on a sequence of feature vectors.
296///
297/// Applies a set of `out_channels` filters, each of length `kernel_size`
298/// spanning `in_channels` input channels, using causal (left) padding so the
299/// output length equals the input length.
300///
301/// Layout:
302/// - `weights[o][k][c]` = weight for output channel `o`, kernel position `k`,
303///   input channel `c`.
304/// - `biases[o]` = bias for output channel `o`.
305#[allow(dead_code)]
306#[derive(Debug, Clone)]
307pub struct Conv1DLayer {
308    /// Number of input channels per time step.
309    pub in_channels: usize,
310    /// Number of output channels per time step.
311    pub out_channels: usize,
312    /// Kernel (filter) length along the time axis.
313    pub kernel_size: usize,
314    /// Filter weights: `weights[out_ch][kernel_pos][in_ch]`.
315    pub weights: Vec<Vec<Vec<f64>>>,
316    /// Bias per output channel.
317    pub biases: Vec<f64>,
318    /// Activation function applied after convolution.
319    pub activation: ExtActivation,
320}
321impl Conv1DLayer {
322    /// Create a new Conv1D layer with zero-initialised weights.
323    pub fn new(
324        in_channels: usize,
325        out_channels: usize,
326        kernel_size: usize,
327        activation: ExtActivation,
328    ) -> Self {
329        let weights = vec![vec![vec![0.0_f64; in_channels]; kernel_size]; out_channels];
330        let biases = vec![0.0_f64; out_channels];
331        Self {
332            in_channels,
333            out_channels,
334            kernel_size,
335            weights,
336            biases,
337            activation,
338        }
339    }
340    /// Forward pass.
341    ///
342    /// `input` has shape `[seq_len][in_channels]`.  Returns a tensor of shape
343    /// `[seq_len][out_channels]` using causal (left-zero) padding.
344    pub fn forward(&self, input: &[Vec<f64>]) -> Vec<Vec<f64>> {
345        let seq_len = input.len();
346        let mut output = vec![vec![0.0_f64; self.out_channels]; seq_len];
347        for t in 0..seq_len {
348            for o in 0..self.out_channels {
349                let mut acc = self.biases[o];
350                for k in 0..self.kernel_size {
351                    let src_t = t as isize - k as isize;
352                    if src_t < 0 {
353                        continue;
354                    }
355                    let src_t = src_t as usize;
356                    for c in 0..self.in_channels {
357                        acc += self.weights[o][k][c] * input[src_t][c];
358                    }
359                }
360                output[t][o] = self.activation.apply(acc);
361            }
362        }
363        output
364    }
365    /// Total number of trainable parameters.
366    pub fn num_params(&self) -> usize {
367        self.out_channels * self.kernel_size * self.in_channels + self.out_channels
368    }
369}
370/// A single operation in the inference pipeline.
371#[derive(Debug, Clone)]
372pub enum InferenceOp {
373    /// Dense (fully-connected) layer.
374    Dense(DenseLayer),
375    /// Batch normalization layer.
376    BatchNorm(BatchNormLayer),
377    /// Activation function (standalone).
378    Activation(ActivationFn),
379}
380/// Adam optimizer for a flat parameter vector.
381///
382/// Reference: Kingma & Ba (2015) "Adam: A Method for Stochastic Optimization".
383#[derive(Debug, Clone)]
384pub struct AdamOptimizer {
385    /// Learning rate α.
386    pub lr: f64,
387    /// Exponential decay rate for first moment estimates.
388    pub beta1: f64,
389    /// Exponential decay rate for second moment estimates.
390    pub beta2: f64,
391    /// Small constant for numerical stability.
392    pub epsilon: f64,
393    /// First moment vector (m).
394    pub m: Vec<f64>,
395    /// Second moment vector (v).
396    pub v: Vec<f64>,
397    /// Current step count (t).
398    pub step: u64,
399}
400impl AdamOptimizer {
401    /// Create a new Adam optimizer for a parameter vector of length `n_params`.
402    pub fn new(n_params: usize, lr: f64, beta1: f64, beta2: f64, epsilon: f64) -> Self {
403        Self {
404            lr,
405            beta1,
406            beta2,
407            epsilon,
408            m: vec![0.0; n_params],
409            v: vec![0.0; n_params],
410            step: 0,
411        }
412    }
413    /// Create an Adam optimizer with default hyperparameters (lr=1e-3, β1=0.9, β2=0.999, ε=1e-8).
414    pub fn default_params(n_params: usize) -> Self {
415        Self::new(n_params, 1e-3, 0.9, 0.999, 1e-8)
416    }
417    /// Apply one Adam update step to `params` using `grads`.
418    ///
419    /// Updates `params` in-place and increments the step counter.
420    pub fn step_update(&mut self, params: &mut [f64], grads: &[f64]) {
421        assert_eq!(
422            params.len(),
423            self.m.len(),
424            "AdamOptimizer::step_update: params/m length mismatch"
425        );
426        assert_eq!(
427            grads.len(),
428            self.m.len(),
429            "AdamOptimizer::step_update: grads/m length mismatch"
430        );
431        self.step += 1;
432        let t = self.step as f64;
433        let bias_corr1 = 1.0 - self.beta1.powf(t);
434        let bias_corr2 = 1.0 - self.beta2.powf(t);
435        for i in 0..params.len() {
436            self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * grads[i];
437            self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * grads[i] * grads[i];
438            let m_hat = self.m[i] / bias_corr1;
439            let v_hat = self.v[i] / bias_corr2;
440            params[i] -= self.lr * m_hat / (v_hat.sqrt() + self.epsilon);
441        }
442    }
443    /// Reset state (moments and step counter) to zero.
444    pub fn reset(&mut self) {
445        self.m.iter_mut().for_each(|x| *x = 0.0);
446        self.v.iter_mut().for_each(|x| *x = 0.0);
447        self.step = 0;
448    }
449}
450/// A single graph neural network layer implementing the sum-aggregation
451/// message passing update:
452///
453/// h_i^(l+1) = σ(W_self * h_i^(l) + W_neigh * Σ_{j ∈ N(i)} h_j^(l) + b)
454///
455/// All nodes share the same weight matrices.
456#[allow(dead_code)]
457#[derive(Debug, Clone)]
458pub struct GnnLayer {
459    /// Input feature dimension.
460    pub in_dim: usize,
461    /// Output feature dimension.
462    pub out_dim: usize,
463    /// Self-loop weight matrix W_self \[out_dim × in_dim\].
464    pub w_self: Vec<f64>,
465    /// Neighbour aggregation weight matrix W_neigh \[out_dim × in_dim\].
466    pub w_neigh: Vec<f64>,
467    /// Bias vector \[out_dim\].
468    pub bias: Vec<f64>,
469    /// Activation function.
470    pub activation: ExtActivation,
471}
472impl GnnLayer {
473    /// Create a new GNN layer with zero-initialised weights.
474    pub fn new(in_dim: usize, out_dim: usize, activation: ExtActivation) -> Self {
475        Self {
476            in_dim,
477            out_dim,
478            w_self: vec![0.0_f64; out_dim * in_dim],
479            w_neigh: vec![0.0_f64; out_dim * in_dim],
480            bias: vec![0.0_f64; out_dim],
481            activation,
482        }
483    }
484    /// Forward pass.
485    ///
486    /// * `node_feats` – `[n_nodes × in_dim]` flat row-major node feature matrix.
487    /// * `adj`        – adjacency list: `adj[i]` contains the neighbour indices of node `i`.
488    ///
489    /// Returns `[n_nodes × out_dim]` flat row-major.
490    pub fn forward(&self, node_feats: &[f64], n_nodes: usize, adj: &[Vec<usize>]) -> Vec<f64> {
491        assert_eq!(node_feats.len(), n_nodes * self.in_dim);
492        assert_eq!(adj.len(), n_nodes);
493        let in_d = self.in_dim;
494        let out_d = self.out_dim;
495        let mut out = vec![0.0_f64; n_nodes * out_d];
496        for i in 0..n_nodes {
497            let h_self = &node_feats[i * in_d..(i + 1) * in_d];
498            let mut agg = vec![0.0_f64; in_d];
499            for &j in &adj[i] {
500                let h_j = &node_feats[j * in_d..(j + 1) * in_d];
501                for d in 0..in_d {
502                    agg[d] += h_j[d];
503                }
504            }
505            for o in 0..out_d {
506                let mut acc = self.bias[o];
507                for d in 0..in_d {
508                    acc += self.w_self[o * in_d + d] * h_self[d];
509                    acc += self.w_neigh[o * in_d + d] * agg[d];
510                }
511                out[i * out_d + o] = self.activation.apply(acc);
512            }
513        }
514        out
515    }
516    /// Total number of trainable parameters.
517    pub fn num_params(&self) -> usize {
518        2 * self.out_dim * self.in_dim + self.out_dim
519    }
520}
521/// A multi-layer message passing neural network stacking `GnnLayer`s.
522#[allow(dead_code)]
523#[derive(Debug, Clone)]
524pub struct MessagePassingNet {
525    /// Ordered list of GNN layers.
526    pub layers: Vec<GnnLayer>,
527}
528impl MessagePassingNet {
529    /// Create an empty MPNN.
530    pub fn new() -> Self {
531        Self { layers: Vec::new() }
532    }
533    /// Add a GNN layer to the stack.
534    pub fn add_layer(&mut self, layer: GnnLayer) {
535        self.layers.push(layer);
536    }
537    /// Run all layers in sequence over a fixed graph.
538    ///
539    /// Returns the final node feature matrix `[n_nodes × last_out_dim]`.
540    pub fn forward(&self, node_feats: &[f64], n_nodes: usize, adj: &[Vec<usize>]) -> Vec<f64> {
541        let mut h = node_feats.to_vec();
542        for layer in &self.layers {
543            h = layer.forward(&h, n_nodes, adj);
544        }
545        h
546    }
547    /// Aggregate node features to a single graph-level representation (mean pooling).
548    pub fn global_mean_pool(&self, node_feats: &[f64], n_nodes: usize, out_dim: usize) -> Vec<f64> {
549        if n_nodes == 0 {
550            return vec![0.0_f64; out_dim];
551        }
552        let mut pooled = vec![0.0_f64; out_dim];
553        for i in 0..n_nodes {
554            for d in 0..out_dim {
555                pooled[d] += node_feats[i * out_dim + d];
556            }
557        }
558        let inv_n = 1.0 / n_nodes as f64;
559        for v in &mut pooled {
560            *v *= inv_n;
561        }
562        pooled
563    }
564}
565/// Accumulates gradients from multiple backward passes for mini-batch training.
566#[derive(Debug, Clone)]
567pub struct GradAccumulator {
568    /// Accumulated weight gradients.
569    pub grad_weights: Vec<f64>,
570    /// Accumulated bias gradients.
571    pub grad_biases: Vec<f64>,
572    /// Number of samples accumulated.
573    pub count: usize,
574}
575impl GradAccumulator {
576    /// Create a new accumulator sized for `n_weights` weights and `n_biases` biases.
577    pub fn new(n_weights: usize, n_biases: usize) -> Self {
578        Self {
579            grad_weights: vec![0.0; n_weights],
580            grad_biases: vec![0.0; n_biases],
581            count: 0,
582        }
583    }
584    /// Add a set of gradients (accumulate without dividing).
585    pub fn accumulate(&mut self, gw: &[f64], gb: &[f64]) {
586        assert_eq!(gw.len(), self.grad_weights.len());
587        assert_eq!(gb.len(), self.grad_biases.len());
588        for (acc, &g) in self.grad_weights.iter_mut().zip(gw.iter()) {
589            *acc += g;
590        }
591        for (acc, &g) in self.grad_biases.iter_mut().zip(gb.iter()) {
592            *acc += g;
593        }
594        self.count += 1;
595    }
596    /// Compute mean gradients (divide by count) and return them.
597    pub fn mean_grads(&self) -> (Vec<f64>, Vec<f64>) {
598        let n = self.count.max(1) as f64;
599        let gw: Vec<f64> = self.grad_weights.iter().map(|&g| g / n).collect();
600        let gb: Vec<f64> = self.grad_biases.iter().map(|&g| g / n).collect();
601        (gw, gb)
602    }
603    /// Zero all accumulated gradients and reset count.
604    pub fn zero(&mut self) {
605        self.grad_weights.iter_mut().for_each(|x| *x = 0.0);
606        self.grad_biases.iter_mut().for_each(|x| *x = 0.0);
607        self.count = 0;
608    }
609}
610/// Multi-head attention module.
611///
612/// Projects Q, K, V with learned linear projections, runs `n_heads`
613/// parallel attention heads, then concatenates and projects the output.
614///
615/// All weight matrices are stored flat row-major.
616#[allow(dead_code)]
617#[derive(Debug, Clone)]
618pub struct MultiHeadAttention {
619    /// Model dimensionality.
620    pub d_model: usize,
621    /// Number of attention heads.
622    pub n_heads: usize,
623    /// Dimensionality per head: `d_model / n_heads`.
624    pub d_head: usize,
625    /// W_Q projection \[d_model × d_model\].
626    pub w_q: Vec<f64>,
627    /// W_K projection \[d_model × d_model\].
628    pub w_k: Vec<f64>,
629    /// W_V projection \[d_model × d_model\].
630    pub w_v: Vec<f64>,
631    /// W_O output projection \[d_model × d_model\].
632    pub w_o: Vec<f64>,
633    /// Output bias \[d_model\].
634    pub b_o: Vec<f64>,
635}
636impl MultiHeadAttention {
637    /// Create a new MHA module with zero-initialised projections.
638    pub fn new(d_model: usize, n_heads: usize) -> Self {
639        assert_eq!(d_model % n_heads, 0, "d_model must be divisible by n_heads");
640        let d_head = d_model / n_heads;
641        let dm2 = d_model * d_model;
642        Self {
643            d_model,
644            n_heads,
645            d_head,
646            w_q: vec![0.0_f64; dm2],
647            w_k: vec![0.0_f64; dm2],
648            w_v: vec![0.0_f64; dm2],
649            w_o: vec![0.0_f64; dm2],
650            b_o: vec![0.0_f64; d_model],
651        }
652    }
653    /// Initialise W_Q, W_K, W_V, W_O with identity-like weights for testing.
654    pub fn init_identity(&mut self) {
655        let dm = self.d_model;
656        for row in 0..dm {
657            self.w_q[row * dm + row] = 1.0;
658            self.w_k[row * dm + row] = 1.0;
659            self.w_v[row * dm + row] = 1.0;
660            self.w_o[row * dm + row] = 1.0;
661        }
662    }
663    /// Linear projection: `output = input @ W^T`  where W is `[out × in]`.
664    fn project(
665        input: &[f64],
666        w: &[f64],
667        seq_len: usize,
668        in_dim: usize,
669        out_dim: usize,
670    ) -> Vec<f64> {
671        let mut out = vec![0.0_f64; seq_len * out_dim];
672        for t in 0..seq_len {
673            for o in 0..out_dim {
674                let mut acc = 0.0_f64;
675                for i in 0..in_dim {
676                    acc += input[t * in_dim + i] * w[o * in_dim + i];
677                }
678                out[t * out_dim + o] = acc;
679            }
680        }
681        out
682    }
683    /// Forward pass.
684    ///
685    /// `x` has shape `[seq_len × d_model]` (flat row-major).
686    /// Returns output of shape `[seq_len × d_model]`.
687    pub fn forward(&self, x: &[f64], seq_len: usize) -> Vec<f64> {
688        let dm = self.d_model;
689        let dh = self.d_head;
690        let nh = self.n_heads;
691        let q_full = Self::project(x, &self.w_q, seq_len, dm, dm);
692        let k_full = Self::project(x, &self.w_k, seq_len, dm, dm);
693        let v_full = Self::project(x, &self.w_v, seq_len, dm, dm);
694        let mut concat = vec![0.0_f64; seq_len * dm];
695        for h in 0..nh {
696            let mut q_h = vec![0.0_f64; seq_len * dh];
697            let mut k_h = vec![0.0_f64; seq_len * dh];
698            let mut v_h = vec![0.0_f64; seq_len * dh];
699            for t in 0..seq_len {
700                for d in 0..dh {
701                    q_h[t * dh + d] = q_full[t * dm + h * dh + d];
702                    k_h[t * dh + d] = k_full[t * dm + h * dh + d];
703                    v_h[t * dh + d] = v_full[t * dm + h * dh + d];
704                }
705            }
706            let head_out =
707                scaled_dot_product_attention(&q_h, &k_h, &v_h, seq_len, seq_len, dh, dh, None);
708            for t in 0..seq_len {
709                for d in 0..dh {
710                    concat[t * dm + h * dh + d] = head_out[t * dh + d];
711                }
712            }
713        }
714        let projected = Self::project(&concat, &self.w_o, seq_len, dm, dm);
715        let mut output = projected;
716        for t in 0..seq_len {
717            for d in 0..dm {
718                output[t * dm + d] += self.b_o[d];
719            }
720        }
721        output
722    }
723    /// Total number of trainable parameters.
724    pub fn num_params(&self) -> usize {
725        4 * self.d_model * self.d_model + self.d_model
726    }
727}
728/// A single fully-connected layer with f64 weights.
729#[derive(Debug, Clone)]
730pub struct NeuralLayer {
731    /// Weight matrix: `weights[out][in]`.
732    pub weights: Vec<Vec<f64>>,
733    /// Bias vector of length `out_features`.
734    pub biases: Vec<f64>,
735    /// Activation function applied after the affine transform.
736    pub activation: ActivationFn64,
737}
738impl NeuralLayer {
739    /// Create a new layer with Xavier-uniform initialised weights.
740    ///
741    /// Xavier uniform: U(-limit, limit) where limit = sqrt(6 / (fan_in + fan_out)).
742    pub fn new_xavier(in_features: usize, out_features: usize, activation: ActivationFn64) -> Self {
743        let limit = (6.0_f64 / (in_features + out_features) as f64).sqrt();
744        let mut state: u64 = 0x123456789abcdef0;
745        let lcg_next = |s: &mut u64| -> f64 {
746            *s = s
747                .wrapping_mul(6364136223846793005)
748                .wrapping_add(1442695040888963407);
749            let bits = (*s >> 33) as f64;
750            bits / (u64::MAX as f64) * 2.0 * limit - limit
751        };
752        let weights: Vec<Vec<f64>> = (0..out_features)
753            .map(|_| (0..in_features).map(|_| lcg_next(&mut state)).collect())
754            .collect();
755        let biases = vec![0.0_f64; out_features];
756        Self {
757            weights,
758            biases,
759            activation,
760        }
761    }
762    /// Forward pass: activation(W * input + b).
763    pub fn forward(&self, input: &[f64]) -> Vec<f64> {
764        let out_features = self.weights.len();
765        let mut output = Vec::with_capacity(out_features);
766        for o in 0..out_features {
767            let mut acc = self.biases[o];
768            for (i, &x) in input.iter().enumerate() {
769                acc += self.weights[o][i] * x;
770            }
771            output.push(self.activation.apply(acc));
772        }
773        output
774    }
775}
776/// Attention-based graph readout that computes a weighted sum of node features.
777///
778/// For each node i, computes a scalar attention score a_i = sigmoid(w · h_i + b),
779/// then returns Σ_i a_i * h_i.
780#[allow(dead_code)]
781#[derive(Debug, Clone)]
782pub struct AttentionReadout {
783    /// Feature dimensionality.
784    pub d_feat: usize,
785    /// Attention weight vector \[d_feat\].
786    pub w_attn: Vec<f64>,
787    /// Attention bias (scalar).
788    pub b_attn: f64,
789}
790impl AttentionReadout {
791    /// Create with zero attention weights.
792    pub fn new(d_feat: usize) -> Self {
793        Self {
794            d_feat,
795            w_attn: vec![0.0_f64; d_feat],
796            b_attn: 0.0,
797        }
798    }
799    /// Compute attention-weighted sum over nodes.
800    ///
801    /// `node_feats`: `[n_nodes × d_feat]` flat row-major.
802    pub fn forward(&self, node_feats: &[f64], n_nodes: usize) -> Vec<f64> {
803        let df = self.d_feat;
804        let mut out = vec![0.0_f64; df];
805        let mut attn_scores = Vec::with_capacity(n_nodes);
806        for i in 0..n_nodes {
807            let h = &node_feats[i * df..(i + 1) * df];
808            let raw: f64 = h
809                .iter()
810                .zip(self.w_attn.iter())
811                .map(|(&x, &w)| x * w)
812                .sum::<f64>()
813                + self.b_attn;
814            let score = 1.0 / (1.0 + (-raw).exp());
815            attn_scores.push(score);
816        }
817        for i in 0..n_nodes {
818            let h = &node_feats[i * df..(i + 1) * df];
819            for d in 0..df {
820                out[d] += attn_scores[i] * h[d];
821            }
822        }
823        out
824    }
825}
826/// A single transformer encoder block:
827/// x → MHA(LayerNorm(x)) + x → FFN(LayerNorm(·)) + ·
828#[allow(dead_code)]
829#[derive(Debug, Clone)]
830pub struct TransformerBlock {
831    /// Multi-head self-attention module.
832    pub mha: MultiHeadAttention,
833    /// Feed-forward network.
834    pub ffn: TransformerFfn,
835    /// Layer norm before MHA.
836    pub ln1: LayerNorm,
837    /// Layer norm before FFN.
838    pub ln2: LayerNorm,
839    /// Model dimensionality.
840    pub d_model: usize,
841}
842impl TransformerBlock {
843    /// Create a new transformer block with zero weights.
844    pub fn new(d_model: usize, n_heads: usize, d_ff: usize) -> Self {
845        Self {
846            mha: MultiHeadAttention::new(d_model, n_heads),
847            ffn: TransformerFfn::new(d_model, d_ff),
848            ln1: LayerNorm::new(d_model),
849            ln2: LayerNorm::new(d_model),
850            d_model,
851        }
852    }
853    /// Forward pass with pre-norm style residual connections.
854    ///
855    /// `x` is flat row-major `[seq_len × d_model]`.
856    pub fn forward(&self, x: &[f64], seq_len: usize) -> Vec<f64> {
857        let dm = self.d_model;
858        let mut normed1 = vec![0.0_f64; seq_len * dm];
859        for t in 0..seq_len {
860            let row = &x[t * dm..(t + 1) * dm];
861            let n = self.ln1.forward(row);
862            normed1[t * dm..(t + 1) * dm].copy_from_slice(&n);
863        }
864        let attn_out = self.mha.forward(&normed1, seq_len);
865        let mut x1 = vec![0.0_f64; seq_len * dm];
866        for i in 0..x1.len() {
867            x1[i] = x[i] + attn_out[i];
868        }
869        let mut normed2 = vec![0.0_f64; seq_len * dm];
870        for t in 0..seq_len {
871            let row = &x1[t * dm..(t + 1) * dm];
872            let n = self.ln2.forward(row);
873            normed2[t * dm..(t + 1) * dm].copy_from_slice(&n);
874        }
875        let ffn_out = self.ffn.forward(&normed2, seq_len);
876        let mut x2 = vec![0.0_f64; seq_len * dm];
877        for i in 0..x2.len() {
878            x2[i] = x1[i] + ffn_out[i];
879        }
880        x2
881    }
882}
883/// Activation functions for neural network layers.
884#[derive(Debug, Clone, PartialEq)]
885pub enum ActivationFn {
886    /// Hyperbolic tangent.
887    Tanh,
888    /// Rectified linear unit.
889    Relu,
890    /// Logistic sigmoid.
891    Sigmoid,
892    /// Sigmoid-weighted linear unit.
893    Silu,
894    /// Gaussian error linear unit (approximation).
895    Gelu,
896    /// Identity / no activation.
897    Linear,
898}
899impl ActivationFn {
900    /// Evaluate the activation function at `x`.
901    pub fn apply(&self, x: f32) -> f32 {
902        match self {
903            ActivationFn::Tanh => x.tanh(),
904            ActivationFn::Relu => x.max(0.0),
905            ActivationFn::Sigmoid => 1.0 / (1.0 + (-x).exp()),
906            ActivationFn::Silu => x / (1.0 + (-x).exp()),
907            ActivationFn::Gelu => {
908                let cdf = 0.5
909                    * (1.0
910                        + (std::f32::consts::FRAC_2_SQRT_PI.sqrt() * (x + 0.044715 * x * x * x))
911                            .tanh());
912                x * cdf
913            }
914            ActivationFn::Linear => x,
915        }
916    }
917    /// Evaluate the derivative of the activation function at `x`.
918    pub fn derivative(&self, x: f32) -> f32 {
919        match self {
920            ActivationFn::Tanh => {
921                let t = x.tanh();
922                1.0 - t * t
923            }
924            ActivationFn::Relu => {
925                if x > 0.0 {
926                    1.0
927                } else {
928                    0.0
929                }
930            }
931            ActivationFn::Sigmoid => {
932                let s = 1.0 / (1.0 + (-x).exp());
933                s * (1.0 - s)
934            }
935            ActivationFn::Silu => {
936                let s = 1.0 / (1.0 + (-x).exp());
937                s + x * s * (1.0 - s)
938            }
939            ActivationFn::Gelu => {
940                let eps = 1e-5_f32;
941                (self.apply(x + eps) - self.apply(x - eps)) / (2.0 * eps)
942            }
943            ActivationFn::Linear => 1.0,
944        }
945    }
946}
947/// Layer normalisation (Ba et al., 2016).
948///
949/// Normalises the *entire* feature vector of a single sample to zero mean and
950/// unit variance, then applies learned scale (gamma) and shift (beta).
951#[derive(Debug, Clone)]
952pub struct LayerNormLayer {
953    /// Number of features (last dimension size).
954    pub n_features: usize,
955    /// Learned scale parameter (gamma), initialised to 1.
956    pub gamma: Vec<f64>,
957    /// Learned shift parameter (beta), initialised to 0.
958    pub beta: Vec<f64>,
959    /// Numerical stability constant.
960    pub epsilon: f64,
961}
962impl LayerNormLayer {
963    /// Create a new LayerNorm with identity transform (gamma=1, beta=0).
964    pub fn new(n_features: usize) -> Self {
965        Self {
966            n_features,
967            gamma: vec![1.0; n_features],
968            beta: vec![0.0; n_features],
969            epsilon: 1e-5,
970        }
971    }
972    /// Apply layer normalisation to one sample vector.
973    ///
974    /// output\[i\] = gamma\[i\] * (input\[i\] - mean) / sqrt(var + eps) + beta\[i\]
975    pub fn forward(&self, input: &[f64]) -> Vec<f64> {
976        assert_eq!(
977            input.len(),
978            self.n_features,
979            "LayerNorm: input size mismatch"
980        );
981        let n = self.n_features as f64;
982        let mean: f64 = input.iter().sum::<f64>() / n;
983        let var: f64 = input.iter().map(|&x| (x - mean) * (x - mean)).sum::<f64>() / n;
984        let std_inv = 1.0 / (var + self.epsilon).sqrt();
985        (0..self.n_features)
986            .map(|i| self.gamma[i] * (input[i] - mean) * std_inv + self.beta[i])
987            .collect()
988    }
989    /// Compute gradient of the layer norm output with respect to the input.
990    ///
991    /// Returns `(d_input, d_gamma, d_beta)` given upstream gradient `d_output`.
992    #[allow(non_snake_case)]
993    pub fn backward(&self, input: &[f64], d_output: &[f64]) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
994        assert_eq!(input.len(), self.n_features);
995        assert_eq!(d_output.len(), self.n_features);
996        let n = self.n_features as f64;
997        let mean: f64 = input.iter().sum::<f64>() / n;
998        let var: f64 = input.iter().map(|&x| (x - mean) * (x - mean)).sum::<f64>() / n;
999        let std_inv = 1.0 / (var + self.epsilon).sqrt();
1000        let x_hat: Vec<f64> = input.iter().map(|&x| (x - mean) * std_inv).collect();
1001        let d_gamma: Vec<f64> = (0..self.n_features)
1002            .map(|i| d_output[i] * x_hat[i])
1003            .collect();
1004        let d_beta: Vec<f64> = d_output.to_vec();
1005        let d_x_hat: Vec<f64> = (0..self.n_features)
1006            .map(|i| d_output[i] * self.gamma[i])
1007            .collect();
1008        let sum_d_x_hat: f64 = d_x_hat.iter().sum();
1009        let sum_d_x_hat_xhat: f64 = d_x_hat.iter().zip(x_hat.iter()).map(|(&a, &b)| a * b).sum();
1010        let d_input: Vec<f64> = (0..self.n_features)
1011            .map(|i| std_inv * (d_x_hat[i] - (sum_d_x_hat + x_hat[i] * sum_d_x_hat_xhat) / n))
1012            .collect();
1013        (d_input, d_gamma, d_beta)
1014    }
1015}
1016/// Feature-wise Z-score normalizer.
1017///
1018/// Stores per-feature mean and standard deviation fitted on a training corpus.
1019#[derive(Debug, Clone)]
1020pub struct DataNormalizer {
1021    /// Per-feature mean.
1022    pub mean: Vec<f32>,
1023    /// Per-feature standard deviation.
1024    pub std_dev: Vec<f32>,
1025}
1026impl DataNormalizer {
1027    /// Fit normalizer statistics from a collection of sample vectors.
1028    ///
1029    /// # Panics
1030    /// Panics if `data` is empty or if sample vectors have inconsistent lengths.
1031    pub fn fit(data: &[Vec<f32>]) -> Self {
1032        assert!(
1033            !data.is_empty(),
1034            "DataNormalizer::fit: data must be non-empty"
1035        );
1036        let n_features = data[0].len();
1037        let n = data.len() as f32;
1038        let mut mean = vec![0.0_f32; n_features];
1039        for sample in data {
1040            assert_eq!(
1041                sample.len(),
1042                n_features,
1043                "DataNormalizer::fit: inconsistent sample length"
1044            );
1045            for (k, &v) in sample.iter().enumerate() {
1046                mean[k] += v;
1047            }
1048        }
1049        for m in &mut mean {
1050            *m /= n;
1051        }
1052        let mut variance = vec![0.0_f32; n_features];
1053        for sample in data {
1054            for (k, &v) in sample.iter().enumerate() {
1055                let diff = v - mean[k];
1056                variance[k] += diff * diff;
1057            }
1058        }
1059        let std_dev: Vec<f32> = variance
1060            .iter()
1061            .map(|&v| {
1062                let s = (v / n).sqrt();
1063                if s < 1e-8 { 1.0 } else { s }
1064            })
1065            .collect();
1066        DataNormalizer { mean, std_dev }
1067    }
1068    /// Standardise a single sample: `(x - mean) / std`.
1069    pub fn transform(&self, x: &[f32]) -> Vec<f32> {
1070        x.iter()
1071            .zip(self.mean.iter())
1072            .zip(self.std_dev.iter())
1073            .map(|((&xi, &m), &s)| (xi - m) / s)
1074            .collect()
1075    }
1076    /// Invert standardisation: `x * std + mean`.
1077    pub fn inverse_transform(&self, x: &[f32]) -> Vec<f32> {
1078        x.iter()
1079            .zip(self.mean.iter())
1080            .zip(self.std_dev.iter())
1081            .map(|((&xi, &m), &s)| xi * s + m)
1082            .collect()
1083    }
1084}
1085/// Activation function for f64-precision neural network layers.
1086#[derive(Debug, Clone, PartialEq)]
1087pub enum ActivationFn64 {
1088    /// Rectified linear unit: max(0, x).
1089    Relu,
1090    /// Logistic sigmoid: 1 / (1 + exp(-x)).
1091    Sigmoid,
1092    /// Hyperbolic tangent.
1093    Tanh,
1094    /// Identity (no activation).
1095    Linear,
1096}
1097impl ActivationFn64 {
1098    /// Evaluate the activation at a single value.
1099    pub fn apply(&self, x: f64) -> f64 {
1100        match self {
1101            ActivationFn64::Relu => x.max(0.0),
1102            ActivationFn64::Sigmoid => 1.0 / (1.0 + (-x).exp()),
1103            ActivationFn64::Tanh => x.tanh(),
1104            ActivationFn64::Linear => x,
1105        }
1106    }
1107    /// Apply the activation in-place to every element of a vector.
1108    pub fn apply_batch(&self, v: &mut Vec<f64>) {
1109        for x in v.iter_mut() {
1110            *x = self.apply(*x);
1111        }
1112    }
1113}
1114/// A sequential feed-forward neural network using f64 precision.
1115#[derive(Debug, Clone)]
1116pub struct NeuralNetwork {
1117    /// Ordered list of layers.
1118    pub layers: Vec<NeuralLayer>,
1119}
1120impl NeuralNetwork {
1121    /// Build a network with Xavier-initialised weights from a list of layer sizes.
1122    ///
1123    /// All hidden layers use `activation`; the final layer uses `ActivationFn64::Linear`.
1124    pub fn new(layer_sizes: &[usize], activation: ActivationFn64) -> Self {
1125        assert!(
1126            layer_sizes.len() >= 2,
1127            "need at least input and output size"
1128        );
1129        let mut layers = Vec::new();
1130        for i in 0..layer_sizes.len() - 1 {
1131            let act = if i == layer_sizes.len() - 2 {
1132                ActivationFn64::Linear
1133            } else {
1134                activation.clone()
1135            };
1136            layers.push(NeuralLayer::new_xavier(
1137                layer_sizes[i],
1138                layer_sizes[i + 1],
1139                act,
1140            ));
1141        }
1142        Self { layers }
1143    }
1144    /// Run a forward pass through all layers.
1145    pub fn forward(&self, input: &[f64]) -> Vec<f64> {
1146        let mut current: Vec<f64> = input.to_vec();
1147        for layer in &self.layers {
1148            current = layer.forward(&current);
1149        }
1150        current
1151    }
1152    /// Expected input dimension (size of the first layer's input).
1153    pub fn input_dim(&self) -> usize {
1154        self.layers
1155            .first()
1156            .map_or(0, |l| l.weights.first().map_or(0, |r| r.len()))
1157    }
1158    /// Expected output dimension (size of the last layer's output).
1159    pub fn output_dim(&self) -> usize {
1160        self.layers.last().map_or(0, |l| l.biases.len())
1161    }
1162}
1163/// A sequential feed-forward neural network.
1164#[derive(Debug, Clone)]
1165pub struct FeedForwardNet {
1166    /// Ordered list of dense layers.
1167    pub layers: Vec<DenseLayer>,
1168}
1169impl FeedForwardNet {
1170    /// Create an empty network.
1171    pub fn new() -> Self {
1172        FeedForwardNet { layers: Vec::new() }
1173    }
1174    /// Append a layer to the end of the network.
1175    pub fn add_layer(&mut self, layer: DenseLayer) {
1176        self.layers.push(layer);
1177    }
1178    /// Run a forward pass through all layers.
1179    pub fn forward(&self, input: &[f32]) -> Vec<f32> {
1180        let mut current: Vec<f32> = input.to_vec();
1181        for layer in &self.layers {
1182            current = layer.forward(&current);
1183        }
1184        current
1185    }
1186    /// Returns the expected input width (`None` if the network has no layers).
1187    pub fn input_size(&self) -> Option<usize> {
1188        self.layers.first().map(|l| l.in_features)
1189    }
1190    /// Returns the output width of the last layer (`None` if empty).
1191    pub fn output_size(&self) -> Option<usize> {
1192        self.layers.last().map(|l| l.out_features)
1193    }
1194    /// Sum of parameters across all layers.
1195    pub fn total_parameters(&self) -> usize {
1196        self.layers.iter().map(|l| l.parameter_count()).sum()
1197    }
1198}
1199impl FeedForwardNet {
1200    /// Compute the total gradient norm across all layers.
1201    ///
1202    /// `layer_grads[i]` is the concatenated `[grad_weights, grad_biases]` for
1203    /// layer `i`.  Returns the L2 norm of all gradients combined.
1204    pub fn compute_gradient_norm(&self, layer_grads: &[Vec<f32>]) -> f32 {
1205        let sum_sq: f32 = layer_grads
1206            .iter()
1207            .flat_map(|g| g.iter())
1208            .map(|&v| v * v)
1209            .sum();
1210        sum_sq.sqrt()
1211    }
1212    /// Clip per-layer gradient vectors in-place so their combined norm ≤ `max_norm`.
1213    /// Returns the pre-clip norm.
1214    pub fn clip_gradients(&self, layer_grads: &mut Vec<Vec<f32>>, max_norm: f32) -> f32 {
1215        let norm = self.compute_gradient_norm(layer_grads);
1216        if norm > max_norm && norm > 0.0 {
1217            let scale = max_norm / norm;
1218            for g in layer_grads.iter_mut() {
1219                for v in g.iter_mut() {
1220                    *v *= scale;
1221                }
1222            }
1223        }
1224        norm
1225    }
1226}
1227/// Layer normalisation applied to each time step independently.
1228///
1229/// Normalises a feature vector of length `n_features` to zero mean and unit
1230/// variance, then applies learnable scale (gamma) and bias (beta).
1231#[allow(dead_code)]
1232#[derive(Debug, Clone)]
1233pub struct LayerNorm {
1234    /// Number of features.
1235    pub n_features: usize,
1236    /// Learnable scale parameter.
1237    pub gamma: Vec<f64>,
1238    /// Learnable shift parameter.
1239    pub beta: Vec<f64>,
1240    /// Numerical stability constant.
1241    pub epsilon: f64,
1242}
1243impl LayerNorm {
1244    /// Create a new layer norm with identity initialisation (gamma=1, beta=0).
1245    pub fn new(n_features: usize) -> Self {
1246        Self {
1247            n_features,
1248            gamma: vec![1.0_f64; n_features],
1249            beta: vec![0.0_f64; n_features],
1250            epsilon: 1e-5,
1251        }
1252    }
1253    /// Normalise a single feature vector.
1254    pub fn forward(&self, x: &[f64]) -> Vec<f64> {
1255        assert_eq!(x.len(), self.n_features);
1256        let n = self.n_features as f64;
1257        let mean = x.iter().sum::<f64>() / n;
1258        let var = x.iter().map(|&v| (v - mean) * (v - mean)).sum::<f64>() / n;
1259        let std = (var + self.epsilon).sqrt();
1260        x.iter()
1261            .enumerate()
1262            .map(|(i, &v)| self.gamma[i] * (v - mean) / std + self.beta[i])
1263            .collect()
1264    }
1265}
1266/// A mock GPU buffer for batched neural network inference.
1267///
1268/// In a real GPU backend this would be a device-side buffer; here we
1269/// store a flat f64 array in host memory.
1270#[derive(Debug, Clone)]
1271pub struct GpuNeuralBuffer {
1272    /// Batch size (number of samples).
1273    pub batch_size: usize,
1274    /// Dimensionality of each input sample.
1275    pub input_dim: usize,
1276    /// Dimensionality of each output sample.
1277    pub output_dim: usize,
1278    /// Flat data storage: `batch_size * max(input_dim, output_dim)` elements.
1279    pub data: Vec<f64>,
1280}
1281impl GpuNeuralBuffer {
1282    /// Pack a slice of 3-D positions into a buffer suitable for network input.
1283    ///
1284    /// Each position `[x, y, z]` becomes three consecutive f64 values.
1285    pub fn pack_positions(positions: &[[f64; 3]]) -> Self {
1286        let batch_size = positions.len();
1287        let input_dim = 3;
1288        let output_dim = 3;
1289        let mut data = Vec::with_capacity(batch_size * input_dim);
1290        for p in positions {
1291            data.push(p[0]);
1292            data.push(p[1]);
1293            data.push(p[2]);
1294        }
1295        Self {
1296            batch_size,
1297            input_dim,
1298            output_dim,
1299            data,
1300        }
1301    }
1302    /// Unpack the buffer contents as a list of 3-D force vectors.
1303    ///
1304    /// Assumes `data` has been filled with `batch_size * 3` values by
1305    /// a prior inference step.
1306    pub fn unpack_forces(&self) -> Vec<[f64; 3]> {
1307        self.data.chunks(3).map(|c| [c[0], c[1], c[2]]).collect()
1308    }
1309}
1310/// Sinusoidal positional encoding (Vaswani et al., 2017).
1311///
1312/// For each position `pos` and dimension `i` in a `d_model`-dimensional
1313/// embedding space:
1314///   PE\[pos, 2i\]   = sin(pos / 10000^(2i/d_model))
1315///   PE\[pos, 2i+1\] = cos(pos / 10000^(2i/d_model))
1316#[allow(dead_code)]
1317#[derive(Debug, Clone)]
1318pub struct PositionalEncoding {
1319    /// Embedding dimensionality.
1320    pub d_model: usize,
1321    /// Maximum sequence length supported.
1322    pub max_len: usize,
1323    /// Pre-computed encoding table: `table[pos][dim]`.
1324    pub table: Vec<Vec<f64>>,
1325}
1326impl PositionalEncoding {
1327    /// Build the positional encoding table up to `max_len` positions.
1328    pub fn new(d_model: usize, max_len: usize) -> Self {
1329        let mut table = vec![vec![0.0_f64; d_model]; max_len];
1330        for pos in 0..max_len {
1331            for i in 0..(d_model / 2) {
1332                let angle = (pos as f64) / (10000.0_f64).powf(2.0 * i as f64 / d_model as f64);
1333                table[pos][2 * i] = angle.sin();
1334                if 2 * i + 1 < d_model {
1335                    table[pos][2 * i + 1] = angle.cos();
1336                }
1337            }
1338        }
1339        Self {
1340            d_model,
1341            max_len,
1342            table,
1343        }
1344    }
1345    /// Add positional encoding to a sequence of embeddings in-place.
1346    ///
1347    /// `embeddings[t]` is a feature vector of length `d_model`.
1348    pub fn add_to_sequence(&self, embeddings: &mut Vec<Vec<f64>>) {
1349        for (t, emb) in embeddings.iter_mut().enumerate() {
1350            if t >= self.max_len {
1351                break;
1352            }
1353            for d in 0..emb.len().min(self.d_model) {
1354                emb[d] += self.table[t][d];
1355            }
1356        }
1357    }
1358    /// Return the positional encoding vector for position `pos`.
1359    pub fn get(&self, pos: usize) -> &[f64] {
1360        &self.table[pos.min(self.max_len - 1)]
1361    }
1362}
1363/// A fully-connected layer with f64 weights supporting forward pass and
1364/// gradient computation for backpropagation.
1365#[derive(Debug, Clone)]
1366pub struct DenseLayer64 {
1367    /// Weight matrix in row-major layout: `weights[out * in_features + in]`.
1368    pub weights: Vec<f64>,
1369    /// Bias vector of length `out_features`.
1370    pub biases: Vec<f64>,
1371    /// Number of input features.
1372    pub in_features: usize,
1373    /// Number of output features.
1374    pub out_features: usize,
1375    /// Activation function.
1376    pub activation: ExtActivation,
1377    /// Pre-activation outputs from the last forward pass (z = W*x + b).
1378    pub last_pre_act: Vec<f64>,
1379    /// Post-activation outputs from the last forward pass.
1380    pub last_output: Vec<f64>,
1381    /// Last input fed to this layer.
1382    pub last_input: Vec<f64>,
1383}
1384impl DenseLayer64 {
1385    /// Create a new layer with zero-initialised weights and biases.
1386    pub fn new(in_features: usize, out_features: usize, activation: ExtActivation) -> Self {
1387        Self {
1388            weights: vec![0.0_f64; out_features * in_features],
1389            biases: vec![0.0_f64; out_features],
1390            in_features,
1391            out_features,
1392            activation,
1393            last_pre_act: Vec::new(),
1394            last_output: Vec::new(),
1395            last_input: Vec::new(),
1396        }
1397    }
1398    /// Forward pass: computes `activation(W * input + b)`.
1399    /// Caches `pre_act`, `output`, and `input` for backprop.
1400    pub fn forward(&mut self, input: &[f64]) -> Vec<f64> {
1401        assert_eq!(
1402            input.len(),
1403            self.in_features,
1404            "DenseLayer64::forward: input size mismatch"
1405        );
1406        self.last_input = input.to_vec();
1407        let mut pre_act = Vec::with_capacity(self.out_features);
1408        for o in 0..self.out_features {
1409            let row = o * self.in_features;
1410            let mut acc = self.biases[o];
1411            for i in 0..self.in_features {
1412                acc += self.weights[row + i] * input[i];
1413            }
1414            pre_act.push(acc);
1415        }
1416        let output: Vec<f64> = pre_act.iter().map(|&z| self.activation.apply(z)).collect();
1417        self.last_pre_act = pre_act;
1418        self.last_output = output.clone();
1419        output
1420    }
1421    /// Backward pass: computes gradients w.r.t. weights, biases, and input.
1422    ///
1423    /// `delta_out` is the gradient of the loss w.r.t. this layer's output
1424    /// (same shape as `last_output`).
1425    ///
1426    /// Returns `(grad_weights, grad_biases, delta_in)` where `delta_in` is the
1427    /// gradient passed to the previous layer.
1428    #[allow(clippy::too_many_arguments)]
1429    pub fn backward(&self, delta_out: &[f64]) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
1430        assert_eq!(
1431            delta_out.len(),
1432            self.out_features,
1433            "DenseLayer64::backward: delta_out size mismatch"
1434        );
1435        let delta_pre: Vec<f64> = delta_out
1436            .iter()
1437            .zip(self.last_pre_act.iter())
1438            .map(|(&d, &z)| d * self.activation.derivative(z))
1439            .collect();
1440        let mut grad_weights = vec![0.0_f64; self.out_features * self.in_features];
1441        for o in 0..self.out_features {
1442            let row = o * self.in_features;
1443            for i in 0..self.in_features {
1444                grad_weights[row + i] = delta_pre[o] * self.last_input[i];
1445            }
1446        }
1447        let grad_biases = delta_pre.clone();
1448        let mut delta_in = vec![0.0_f64; self.in_features];
1449        for o in 0..self.out_features {
1450            let row = o * self.in_features;
1451            for i in 0..self.in_features {
1452                delta_in[i] += self.weights[row + i] * delta_pre[o];
1453            }
1454        }
1455        (grad_weights, grad_biases, delta_in)
1456    }
1457    /// Apply gradient updates using a simple SGD step.
1458    pub fn apply_sgd(&mut self, grad_weights: &[f64], grad_biases: &[f64], lr: f64) {
1459        for (w, &gw) in self.weights.iter_mut().zip(grad_weights.iter()) {
1460            *w -= lr * gw;
1461        }
1462        for (b, &gb) in self.biases.iter_mut().zip(grad_biases.iter()) {
1463            *b -= lr * gb;
1464        }
1465    }
1466    /// Total number of parameters.
1467    pub fn num_params(&self) -> usize {
1468        self.out_features * self.in_features + self.out_features
1469    }
1470}
1471/// Atomic neural network potential (NNP) with one sub-network per element.
1472///
1473/// Architecture follows Behler (2011): each atom contributes an atomic energy
1474/// predicted by an element-specific feed-forward network whose input is the
1475/// Behler-Parrinello descriptor vector.
1476#[derive(Debug)]
1477pub struct AtomicNeuralNetwork {
1478    /// Element-specific networks keyed by atomic number.
1479    pub networks: HashMap<u8, FeedForwardNet>,
1480    /// Symmetry-function descriptor shared by all elements.
1481    pub descriptor: BehlerParrinelloDescriptor,
1482}
1483impl AtomicNeuralNetwork {
1484    /// Create a new AANN with the given descriptor.
1485    pub fn new(descriptor: BehlerParrinelloDescriptor) -> Self {
1486        AtomicNeuralNetwork {
1487            networks: HashMap::new(),
1488            descriptor,
1489        }
1490    }
1491    /// Register a sub-network for the given atomic number.
1492    pub fn add_element_network(&mut self, atomic_number: u8, net: FeedForwardNet) {
1493        self.networks.insert(atomic_number, net);
1494    }
1495    /// Predict the atomic energy for one atom given its descriptor.
1496    ///
1497    /// Returns `None` if no network is registered for `atomic_number`.
1498    pub fn atomic_energy(&self, atomic_number: u8, descriptor: &[f32]) -> Option<f32> {
1499        self.networks
1500            .get(&atomic_number)
1501            .map(|net| net.forward(descriptor)[0])
1502    }
1503    /// Sum of atomic energies over all atoms.
1504    ///
1505    /// Atoms whose element has no registered network contribute 0.
1506    pub fn total_energy(&self, positions: &[[f64; 3]], atomic_numbers: &[u8]) -> f64 {
1507        assert_eq!(
1508            positions.len(),
1509            atomic_numbers.len(),
1510            "total_energy: positions and atomic_numbers must have the same length"
1511        );
1512        let mut e_total = 0.0_f64;
1513        for (i, &z) in atomic_numbers.iter().enumerate() {
1514            let desc_f64 = self.descriptor.descriptor_vector(positions, i);
1515            let desc_f32: Vec<f32> = desc_f64.iter().map(|&v| v as f32).collect();
1516            if let Some(e) = self.atomic_energy(z, &desc_f32) {
1517                e_total += e as f64;
1518            }
1519        }
1520        e_total
1521    }
1522}
1523/// A single fully-connected (dense) layer with an activation function.
1524///
1525/// Weights are stored in row-major order: `weights[out * in_features + in]`.
1526#[derive(Debug, Clone)]
1527pub struct DenseLayer {
1528    /// Weight matrix in row-major layout `[out_features × in_features]`.
1529    pub weights: Vec<f32>,
1530    /// Bias vector of length `out_features`.
1531    pub biases: Vec<f32>,
1532    /// Number of input features.
1533    pub in_features: usize,
1534    /// Number of output features.
1535    pub out_features: usize,
1536    /// Activation function applied after the affine transform.
1537    pub activation: ActivationFn,
1538}
1539impl DenseLayer {
1540    /// Create a new layer with zero-initialised weights and biases.
1541    pub fn new(in_features: usize, out_features: usize, activation: ActivationFn) -> Self {
1542        DenseLayer {
1543            weights: vec![0.0_f32; out_features * in_features],
1544            biases: vec![0.0_f32; out_features],
1545            in_features,
1546            out_features,
1547            activation,
1548        }
1549    }
1550    /// Compute `activation(W * input + b)`.
1551    ///
1552    /// # Panics
1553    /// Panics if `input.len() != self.in_features`.
1554    pub fn forward(&self, input: &[f32]) -> Vec<f32> {
1555        assert_eq!(
1556            input.len(),
1557            self.in_features,
1558            "DenseLayer::forward: input length {} != in_features {}",
1559            input.len(),
1560            self.in_features
1561        );
1562        let mut output = Vec::with_capacity(self.out_features);
1563        for o in 0..self.out_features {
1564            let row_offset = o * self.in_features;
1565            let mut acc = self.biases[o];
1566            for i in 0..self.in_features {
1567                acc += self.weights[row_offset + i] * input[i];
1568            }
1569            output.push(self.activation.apply(acc));
1570        }
1571        output
1572    }
1573    /// Replace the weight matrix (must have length `out_features * in_features`).
1574    ///
1575    /// # Panics
1576    /// Panics if `w.len()` does not match.
1577    pub fn set_weights(&mut self, w: &[f32]) {
1578        assert_eq!(
1579            w.len(),
1580            self.out_features * self.in_features,
1581            "set_weights: expected {} elements, got {}",
1582            self.out_features * self.in_features,
1583            w.len()
1584        );
1585        self.weights.copy_from_slice(w);
1586    }
1587    /// Replace the bias vector (must have length `out_features`).
1588    ///
1589    /// # Panics
1590    /// Panics if `b.len()` does not match.
1591    pub fn set_biases(&mut self, b: &[f32]) {
1592        assert_eq!(
1593            b.len(),
1594            self.out_features,
1595            "set_biases: expected {} elements, got {}",
1596            self.out_features,
1597            b.len()
1598        );
1599        self.biases.copy_from_slice(b);
1600    }
1601    /// Total number of trainable parameters (weights + biases).
1602    pub fn parameter_count(&self) -> usize {
1603        self.out_features * self.in_features + self.out_features
1604    }
1605}
1606/// Dropout regularisation layer.
1607///
1608/// During training, each neuron is set to zero with probability `rate`.
1609/// During inference (training=false) the layer passes inputs through unchanged
1610/// but scales by `1 - rate` to maintain expected magnitude.
1611#[derive(Debug, Clone)]
1612pub struct DropoutLayer {
1613    /// Probability of dropping a unit (0.0 = no dropout, 1.0 = all dropped).
1614    pub rate: f64,
1615    /// Whether the layer is in training mode.
1616    pub training: bool,
1617    /// Mask used in the last forward pass (1.0 = kept, 0.0 = dropped).
1618    pub last_mask: Vec<f64>,
1619    /// Seed for the deterministic LCG used to generate the mask.
1620    pub(super) seed: u64,
1621}
1622impl DropoutLayer {
1623    /// Create a new dropout layer.
1624    pub fn new(rate: f64, training: bool) -> Self {
1625        assert!(
1626            (0.0..=1.0).contains(&rate),
1627            "dropout rate must be in [0, 1]"
1628        );
1629        Self {
1630            rate,
1631            training,
1632            last_mask: Vec::new(),
1633            seed: 0xdeadbeefcafe1234,
1634        }
1635    }
1636    /// Set the seed for reproducible mask generation.
1637    pub fn set_seed(&mut self, seed: u64) {
1638        self.seed = seed;
1639    }
1640    /// Apply dropout to `input`.
1641    ///
1642    /// In training mode, randomly zeros elements with probability `rate`
1643    /// and scales the rest by `1 / (1 - rate)` (inverted dropout).
1644    /// In eval mode, passes through unchanged.
1645    pub fn forward(&mut self, input: &[f64]) -> Vec<f64> {
1646        if !self.training || self.rate == 0.0 {
1647            self.last_mask = vec![1.0; input.len()];
1648            return input.to_vec();
1649        }
1650        if self.rate == 1.0 {
1651            self.last_mask = vec![0.0; input.len()];
1652            return vec![0.0; input.len()];
1653        }
1654        let scale = 1.0 / (1.0 - self.rate);
1655        let mut mask = Vec::with_capacity(input.len());
1656        let mut output = Vec::with_capacity(input.len());
1657        for &x in input {
1658            self.seed = self
1659                .seed
1660                .wrapping_mul(6364136223846793005)
1661                .wrapping_add(1442695040888963407);
1662            let u = (self.seed >> 11) as f64 / (1u64 << 53) as f64;
1663            let m = if u >= self.rate { scale } else { 0.0 };
1664            mask.push(m);
1665            output.push(x * m);
1666        }
1667        self.last_mask = mask;
1668        output
1669    }
1670    /// Backward pass: applies the stored mask to the upstream gradient.
1671    pub fn backward(&self, delta_out: &[f64]) -> Vec<f64> {
1672        delta_out
1673            .iter()
1674            .zip(self.last_mask.iter())
1675            .map(|(&d, &m)| d * m)
1676            .collect()
1677    }
1678}
1679/// Behler-Parrinello symmetry functions for constructing atomic descriptors.
1680///
1681/// Reference: J. Behler and M. Parrinello, PRL 98, 146401 (2007).
1682#[derive(Debug, Clone)]
1683pub struct BehlerParrinelloDescriptor {
1684    /// Radial decay parameters η for G2 functions.
1685    pub eta: Vec<f64>,
1686    /// Shift parameters R_s for G2 functions.
1687    pub rs: Vec<f64>,
1688    /// Cutoff radius R_c in Ångström (or same units as positions).
1689    pub cutoff: f64,
1690}
1691impl BehlerParrinelloDescriptor {
1692    /// Smooth cutoff function.
1693    ///
1694    /// f_c(r) = 0.5 * (cos(Ï€ r / R_c) + 1)  for r < R_c, else 0.
1695    pub fn cutoff_fn(r: f64, rc: f64) -> f64 {
1696        if r < rc {
1697            0.5 * ((PI_F64 * r / rc).cos() + 1.0)
1698        } else {
1699            0.0
1700        }
1701    }
1702    /// G1 radial symmetry function: G1 = f_c(r).
1703    pub fn radial_g1(r: f64, rc: f64) -> f64 {
1704        Self::cutoff_fn(r, rc)
1705    }
1706    /// G2 radial symmetry function: G2 = exp(-η (r - R_s)²) * f_c(r).
1707    pub fn radial_g2(r: f64, eta: f64, rs: f64, rc: f64) -> f64 {
1708        let diff = r - rs;
1709        (-eta * diff * diff).exp() * Self::cutoff_fn(r, rc)
1710    }
1711    /// G4 angular symmetry function (two-body factor for a triplet i-j-k).
1712    ///
1713    /// G4 = 2^(1-ζ) * (1 + λ cos θ)^ζ * exp(-η (r_ij² + r_ik² + r_jk²)) * f_c(r_ij) f_c(r_ik) f_c(r_jk)
1714    #[allow(clippy::too_many_arguments)]
1715    pub fn angular_g4(
1716        r_ij: f64,
1717        r_ik: f64,
1718        r_jk: f64,
1719        cos_theta: f64,
1720        eta: f64,
1721        zeta: f64,
1722        lambda: f64,
1723        rc: f64,
1724    ) -> f64 {
1725        let angular = (1.0 + lambda * cos_theta).powf(zeta);
1726        let radial = (-eta * (r_ij * r_ij + r_ik * r_ik + r_jk * r_jk)).exp();
1727        let fc = Self::cutoff_fn(r_ij, rc) * Self::cutoff_fn(r_ik, rc) * Self::cutoff_fn(r_jk, rc);
1728        2.0_f64.powf(1.0 - zeta) * angular * radial * fc
1729    }
1730    /// Compute a single G2 descriptor value (convenience wrapper).
1731    pub fn compute(r_ij: f64, eta: f64, rs: f64, cutoff: f64) -> f64 {
1732        Self::radial_g2(r_ij, eta, rs, cutoff)
1733    }
1734    /// Build a full G2 descriptor vector for atom `center_idx`.
1735    ///
1736    /// For every (η, R_s) pair the function sums G2(r_ij, η, R_s, R_c) over all
1737    /// neighbours j ≠ center_idx that lie within the cutoff radius.
1738    pub fn descriptor_vector(&self, positions: &[[f64; 3]], center_idx: usize) -> Vec<f64> {
1739        let n_descriptors = self.eta.len();
1740        let mut desc = vec![0.0_f64; n_descriptors];
1741        let ci = positions[center_idx];
1742        for (j, pos_j) in positions.iter().enumerate() {
1743            if j == center_idx {
1744                continue;
1745            }
1746            let dx = pos_j[0] - ci[0];
1747            let dy = pos_j[1] - ci[1];
1748            let dz = pos_j[2] - ci[2];
1749            let r = (dx * dx + dy * dy + dz * dz).sqrt();
1750            if r >= self.cutoff {
1751                continue;
1752            }
1753            for k in 0..n_descriptors {
1754                desc[k] += Self::radial_g2(r, self.eta[k], self.rs[k], self.cutoff);
1755            }
1756        }
1757        desc
1758    }
1759}
1760/// Extended activation functions with additional variants for f64 paths.
1761#[derive(Debug, Clone, PartialEq)]
1762pub enum ExtActivation {
1763    /// Leaky ReLU: max(alpha * x, x) where alpha is the negative slope.
1764    LeakyRelu(f64),
1765    /// Swish: x * sigmoid(beta * x).  beta=1 recovers SiLU.
1766    Swish(f64),
1767    /// Standard ReLU.
1768    Relu,
1769    /// Logistic sigmoid.
1770    Sigmoid,
1771    /// Hyperbolic tangent.
1772    Tanh,
1773    /// Identity.
1774    Linear,
1775}
1776impl ExtActivation {
1777    /// Evaluate the activation function at `x`.
1778    pub fn apply(&self, x: f64) -> f64 {
1779        match self {
1780            ExtActivation::LeakyRelu(alpha) => {
1781                if x >= 0.0 {
1782                    x
1783                } else {
1784                    alpha * x
1785                }
1786            }
1787            ExtActivation::Swish(beta) => x / (1.0 + (-beta * x).exp()),
1788            ExtActivation::Relu => x.max(0.0),
1789            ExtActivation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
1790            ExtActivation::Tanh => x.tanh(),
1791            ExtActivation::Linear => x,
1792        }
1793    }
1794    /// Evaluate the derivative of the activation function at `x`.
1795    pub fn derivative(&self, x: f64) -> f64 {
1796        match self {
1797            ExtActivation::LeakyRelu(alpha) => {
1798                if x >= 0.0 {
1799                    1.0
1800                } else {
1801                    *alpha
1802                }
1803            }
1804            ExtActivation::Swish(beta) => {
1805                let sig = 1.0 / (1.0 + (-beta * x).exp());
1806                sig + beta * x * sig * (1.0 - sig)
1807            }
1808            ExtActivation::Relu => {
1809                if x > 0.0 {
1810                    1.0
1811                } else {
1812                    0.0
1813                }
1814            }
1815            ExtActivation::Sigmoid => {
1816                let s = 1.0 / (1.0 + (-x).exp());
1817                s * (1.0 - s)
1818            }
1819            ExtActivation::Tanh => {
1820                let t = x.tanh();
1821                1.0 - t * t
1822            }
1823            ExtActivation::Linear => 1.0,
1824        }
1825    }
1826    /// Apply elementwise to a vector in-place.
1827    pub fn apply_vec(&self, v: &mut [f64]) {
1828        for x in v.iter_mut() {
1829            *x = self.apply(*x);
1830        }
1831    }
1832}