axonml-nn 0.6.2

Neural network modules for Axonml ML framework
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
//! Graph neural network layers — `GCNConv` and `GATConv`.
//!
//! 672 lines. `GCNConv` (Graph Convolutional Network — symmetric normalized
//! adjacency message passing with linear transform). `GATConv` (Graph
//! Attention Network — attention-weighted neighbor aggregation with
//! multi-head support and LeakyReLU gating). Both operate on adjacency
//! matrix + node feature matrix and implement `Module`.
//!
//! # File
//! `crates/axonml-nn/src/layers/graph.rs`
//!
//! # Author
//! Andrew Jewell Sr. — AutomataNexus LLC
//! ORCID: 0009-0005-2158-7060
//!
//! # Updated
//! April 14, 2026 11:15 PM EST
//!
//! # Disclaimer
//! Use at own risk. This software is provided "as is", without warranty of any
//! kind, express or implied. The author and AutomataNexus shall not be held
//! liable for any damages arising from the use of this software.

use std::collections::HashMap;

use axonml_autograd::Variable;
use axonml_tensor::Tensor;

use crate::module::Module;
use crate::parameter::Parameter;

// =============================================================================
// GCNConv
// =============================================================================

/// Graph Convolutional Network layer (Kipf & Welling, 2017).
///
/// Performs the graph convolution: `output = adj @ x @ weight + bias`
/// where `adj` is the (possibly learned) adjacency matrix.
///
/// # Arguments
/// * `in_features` - Number of input features per node
/// * `out_features` - Number of output features per node
///
/// # Forward Signature
/// Uses `forward_graph(x, adj)` since it requires an adjacency matrix.
/// The standard `forward()` from Module is provided but panics—use `forward_graph()`.
///
/// # Example
/// ```ignore
/// use axonml_nn::layers::GCNConv;
///
/// let gcn = GCNConv::new(72, 128);
/// let x = Variable::new(Tensor::randn(&[2, 7, 72]), true);     // (batch, nodes, features)
/// let adj = Variable::new(Tensor::ones(&[7, 7]), false);         // (nodes, nodes)
/// let output = gcn.forward_graph(&x, &adj);
/// // output shape: (2, 7, 128)
/// ```
pub struct GCNConv {
    weight: Parameter,
    bias: Option<Parameter>,
    in_features: usize,
    out_features: usize,
}

impl GCNConv {
    /// Creates a new GCN convolution layer with bias.
    pub fn new(in_features: usize, out_features: usize) -> Self {
        // Xavier initialization
        let scale = (2.0 / (in_features + out_features) as f32).sqrt();
        let weight_data: Vec<f32> = (0..in_features * out_features)
            .map(|i| {
                // Simple deterministic-ish init for reproducibility
                let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
                x * scale
            })
            .collect();

        let weight = Parameter::named(
            "weight",
            Tensor::from_vec(weight_data, &[in_features, out_features])
                .expect("tensor creation failed"),
            true,
        );

        let bias_data = vec![0.0; out_features];
        let bias = Some(Parameter::named(
            "bias",
            Tensor::from_vec(bias_data, &[out_features]).expect("tensor creation failed"),
            true,
        ));

        Self {
            weight,
            bias,
            in_features,
            out_features,
        }
    }

    /// Creates a GCN layer without bias.
    pub fn without_bias(in_features: usize, out_features: usize) -> Self {
        let scale = (2.0 / (in_features + out_features) as f32).sqrt();
        let weight_data: Vec<f32> = (0..in_features * out_features)
            .map(|i| {
                let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
                x * scale
            })
            .collect();

        let weight = Parameter::named(
            "weight",
            Tensor::from_vec(weight_data, &[in_features, out_features])
                .expect("tensor creation failed"),
            true,
        );

        Self {
            weight,
            bias: None,
            in_features,
            out_features,
        }
    }

    /// Graph convolution forward pass.
    ///
    /// # Arguments
    /// * `x` - Node features: `(batch, num_nodes, in_features)`
    /// * `adj` - Adjacency matrix: `(num_nodes, num_nodes)` or `(batch, num_nodes, num_nodes)`
    ///
    /// # Returns
    /// Output features: `(batch, num_nodes, out_features)`
    pub fn forward_graph(&self, x: &Variable, adj: &Variable) -> Variable {
        let shape = x.shape();
        assert!(
            shape.len() == 3,
            "GCNConv expects input shape (batch, nodes, features), got {:?}",
            shape
        );
        assert_eq!(shape[2], self.in_features, "Input features mismatch");

        let batch = shape[0];
        let adj_shape = adj.shape();

        // GCN: output = adj @ x @ weight + bias
        // Use Variable operations to preserve gradient flow
        //
        // Process per-sample: for each batch element, compute adj @ x_b @ weight
        // This preserves gradient flow through x and weight via matmul backward.
        let weight = self.weight.variable();

        let mut per_sample: Vec<Variable> = Vec::with_capacity(batch);
        for b in 0..batch {
            // x_b: (nodes, in_features)
            let x_b = x.select(0, b);

            // adj_b: (nodes, nodes)
            let adj_b = if adj_shape.len() == 3 {
                adj.select(0, b)
            } else {
                adj.clone()
            };

            // message_b = adj_b @ x_b → (nodes, in_features)
            let msg_b = adj_b.matmul(&x_b);

            // out_b = msg_b @ weight → (nodes, out_features)
            let mut out_b = msg_b.matmul(&weight);

            // Add bias
            if let Some(bias) = &self.bias {
                out_b = out_b.add_var(&bias.variable());
            }

            // Unsqueeze to (1, nodes, out_features) for stacking
            per_sample.push(out_b.unsqueeze(0));
        }

        // Stack along batch dimension
        let refs: Vec<&Variable> = per_sample.iter().collect();
        Variable::cat(&refs, 0)
    }

    /// Returns the input feature dimension.
    pub fn in_features(&self) -> usize {
        self.in_features
    }

    /// Returns the output feature dimension.
    pub fn out_features(&self) -> usize {
        self.out_features
    }
}

impl Module for GCNConv {
    /// Forward with identity adjacency (self-loops only).
    /// For proper graph convolution, use `forward_graph(x, adj)`.
    fn forward(&self, input: &Variable) -> Variable {
        // Use identity adjacency: each node only aggregates from itself
        let n = input.shape()[0];
        let mut eye_data = vec![0.0f32; n * n];
        for i in 0..n {
            eye_data[i * n + i] = 1.0;
        }
        let adj = Variable::new(
            axonml_tensor::Tensor::from_vec(eye_data, &[n, n])
                .expect("identity matrix creation failed"),
            false,
        );
        self.forward_graph(input, &adj)
    }

    fn parameters(&self) -> Vec<Parameter> {
        let mut params = vec![self.weight.clone()];
        if let Some(bias) = &self.bias {
            params.push(bias.clone());
        }
        params
    }

    fn named_parameters(&self) -> HashMap<String, Parameter> {
        let mut params = HashMap::new();
        params.insert("weight".to_string(), self.weight.clone());
        if let Some(bias) = &self.bias {
            params.insert("bias".to_string(), bias.clone());
        }
        params
    }

    fn name(&self) -> &'static str {
        "GCNConv"
    }
}

// =============================================================================
// GATConv
// =============================================================================

/// Graph Attention Network layer (Veličković et al., 2018).
///
/// Computes attention-weighted graph convolution where attention coefficients
/// are learned based on node features, masked by the adjacency matrix.
///
/// # Example
/// ```ignore
/// use axonml_nn::layers::GATConv;
///
/// let gat = GATConv::new(72, 32, 4); // 4 attention heads
/// let x = Variable::new(Tensor::randn(&[2, 7, 72]), true);
/// let adj = Variable::new(Tensor::ones(&[7, 7]), false);
/// let output = gat.forward_graph(&x, &adj);
/// // output shape: (2, 7, 128)  — 32 * 4 heads
/// ```
pub struct GATConv {
    w: Parameter,
    attn_src: Parameter,
    attn_dst: Parameter,
    bias: Option<Parameter>,
    in_features: usize,
    out_features: usize,
    num_heads: usize,
    negative_slope: f32,
}

impl GATConv {
    /// Creates a new GAT convolution layer.
    ///
    /// # Arguments
    /// * `in_features` - Input feature dimension per node
    /// * `out_features` - Output feature dimension per head
    /// * `num_heads` - Number of attention heads (output = out_features * num_heads)
    pub fn new(in_features: usize, out_features: usize, num_heads: usize) -> Self {
        let total_out = out_features * num_heads;
        let scale = (2.0 / (in_features + total_out) as f32).sqrt();

        let w_data: Vec<f32> = (0..in_features * total_out)
            .map(|i| {
                let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
                x * scale
            })
            .collect();

        let w = Parameter::named(
            "w",
            Tensor::from_vec(w_data, &[in_features, total_out]).expect("tensor creation failed"),
            true,
        );

        // Attention vectors: one per head for source and destination
        let attn_scale = (1.0 / out_features as f32).sqrt();
        let attn_src_data: Vec<f32> = (0..total_out)
            .map(|i| {
                let x = ((i as f32 * 0.723_606_8) % 1.0) * 2.0 - 1.0;
                x * attn_scale
            })
            .collect();
        let attn_dst_data: Vec<f32> = (0..total_out)
            .map(|i| {
                let x = ((i as f32 * 0.381_966_02) % 1.0) * 2.0 - 1.0;
                x * attn_scale
            })
            .collect();

        let attn_src = Parameter::named(
            "attn_src",
            Tensor::from_vec(attn_src_data, &[num_heads, out_features])
                .expect("tensor creation failed"),
            true,
        );

        let attn_dst = Parameter::named(
            "attn_dst",
            Tensor::from_vec(attn_dst_data, &[num_heads, out_features])
                .expect("tensor creation failed"),
            true,
        );

        let bias_data = vec![0.0; total_out];
        let bias = Some(Parameter::named(
            "bias",
            Tensor::from_vec(bias_data, &[total_out]).expect("tensor creation failed"),
            true,
        ));

        Self {
            w,
            attn_src,
            attn_dst,
            bias,
            in_features,
            out_features,
            num_heads,
            negative_slope: 0.2,
        }
    }

    /// Graph attention forward pass.
    ///
    /// # Arguments
    /// * `x` - Node features: `(batch, num_nodes, in_features)`
    /// * `adj` - Adjacency mask: `(num_nodes, num_nodes)` — non-zero entries allow attention
    ///
    /// # Returns
    /// Output features: `(batch, num_nodes, out_features * num_heads)`
    pub fn forward_graph(&self, x: &Variable, adj: &Variable) -> Variable {
        let shape = x.shape();
        assert!(
            shape.len() == 3,
            "GATConv expects (batch, nodes, features), got {:?}",
            shape
        );

        let batch = shape[0];
        let nodes = shape[1];
        let total_out = self.out_features * self.num_heads;

        let x_data = x.data().to_vec();
        let adj_data = adj.data().to_vec();
        let w_data = self.w.data().to_vec();
        let attn_src_data = self.attn_src.data().to_vec();
        let attn_dst_data = self.attn_dst.data().to_vec();

        let adj_nodes = if adj.shape().len() == 3 {
            adj.shape()[1]
        } else {
            adj.shape()[0]
        };
        assert_eq!(adj_nodes, nodes, "Adjacency matrix size mismatch");

        let mut output = vec![0.0f32; batch * nodes * total_out];

        for b in 0..batch {
            // Step 1: Project all nodes: h = x @ w → (nodes, total_out)
            let mut h = vec![0.0f32; nodes * total_out];
            for i in 0..nodes {
                let x_off = (b * nodes + i) * self.in_features;
                for o in 0..total_out {
                    let mut val = 0.0;
                    for f in 0..self.in_features {
                        val += x_data[x_off + f] * w_data[f * total_out + o];
                    }
                    h[i * total_out + o] = val;
                }
            }

            // Step 2: Compute attention per head
            let adj_off = if adj.shape().len() == 3 {
                b * nodes * nodes
            } else {
                0
            };

            for head in 0..self.num_heads {
                let head_off = head * self.out_features;

                // Compute attention scores for each edge
                // e_ij = LeakyReLU(attn_src · h_i + attn_dst · h_j)
                let mut attn_scores = vec![f32::NEG_INFINITY; nodes * nodes];

                for i in 0..nodes {
                    // src score for node i
                    let mut src_score = 0.0;
                    for f in 0..self.out_features {
                        src_score += h[i * total_out + head_off + f]
                            * attn_src_data[head * self.out_features + f];
                    }

                    for j in 0..nodes {
                        let a_ij = adj_data[adj_off + i * nodes + j];
                        if a_ij != 0.0 {
                            let mut dst_score = 0.0;
                            for f in 0..self.out_features {
                                dst_score += h[j * total_out + head_off + f]
                                    * attn_dst_data[head * self.out_features + f];
                            }

                            let e = src_score + dst_score;
                            // LeakyReLU
                            let e = if e > 0.0 { e } else { e * self.negative_slope };
                            attn_scores[i * nodes + j] = e;
                        }
                    }
                }

                // Softmax per row (per destination node)
                for i in 0..nodes {
                    let row_start = i * nodes;
                    let row_end = row_start + nodes;
                    let row = &attn_scores[row_start..row_end];

                    let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
                    if max_val == f32::NEG_INFINITY {
                        continue; // No neighbors
                    }

                    let mut sum_exp = 0.0f32;
                    let mut exps = vec![0.0; nodes];
                    for j in 0..nodes {
                        if row[j] > f32::NEG_INFINITY {
                            exps[j] = (row[j] - max_val).exp();
                            sum_exp += exps[j];
                        }
                    }

                    // Weighted sum of neighbor features
                    let out_off = (b * nodes + i) * total_out + head_off;
                    for j in 0..nodes {
                        if exps[j] > 0.0 {
                            let alpha = exps[j] / sum_exp;
                            for f in 0..self.out_features {
                                output[out_off + f] += alpha * h[j * total_out + head_off + f];
                            }
                        }
                    }
                }
            }
        }

        // Add bias
        if let Some(bias) = &self.bias {
            let bias_data = bias.data().to_vec();
            for b in 0..batch {
                for i in 0..nodes {
                    let offset = (b * nodes + i) * total_out;
                    for o in 0..total_out {
                        output[offset + o] += bias_data[o];
                    }
                }
            }
        }

        Variable::new(
            Tensor::from_vec(output, &[batch, nodes, total_out]).expect("tensor creation failed"),
            x.requires_grad(),
        )
    }

    /// Total output dimension (out_features * num_heads).
    pub fn total_out_features(&self) -> usize {
        self.out_features * self.num_heads
    }
}

impl Module for GATConv {
    fn forward(&self, input: &Variable) -> Variable {
        // Use identity adjacency: each node only attends to itself
        let n = input.shape()[0];
        let mut eye_data = vec![0.0f32; n * n];
        for i in 0..n {
            eye_data[i * n + i] = 1.0;
        }
        let adj = Variable::new(
            axonml_tensor::Tensor::from_vec(eye_data, &[n, n])
                .expect("identity matrix creation failed"),
            false,
        );
        self.forward_graph(input, &adj)
    }

    fn parameters(&self) -> Vec<Parameter> {
        let mut params = vec![self.w.clone(), self.attn_src.clone(), self.attn_dst.clone()];
        if let Some(bias) = &self.bias {
            params.push(bias.clone());
        }
        params
    }

    fn named_parameters(&self) -> HashMap<String, Parameter> {
        let mut params = HashMap::new();
        params.insert("w".to_string(), self.w.clone());
        params.insert("attn_src".to_string(), self.attn_src.clone());
        params.insert("attn_dst".to_string(), self.attn_dst.clone());
        if let Some(bias) = &self.bias {
            params.insert("bias".to_string(), bias.clone());
        }
        params
    }

    fn name(&self) -> &'static str {
        "GATConv"
    }
}

// =============================================================================
// Tests
// =============================================================================

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_gcn_conv_shape() {
        let gcn = GCNConv::new(72, 128);
        let x = Variable::new(
            Tensor::from_vec(vec![1.0; 2 * 7 * 72], &[2, 7, 72]).expect("tensor creation failed"),
            false,
        );
        let adj = Variable::new(
            Tensor::from_vec(vec![1.0; 7 * 7], &[7, 7]).expect("tensor creation failed"),
            false,
        );
        let output = gcn.forward_graph(&x, &adj);
        assert_eq!(output.shape(), vec![2, 7, 128]);
    }

    #[test]
    fn test_gcn_conv_identity_adjacency() {
        // With identity adjacency, each node only sees itself
        let gcn = GCNConv::new(4, 8);
        let x = Variable::new(
            Tensor::from_vec(vec![1.0; 3 * 4], &[1, 3, 4]).expect("tensor creation failed"),
            false,
        );

        // Identity adjacency
        let mut adj_data = vec![0.0; 9];
        adj_data[0] = 1.0; // (0,0)
        adj_data[4] = 1.0; // (1,1)
        adj_data[8] = 1.0; // (2,2)
        let adj = Variable::new(
            Tensor::from_vec(adj_data, &[3, 3]).expect("tensor creation failed"),
            false,
        );

        let output = gcn.forward_graph(&x, &adj);
        assert_eq!(output.shape(), vec![1, 3, 8]);

        // All nodes have same input, so all should produce same output
        let data = output.data().to_vec();
        for i in 0..3 {
            for f in 0..8 {
                assert!(
                    (data[i * 8 + f] - data[f]).abs() < 1e-6,
                    "Node outputs should be identical with identity adj and same input"
                );
            }
        }
    }

    #[test]
    fn test_gcn_conv_parameters() {
        let gcn = GCNConv::new(16, 32);
        let params = gcn.parameters();
        assert_eq!(params.len(), 2); // weight + bias

        let total_params: usize = params.iter().map(|p| p.numel()).sum();
        assert_eq!(total_params, 16 * 32 + 32); // weight + bias
    }

    #[test]
    fn test_gcn_conv_no_bias() {
        let gcn = GCNConv::without_bias(16, 32);
        let params = gcn.parameters();
        assert_eq!(params.len(), 1); // weight only
    }

    #[test]
    fn test_gcn_conv_named_parameters() {
        let gcn = GCNConv::new(16, 32);
        let params = gcn.named_parameters();
        assert!(params.contains_key("weight"));
        assert!(params.contains_key("bias"));
    }

    #[test]
    fn test_gat_conv_shape() {
        let gat = GATConv::new(72, 32, 4); // 4 heads
        let x = Variable::new(
            Tensor::from_vec(vec![1.0; 2 * 7 * 72], &[2, 7, 72]).expect("tensor creation failed"),
            false,
        );
        let adj = Variable::new(
            Tensor::from_vec(vec![1.0; 7 * 7], &[7, 7]).expect("tensor creation failed"),
            false,
        );
        let output = gat.forward_graph(&x, &adj);
        assert_eq!(output.shape(), vec![2, 7, 128]); // 32 * 4 = 128
    }

    #[test]
    fn test_gat_conv_single_head() {
        let gat = GATConv::new(16, 8, 1);
        let x = Variable::new(
            Tensor::from_vec(vec![1.0; 5 * 16], &[1, 5, 16]).expect("tensor creation failed"),
            false,
        );
        let adj = Variable::new(
            Tensor::from_vec(vec![1.0; 5 * 5], &[5, 5]).expect("tensor creation failed"),
            false,
        );
        let output = gat.forward_graph(&x, &adj);
        assert_eq!(output.shape(), vec![1, 5, 8]);
    }

    #[test]
    fn test_gat_conv_parameters() {
        let gat = GATConv::new(16, 8, 4);
        let params = gat.parameters();
        assert_eq!(params.len(), 4); // w, attn_src, attn_dst, bias

        let named = gat.named_parameters();
        assert!(named.contains_key("w"));
        assert!(named.contains_key("attn_src"));
        assert!(named.contains_key("attn_dst"));
        assert!(named.contains_key("bias"));
    }

    #[test]
    fn test_gat_conv_total_output() {
        let gat = GATConv::new(16, 32, 4);
        assert_eq!(gat.total_out_features(), 128);
    }

    #[test]
    fn test_gcn_zero_adjacency() {
        // Zero adjacency should produce only bias in output
        let gcn = GCNConv::new(4, 4);
        let x = Variable::new(
            Tensor::from_vec(vec![99.0; 3 * 4], &[1, 3, 4]).expect("tensor creation failed"),
            false,
        );
        let adj = Variable::new(
            Tensor::from_vec(vec![0.0; 9], &[3, 3]).expect("tensor creation failed"),
            false,
        );
        let output = gcn.forward_graph(&x, &adj);

        // With zero adjacency, output should be just bias (all zeros initially)
        let data = output.data().to_vec();
        for val in &data {
            assert!(
                val.abs() < 1e-6,
                "Zero adjacency should zero out message passing"
            );
        }
    }
}