Skip to main content

axonml_nn/layers/
graph.rs

1//! Graph neural network layers — `GCNConv` and `GATConv`.
2//!
3//! 672 lines. `GCNConv` (Graph Convolutional Network — symmetric normalized
4//! adjacency message passing with linear transform). `GATConv` (Graph
5//! Attention Network — attention-weighted neighbor aggregation with
6//! multi-head support and LeakyReLU gating). Both operate on adjacency
7//! matrix + node feature matrix and implement `Module`.
8//!
9//! # File
10//! `crates/axonml-nn/src/layers/graph.rs`
11//!
12//! # Author
13//! Andrew Jewell Sr. — AutomataNexus LLC
14//! ORCID: 0009-0005-2158-7060
15//!
16//! # Updated
17//! April 14, 2026 11:15 PM EST
18//!
19//! # Disclaimer
20//! Use at own risk. This software is provided "as is", without warranty of any
21//! kind, express or implied. The author and AutomataNexus shall not be held
22//! liable for any damages arising from the use of this software.
23
24use std::collections::HashMap;
25
26use axonml_autograd::Variable;
27use axonml_tensor::Tensor;
28
29use crate::module::Module;
30use crate::parameter::Parameter;
31
32// =============================================================================
33// GCNConv
34// =============================================================================
35
36/// Graph Convolutional Network layer (Kipf & Welling, 2017).
37///
38/// Performs the graph convolution: `output = adj @ x @ weight + bias`
39/// where `adj` is the (possibly learned) adjacency matrix.
40///
41/// # Arguments
42/// * `in_features` - Number of input features per node
43/// * `out_features` - Number of output features per node
44///
45/// # Forward Signature
46/// Uses `forward_graph(x, adj)` since it requires an adjacency matrix.
47/// The standard `forward()` from Module is provided but panics—use `forward_graph()`.
48///
49/// # Example
50/// ```ignore
51/// use axonml_nn::layers::GCNConv;
52///
53/// let gcn = GCNConv::new(72, 128);
54/// let x = Variable::new(Tensor::randn(&[2, 7, 72]), true);     // (batch, nodes, features)
55/// let adj = Variable::new(Tensor::ones(&[7, 7]), false);         // (nodes, nodes)
56/// let output = gcn.forward_graph(&x, &adj);
57/// // output shape: (2, 7, 128)
58/// ```
59pub struct GCNConv {
60    weight: Parameter,
61    bias: Option<Parameter>,
62    in_features: usize,
63    out_features: usize,
64}
65
66impl GCNConv {
67    /// Creates a new GCN convolution layer with bias.
68    pub fn new(in_features: usize, out_features: usize) -> Self {
69        // Xavier initialization
70        let scale = (2.0 / (in_features + out_features) as f32).sqrt();
71        let weight_data: Vec<f32> = (0..in_features * out_features)
72            .map(|i| {
73                // Simple deterministic-ish init for reproducibility
74                let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
75                x * scale
76            })
77            .collect();
78
79        let weight = Parameter::named(
80            "weight",
81            Tensor::from_vec(weight_data, &[in_features, out_features])
82                .expect("tensor creation failed"),
83            true,
84        );
85
86        let bias_data = vec![0.0; out_features];
87        let bias = Some(Parameter::named(
88            "bias",
89            Tensor::from_vec(bias_data, &[out_features]).expect("tensor creation failed"),
90            true,
91        ));
92
93        Self {
94            weight,
95            bias,
96            in_features,
97            out_features,
98        }
99    }
100
101    /// Creates a GCN layer without bias.
102    pub fn without_bias(in_features: usize, out_features: usize) -> Self {
103        let scale = (2.0 / (in_features + out_features) as f32).sqrt();
104        let weight_data: Vec<f32> = (0..in_features * out_features)
105            .map(|i| {
106                let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
107                x * scale
108            })
109            .collect();
110
111        let weight = Parameter::named(
112            "weight",
113            Tensor::from_vec(weight_data, &[in_features, out_features])
114                .expect("tensor creation failed"),
115            true,
116        );
117
118        Self {
119            weight,
120            bias: None,
121            in_features,
122            out_features,
123        }
124    }
125
126    /// Graph convolution forward pass.
127    ///
128    /// # Arguments
129    /// * `x` - Node features: `(batch, num_nodes, in_features)`
130    /// * `adj` - Adjacency matrix: `(num_nodes, num_nodes)` or `(batch, num_nodes, num_nodes)`
131    ///
132    /// # Returns
133    /// Output features: `(batch, num_nodes, out_features)`
134    pub fn forward_graph(&self, x: &Variable, adj: &Variable) -> Variable {
135        let shape = x.shape();
136        assert!(
137            shape.len() == 3,
138            "GCNConv expects input shape (batch, nodes, features), got {:?}",
139            shape
140        );
141        assert_eq!(shape[2], self.in_features, "Input features mismatch");
142
143        let batch = shape[0];
144        let adj_shape = adj.shape();
145
146        // GCN: output = adj @ x @ weight + bias
147        // Use Variable operations to preserve gradient flow
148        //
149        // Process per-sample: for each batch element, compute adj @ x_b @ weight
150        // This preserves gradient flow through x and weight via matmul backward.
151        let weight = self.weight.variable();
152
153        let mut per_sample: Vec<Variable> = Vec::with_capacity(batch);
154        for b in 0..batch {
155            // x_b: (nodes, in_features)
156            let x_b = x.select(0, b);
157
158            // adj_b: (nodes, nodes)
159            let adj_b = if adj_shape.len() == 3 {
160                adj.select(0, b)
161            } else {
162                adj.clone()
163            };
164
165            // message_b = adj_b @ x_b → (nodes, in_features)
166            let msg_b = adj_b.matmul(&x_b);
167
168            // out_b = msg_b @ weight → (nodes, out_features)
169            let mut out_b = msg_b.matmul(&weight);
170
171            // Add bias
172            if let Some(bias) = &self.bias {
173                out_b = out_b.add_var(&bias.variable());
174            }
175
176            // Unsqueeze to (1, nodes, out_features) for stacking
177            per_sample.push(out_b.unsqueeze(0));
178        }
179
180        // Stack along batch dimension
181        let refs: Vec<&Variable> = per_sample.iter().collect();
182        Variable::cat(&refs, 0)
183    }
184
185    /// Returns the input feature dimension.
186    pub fn in_features(&self) -> usize {
187        self.in_features
188    }
189
190    /// Returns the output feature dimension.
191    pub fn out_features(&self) -> usize {
192        self.out_features
193    }
194}
195
196impl Module for GCNConv {
197    /// Forward with identity adjacency (self-loops only).
198    /// For proper graph convolution, use `forward_graph(x, adj)`.
199    fn forward(&self, input: &Variable) -> Variable {
200        // Use identity adjacency: each node only aggregates from itself
201        let n = input.shape()[0];
202        let mut eye_data = vec![0.0f32; n * n];
203        for i in 0..n {
204            eye_data[i * n + i] = 1.0;
205        }
206        let adj = Variable::new(
207            axonml_tensor::Tensor::from_vec(eye_data, &[n, n])
208                .expect("identity matrix creation failed"),
209            false,
210        );
211        self.forward_graph(input, &adj)
212    }
213
214    fn parameters(&self) -> Vec<Parameter> {
215        let mut params = vec![self.weight.clone()];
216        if let Some(bias) = &self.bias {
217            params.push(bias.clone());
218        }
219        params
220    }
221
222    fn named_parameters(&self) -> HashMap<String, Parameter> {
223        let mut params = HashMap::new();
224        params.insert("weight".to_string(), self.weight.clone());
225        if let Some(bias) = &self.bias {
226            params.insert("bias".to_string(), bias.clone());
227        }
228        params
229    }
230
231    fn name(&self) -> &'static str {
232        "GCNConv"
233    }
234}
235
236// =============================================================================
237// GATConv
238// =============================================================================
239
240/// Graph Attention Network layer (Veličković et al., 2018).
241///
242/// Computes attention-weighted graph convolution where attention coefficients
243/// are learned based on node features, masked by the adjacency matrix.
244///
245/// # Example
246/// ```ignore
247/// use axonml_nn::layers::GATConv;
248///
249/// let gat = GATConv::new(72, 32, 4); // 4 attention heads
250/// let x = Variable::new(Tensor::randn(&[2, 7, 72]), true);
251/// let adj = Variable::new(Tensor::ones(&[7, 7]), false);
252/// let output = gat.forward_graph(&x, &adj);
253/// // output shape: (2, 7, 128)  — 32 * 4 heads
254/// ```
255pub struct GATConv {
256    w: Parameter,
257    attn_src: Parameter,
258    attn_dst: Parameter,
259    bias: Option<Parameter>,
260    in_features: usize,
261    out_features: usize,
262    num_heads: usize,
263    negative_slope: f32,
264}
265
266impl GATConv {
267    /// Creates a new GAT convolution layer.
268    ///
269    /// # Arguments
270    /// * `in_features` - Input feature dimension per node
271    /// * `out_features` - Output feature dimension per head
272    /// * `num_heads` - Number of attention heads (output = out_features * num_heads)
273    pub fn new(in_features: usize, out_features: usize, num_heads: usize) -> Self {
274        let total_out = out_features * num_heads;
275        let scale = (2.0 / (in_features + total_out) as f32).sqrt();
276
277        let w_data: Vec<f32> = (0..in_features * total_out)
278            .map(|i| {
279                let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
280                x * scale
281            })
282            .collect();
283
284        let w = Parameter::named(
285            "w",
286            Tensor::from_vec(w_data, &[in_features, total_out]).expect("tensor creation failed"),
287            true,
288        );
289
290        // Attention vectors: one per head for source and destination
291        let attn_scale = (1.0 / out_features as f32).sqrt();
292        let attn_src_data: Vec<f32> = (0..total_out)
293            .map(|i| {
294                let x = ((i as f32 * 0.723_606_8) % 1.0) * 2.0 - 1.0;
295                x * attn_scale
296            })
297            .collect();
298        let attn_dst_data: Vec<f32> = (0..total_out)
299            .map(|i| {
300                let x = ((i as f32 * 0.381_966_02) % 1.0) * 2.0 - 1.0;
301                x * attn_scale
302            })
303            .collect();
304
305        let attn_src = Parameter::named(
306            "attn_src",
307            Tensor::from_vec(attn_src_data, &[num_heads, out_features])
308                .expect("tensor creation failed"),
309            true,
310        );
311
312        let attn_dst = Parameter::named(
313            "attn_dst",
314            Tensor::from_vec(attn_dst_data, &[num_heads, out_features])
315                .expect("tensor creation failed"),
316            true,
317        );
318
319        let bias_data = vec![0.0; total_out];
320        let bias = Some(Parameter::named(
321            "bias",
322            Tensor::from_vec(bias_data, &[total_out]).expect("tensor creation failed"),
323            true,
324        ));
325
326        Self {
327            w,
328            attn_src,
329            attn_dst,
330            bias,
331            in_features,
332            out_features,
333            num_heads,
334            negative_slope: 0.2,
335        }
336    }
337
338    /// Graph attention forward pass.
339    ///
340    /// # Arguments
341    /// * `x` - Node features: `(batch, num_nodes, in_features)`
342    /// * `adj` - Adjacency mask: `(num_nodes, num_nodes)` — non-zero entries allow attention
343    ///
344    /// # Returns
345    /// Output features: `(batch, num_nodes, out_features * num_heads)`
346    pub fn forward_graph(&self, x: &Variable, adj: &Variable) -> Variable {
347        let shape = x.shape();
348        assert!(
349            shape.len() == 3,
350            "GATConv expects (batch, nodes, features), got {:?}",
351            shape
352        );
353
354        let batch = shape[0];
355        let nodes = shape[1];
356        let total_out = self.out_features * self.num_heads;
357
358        let x_data = x.data().to_vec();
359        let adj_data = adj.data().to_vec();
360        let w_data = self.w.data().to_vec();
361        let attn_src_data = self.attn_src.data().to_vec();
362        let attn_dst_data = self.attn_dst.data().to_vec();
363
364        let adj_nodes = if adj.shape().len() == 3 {
365            adj.shape()[1]
366        } else {
367            adj.shape()[0]
368        };
369        assert_eq!(adj_nodes, nodes, "Adjacency matrix size mismatch");
370
371        let mut output = vec![0.0f32; batch * nodes * total_out];
372
373        for b in 0..batch {
374            // Step 1: Project all nodes: h = x @ w → (nodes, total_out)
375            let mut h = vec![0.0f32; nodes * total_out];
376            for i in 0..nodes {
377                let x_off = (b * nodes + i) * self.in_features;
378                for o in 0..total_out {
379                    let mut val = 0.0;
380                    for f in 0..self.in_features {
381                        val += x_data[x_off + f] * w_data[f * total_out + o];
382                    }
383                    h[i * total_out + o] = val;
384                }
385            }
386
387            // Step 2: Compute attention per head
388            let adj_off = if adj.shape().len() == 3 {
389                b * nodes * nodes
390            } else {
391                0
392            };
393
394            for head in 0..self.num_heads {
395                let head_off = head * self.out_features;
396
397                // Compute attention scores for each edge
398                // e_ij = LeakyReLU(attn_src · h_i + attn_dst · h_j)
399                let mut attn_scores = vec![f32::NEG_INFINITY; nodes * nodes];
400
401                for i in 0..nodes {
402                    // src score for node i
403                    let mut src_score = 0.0;
404                    for f in 0..self.out_features {
405                        src_score += h[i * total_out + head_off + f]
406                            * attn_src_data[head * self.out_features + f];
407                    }
408
409                    for j in 0..nodes {
410                        let a_ij = adj_data[adj_off + i * nodes + j];
411                        if a_ij != 0.0 {
412                            let mut dst_score = 0.0;
413                            for f in 0..self.out_features {
414                                dst_score += h[j * total_out + head_off + f]
415                                    * attn_dst_data[head * self.out_features + f];
416                            }
417
418                            let e = src_score + dst_score;
419                            // LeakyReLU
420                            let e = if e > 0.0 { e } else { e * self.negative_slope };
421                            attn_scores[i * nodes + j] = e;
422                        }
423                    }
424                }
425
426                // Softmax per row (per destination node)
427                for i in 0..nodes {
428                    let row_start = i * nodes;
429                    let row_end = row_start + nodes;
430                    let row = &attn_scores[row_start..row_end];
431
432                    let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
433                    if max_val == f32::NEG_INFINITY {
434                        continue; // No neighbors
435                    }
436
437                    let mut sum_exp = 0.0f32;
438                    let mut exps = vec![0.0; nodes];
439                    for j in 0..nodes {
440                        if row[j] > f32::NEG_INFINITY {
441                            exps[j] = (row[j] - max_val).exp();
442                            sum_exp += exps[j];
443                        }
444                    }
445
446                    // Weighted sum of neighbor features
447                    let out_off = (b * nodes + i) * total_out + head_off;
448                    for j in 0..nodes {
449                        if exps[j] > 0.0 {
450                            let alpha = exps[j] / sum_exp;
451                            for f in 0..self.out_features {
452                                output[out_off + f] += alpha * h[j * total_out + head_off + f];
453                            }
454                        }
455                    }
456                }
457            }
458        }
459
460        // Add bias
461        if let Some(bias) = &self.bias {
462            let bias_data = bias.data().to_vec();
463            for b in 0..batch {
464                for i in 0..nodes {
465                    let offset = (b * nodes + i) * total_out;
466                    for o in 0..total_out {
467                        output[offset + o] += bias_data[o];
468                    }
469                }
470            }
471        }
472
473        Variable::new(
474            Tensor::from_vec(output, &[batch, nodes, total_out]).expect("tensor creation failed"),
475            x.requires_grad(),
476        )
477    }
478
479    /// Total output dimension (out_features * num_heads).
480    pub fn total_out_features(&self) -> usize {
481        self.out_features * self.num_heads
482    }
483}
484
485impl Module for GATConv {
486    fn forward(&self, input: &Variable) -> Variable {
487        // Use identity adjacency: each node only attends to itself
488        let n = input.shape()[0];
489        let mut eye_data = vec![0.0f32; n * n];
490        for i in 0..n {
491            eye_data[i * n + i] = 1.0;
492        }
493        let adj = Variable::new(
494            axonml_tensor::Tensor::from_vec(eye_data, &[n, n])
495                .expect("identity matrix creation failed"),
496            false,
497        );
498        self.forward_graph(input, &adj)
499    }
500
501    fn parameters(&self) -> Vec<Parameter> {
502        let mut params = vec![self.w.clone(), self.attn_src.clone(), self.attn_dst.clone()];
503        if let Some(bias) = &self.bias {
504            params.push(bias.clone());
505        }
506        params
507    }
508
509    fn named_parameters(&self) -> HashMap<String, Parameter> {
510        let mut params = HashMap::new();
511        params.insert("w".to_string(), self.w.clone());
512        params.insert("attn_src".to_string(), self.attn_src.clone());
513        params.insert("attn_dst".to_string(), self.attn_dst.clone());
514        if let Some(bias) = &self.bias {
515            params.insert("bias".to_string(), bias.clone());
516        }
517        params
518    }
519
520    fn name(&self) -> &'static str {
521        "GATConv"
522    }
523}
524
525// =============================================================================
526// Tests
527// =============================================================================
528
529#[cfg(test)]
530mod tests {
531    use super::*;
532
533    #[test]
534    fn test_gcn_conv_shape() {
535        let gcn = GCNConv::new(72, 128);
536        let x = Variable::new(
537            Tensor::from_vec(vec![1.0; 2 * 7 * 72], &[2, 7, 72]).expect("tensor creation failed"),
538            false,
539        );
540        let adj = Variable::new(
541            Tensor::from_vec(vec![1.0; 7 * 7], &[7, 7]).expect("tensor creation failed"),
542            false,
543        );
544        let output = gcn.forward_graph(&x, &adj);
545        assert_eq!(output.shape(), vec![2, 7, 128]);
546    }
547
548    #[test]
549    fn test_gcn_conv_identity_adjacency() {
550        // With identity adjacency, each node only sees itself
551        let gcn = GCNConv::new(4, 8);
552        let x = Variable::new(
553            Tensor::from_vec(vec![1.0; 3 * 4], &[1, 3, 4]).expect("tensor creation failed"),
554            false,
555        );
556
557        // Identity adjacency
558        let mut adj_data = vec![0.0; 9];
559        adj_data[0] = 1.0; // (0,0)
560        adj_data[4] = 1.0; // (1,1)
561        adj_data[8] = 1.0; // (2,2)
562        let adj = Variable::new(
563            Tensor::from_vec(adj_data, &[3, 3]).expect("tensor creation failed"),
564            false,
565        );
566
567        let output = gcn.forward_graph(&x, &adj);
568        assert_eq!(output.shape(), vec![1, 3, 8]);
569
570        // All nodes have same input, so all should produce same output
571        let data = output.data().to_vec();
572        for i in 0..3 {
573            for f in 0..8 {
574                assert!(
575                    (data[i * 8 + f] - data[f]).abs() < 1e-6,
576                    "Node outputs should be identical with identity adj and same input"
577                );
578            }
579        }
580    }
581
582    #[test]
583    fn test_gcn_conv_parameters() {
584        let gcn = GCNConv::new(16, 32);
585        let params = gcn.parameters();
586        assert_eq!(params.len(), 2); // weight + bias
587
588        let total_params: usize = params.iter().map(|p| p.numel()).sum();
589        assert_eq!(total_params, 16 * 32 + 32); // weight + bias
590    }
591
592    #[test]
593    fn test_gcn_conv_no_bias() {
594        let gcn = GCNConv::without_bias(16, 32);
595        let params = gcn.parameters();
596        assert_eq!(params.len(), 1); // weight only
597    }
598
599    #[test]
600    fn test_gcn_conv_named_parameters() {
601        let gcn = GCNConv::new(16, 32);
602        let params = gcn.named_parameters();
603        assert!(params.contains_key("weight"));
604        assert!(params.contains_key("bias"));
605    }
606
607    #[test]
608    fn test_gat_conv_shape() {
609        let gat = GATConv::new(72, 32, 4); // 4 heads
610        let x = Variable::new(
611            Tensor::from_vec(vec![1.0; 2 * 7 * 72], &[2, 7, 72]).expect("tensor creation failed"),
612            false,
613        );
614        let adj = Variable::new(
615            Tensor::from_vec(vec![1.0; 7 * 7], &[7, 7]).expect("tensor creation failed"),
616            false,
617        );
618        let output = gat.forward_graph(&x, &adj);
619        assert_eq!(output.shape(), vec![2, 7, 128]); // 32 * 4 = 128
620    }
621
622    #[test]
623    fn test_gat_conv_single_head() {
624        let gat = GATConv::new(16, 8, 1);
625        let x = Variable::new(
626            Tensor::from_vec(vec![1.0; 5 * 16], &[1, 5, 16]).expect("tensor creation failed"),
627            false,
628        );
629        let adj = Variable::new(
630            Tensor::from_vec(vec![1.0; 5 * 5], &[5, 5]).expect("tensor creation failed"),
631            false,
632        );
633        let output = gat.forward_graph(&x, &adj);
634        assert_eq!(output.shape(), vec![1, 5, 8]);
635    }
636
637    #[test]
638    fn test_gat_conv_parameters() {
639        let gat = GATConv::new(16, 8, 4);
640        let params = gat.parameters();
641        assert_eq!(params.len(), 4); // w, attn_src, attn_dst, bias
642
643        let named = gat.named_parameters();
644        assert!(named.contains_key("w"));
645        assert!(named.contains_key("attn_src"));
646        assert!(named.contains_key("attn_dst"));
647        assert!(named.contains_key("bias"));
648    }
649
650    #[test]
651    fn test_gat_conv_total_output() {
652        let gat = GATConv::new(16, 32, 4);
653        assert_eq!(gat.total_out_features(), 128);
654    }
655
656    #[test]
657    fn test_gcn_zero_adjacency() {
658        // Zero adjacency should produce only bias in output
659        let gcn = GCNConv::new(4, 4);
660        let x = Variable::new(
661            Tensor::from_vec(vec![99.0; 3 * 4], &[1, 3, 4]).expect("tensor creation failed"),
662            false,
663        );
664        let adj = Variable::new(
665            Tensor::from_vec(vec![0.0; 9], &[3, 3]).expect("tensor creation failed"),
666            false,
667        );
668        let output = gcn.forward_graph(&x, &adj);
669
670        // With zero adjacency, output should be just bias (all zeros initially)
671        let data = output.data().to_vec();
672        for val in &data {
673            assert!(
674                val.abs() < 1e-6,
675                "Zero adjacency should zero out message passing"
676            );
677        }
678    }
679}