Skip to main content

axonml_nn/layers/
graph.rs

1//! Graph Neural Network Layers
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/graph.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_tensor::Tensor;
21
22use crate::module::Module;
23use crate::parameter::Parameter;
24
25// =============================================================================
26// GCNConv
27// =============================================================================
28
29/// Graph Convolutional Network layer (Kipf & Welling, 2017).
30///
31/// Performs the graph convolution: `output = adj @ x @ weight + bias`
32/// where `adj` is the (possibly learned) adjacency matrix.
33///
34/// # Arguments
35/// * `in_features` - Number of input features per node
36/// * `out_features` - Number of output features per node
37///
38/// # Forward Signature
39/// Uses `forward_graph(x, adj)` since it requires an adjacency matrix.
40/// The standard `forward()` from Module is provided but panics—use `forward_graph()`.
41///
42/// # Example
43/// ```ignore
44/// use axonml_nn::layers::GCNConv;
45///
46/// let gcn = GCNConv::new(72, 128);
47/// let x = Variable::new(Tensor::randn(&[2, 7, 72]), true);     // (batch, nodes, features)
48/// let adj = Variable::new(Tensor::ones(&[7, 7]), false);         // (nodes, nodes)
49/// let output = gcn.forward_graph(&x, &adj);
50/// // output shape: (2, 7, 128)
51/// ```
52pub struct GCNConv {
53    weight: Parameter,
54    bias: Option<Parameter>,
55    in_features: usize,
56    out_features: usize,
57}
58
59impl GCNConv {
60    /// Creates a new GCN convolution layer with bias.
61    pub fn new(in_features: usize, out_features: usize) -> Self {
62        // Xavier initialization
63        let scale = (2.0 / (in_features + out_features) as f32).sqrt();
64        let weight_data: Vec<f32> = (0..in_features * out_features)
65            .map(|i| {
66                // Simple deterministic-ish init for reproducibility
67                let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
68                x * scale
69            })
70            .collect();
71
72        let weight = Parameter::named(
73            "weight",
74            Tensor::from_vec(weight_data, &[in_features, out_features])
75                .expect("tensor creation failed"),
76            true,
77        );
78
79        let bias_data = vec![0.0; out_features];
80        let bias = Some(Parameter::named(
81            "bias",
82            Tensor::from_vec(bias_data, &[out_features]).expect("tensor creation failed"),
83            true,
84        ));
85
86        Self {
87            weight,
88            bias,
89            in_features,
90            out_features,
91        }
92    }
93
94    /// Creates a GCN layer without bias.
95    pub fn without_bias(in_features: usize, out_features: usize) -> Self {
96        let scale = (2.0 / (in_features + out_features) as f32).sqrt();
97        let weight_data: Vec<f32> = (0..in_features * out_features)
98            .map(|i| {
99                let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
100                x * scale
101            })
102            .collect();
103
104        let weight = Parameter::named(
105            "weight",
106            Tensor::from_vec(weight_data, &[in_features, out_features])
107                .expect("tensor creation failed"),
108            true,
109        );
110
111        Self {
112            weight,
113            bias: None,
114            in_features,
115            out_features,
116        }
117    }
118
119    /// Graph convolution forward pass.
120    ///
121    /// # Arguments
122    /// * `x` - Node features: `(batch, num_nodes, in_features)`
123    /// * `adj` - Adjacency matrix: `(num_nodes, num_nodes)` or `(batch, num_nodes, num_nodes)`
124    ///
125    /// # Returns
126    /// Output features: `(batch, num_nodes, out_features)`
127    pub fn forward_graph(&self, x: &Variable, adj: &Variable) -> Variable {
128        let shape = x.shape();
129        assert!(
130            shape.len() == 3,
131            "GCNConv expects input shape (batch, nodes, features), got {:?}",
132            shape
133        );
134        assert_eq!(shape[2], self.in_features, "Input features mismatch");
135
136        let batch = shape[0];
137        let adj_shape = adj.shape();
138
139        // GCN: output = adj @ x @ weight + bias
140        // Use Variable operations to preserve gradient flow
141        //
142        // Process per-sample: for each batch element, compute adj @ x_b @ weight
143        // This preserves gradient flow through x and weight via matmul backward.
144        let weight = self.weight.variable();
145
146        let mut per_sample: Vec<Variable> = Vec::with_capacity(batch);
147        for b in 0..batch {
148            // x_b: (nodes, in_features)
149            let x_b = x.select(0, b);
150
151            // adj_b: (nodes, nodes)
152            let adj_b = if adj_shape.len() == 3 {
153                adj.select(0, b)
154            } else {
155                adj.clone()
156            };
157
158            // message_b = adj_b @ x_b → (nodes, in_features)
159            let msg_b = adj_b.matmul(&x_b);
160
161            // out_b = msg_b @ weight → (nodes, out_features)
162            let mut out_b = msg_b.matmul(&weight);
163
164            // Add bias
165            if let Some(bias) = &self.bias {
166                out_b = out_b.add_var(&bias.variable());
167            }
168
169            // Unsqueeze to (1, nodes, out_features) for stacking
170            per_sample.push(out_b.unsqueeze(0));
171        }
172
173        // Stack along batch dimension
174        let refs: Vec<&Variable> = per_sample.iter().collect();
175        Variable::cat(&refs, 0)
176    }
177
178    /// Returns the input feature dimension.
179    pub fn in_features(&self) -> usize {
180        self.in_features
181    }
182
183    /// Returns the output feature dimension.
184    pub fn out_features(&self) -> usize {
185        self.out_features
186    }
187}
188
189impl Module for GCNConv {
190    /// Forward with identity adjacency (self-loops only).
191    /// For proper graph convolution, use `forward_graph(x, adj)`.
192    fn forward(&self, input: &Variable) -> Variable {
193        // Use identity adjacency: each node only aggregates from itself
194        let n = input.shape()[0];
195        let mut eye_data = vec![0.0f32; n * n];
196        for i in 0..n {
197            eye_data[i * n + i] = 1.0;
198        }
199        let adj = Variable::new(
200            axonml_tensor::Tensor::from_vec(eye_data, &[n, n])
201                .expect("identity matrix creation failed"),
202            false,
203        );
204        self.forward_graph(input, &adj)
205    }
206
207    fn parameters(&self) -> Vec<Parameter> {
208        let mut params = vec![self.weight.clone()];
209        if let Some(bias) = &self.bias {
210            params.push(bias.clone());
211        }
212        params
213    }
214
215    fn named_parameters(&self) -> HashMap<String, Parameter> {
216        let mut params = HashMap::new();
217        params.insert("weight".to_string(), self.weight.clone());
218        if let Some(bias) = &self.bias {
219            params.insert("bias".to_string(), bias.clone());
220        }
221        params
222    }
223
224    fn name(&self) -> &'static str {
225        "GCNConv"
226    }
227}
228
229// =============================================================================
230// GATConv
231// =============================================================================
232
233/// Graph Attention Network layer (Veličković et al., 2018).
234///
235/// Computes attention-weighted graph convolution where attention coefficients
236/// are learned based on node features, masked by the adjacency matrix.
237///
238/// # Example
239/// ```ignore
240/// use axonml_nn::layers::GATConv;
241///
242/// let gat = GATConv::new(72, 32, 4); // 4 attention heads
243/// let x = Variable::new(Tensor::randn(&[2, 7, 72]), true);
244/// let adj = Variable::new(Tensor::ones(&[7, 7]), false);
245/// let output = gat.forward_graph(&x, &adj);
246/// // output shape: (2, 7, 128)  — 32 * 4 heads
247/// ```
248pub struct GATConv {
249    w: Parameter,
250    attn_src: Parameter,
251    attn_dst: Parameter,
252    bias: Option<Parameter>,
253    in_features: usize,
254    out_features: usize,
255    num_heads: usize,
256    negative_slope: f32,
257}
258
259impl GATConv {
260    /// Creates a new GAT convolution layer.
261    ///
262    /// # Arguments
263    /// * `in_features` - Input feature dimension per node
264    /// * `out_features` - Output feature dimension per head
265    /// * `num_heads` - Number of attention heads (output = out_features * num_heads)
266    pub fn new(in_features: usize, out_features: usize, num_heads: usize) -> Self {
267        let total_out = out_features * num_heads;
268        let scale = (2.0 / (in_features + total_out) as f32).sqrt();
269
270        let w_data: Vec<f32> = (0..in_features * total_out)
271            .map(|i| {
272                let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
273                x * scale
274            })
275            .collect();
276
277        let w = Parameter::named(
278            "w",
279            Tensor::from_vec(w_data, &[in_features, total_out]).expect("tensor creation failed"),
280            true,
281        );
282
283        // Attention vectors: one per head for source and destination
284        let attn_scale = (1.0 / out_features as f32).sqrt();
285        let attn_src_data: Vec<f32> = (0..total_out)
286            .map(|i| {
287                let x = ((i as f32 * 0.723_606_8) % 1.0) * 2.0 - 1.0;
288                x * attn_scale
289            })
290            .collect();
291        let attn_dst_data: Vec<f32> = (0..total_out)
292            .map(|i| {
293                let x = ((i as f32 * 0.381_966_02) % 1.0) * 2.0 - 1.0;
294                x * attn_scale
295            })
296            .collect();
297
298        let attn_src = Parameter::named(
299            "attn_src",
300            Tensor::from_vec(attn_src_data, &[num_heads, out_features])
301                .expect("tensor creation failed"),
302            true,
303        );
304
305        let attn_dst = Parameter::named(
306            "attn_dst",
307            Tensor::from_vec(attn_dst_data, &[num_heads, out_features])
308                .expect("tensor creation failed"),
309            true,
310        );
311
312        let bias_data = vec![0.0; total_out];
313        let bias = Some(Parameter::named(
314            "bias",
315            Tensor::from_vec(bias_data, &[total_out]).expect("tensor creation failed"),
316            true,
317        ));
318
319        Self {
320            w,
321            attn_src,
322            attn_dst,
323            bias,
324            in_features,
325            out_features,
326            num_heads,
327            negative_slope: 0.2,
328        }
329    }
330
331    /// Graph attention forward pass.
332    ///
333    /// # Arguments
334    /// * `x` - Node features: `(batch, num_nodes, in_features)`
335    /// * `adj` - Adjacency mask: `(num_nodes, num_nodes)` — non-zero entries allow attention
336    ///
337    /// # Returns
338    /// Output features: `(batch, num_nodes, out_features * num_heads)`
339    pub fn forward_graph(&self, x: &Variable, adj: &Variable) -> Variable {
340        let shape = x.shape();
341        assert!(
342            shape.len() == 3,
343            "GATConv expects (batch, nodes, features), got {:?}",
344            shape
345        );
346
347        let batch = shape[0];
348        let nodes = shape[1];
349        let total_out = self.out_features * self.num_heads;
350
351        let x_data = x.data().to_vec();
352        let adj_data = adj.data().to_vec();
353        let w_data = self.w.data().to_vec();
354        let attn_src_data = self.attn_src.data().to_vec();
355        let attn_dst_data = self.attn_dst.data().to_vec();
356
357        let adj_nodes = if adj.shape().len() == 3 {
358            adj.shape()[1]
359        } else {
360            adj.shape()[0]
361        };
362        assert_eq!(adj_nodes, nodes, "Adjacency matrix size mismatch");
363
364        let mut output = vec![0.0f32; batch * nodes * total_out];
365
366        for b in 0..batch {
367            // Step 1: Project all nodes: h = x @ w → (nodes, total_out)
368            let mut h = vec![0.0f32; nodes * total_out];
369            for i in 0..nodes {
370                let x_off = (b * nodes + i) * self.in_features;
371                for o in 0..total_out {
372                    let mut val = 0.0;
373                    for f in 0..self.in_features {
374                        val += x_data[x_off + f] * w_data[f * total_out + o];
375                    }
376                    h[i * total_out + o] = val;
377                }
378            }
379
380            // Step 2: Compute attention per head
381            let adj_off = if adj.shape().len() == 3 {
382                b * nodes * nodes
383            } else {
384                0
385            };
386
387            for head in 0..self.num_heads {
388                let head_off = head * self.out_features;
389
390                // Compute attention scores for each edge
391                // e_ij = LeakyReLU(attn_src · h_i + attn_dst · h_j)
392                let mut attn_scores = vec![f32::NEG_INFINITY; nodes * nodes];
393
394                for i in 0..nodes {
395                    // src score for node i
396                    let mut src_score = 0.0;
397                    for f in 0..self.out_features {
398                        src_score += h[i * total_out + head_off + f]
399                            * attn_src_data[head * self.out_features + f];
400                    }
401
402                    for j in 0..nodes {
403                        let a_ij = adj_data[adj_off + i * nodes + j];
404                        if a_ij != 0.0 {
405                            let mut dst_score = 0.0;
406                            for f in 0..self.out_features {
407                                dst_score += h[j * total_out + head_off + f]
408                                    * attn_dst_data[head * self.out_features + f];
409                            }
410
411                            let e = src_score + dst_score;
412                            // LeakyReLU
413                            let e = if e > 0.0 { e } else { e * self.negative_slope };
414                            attn_scores[i * nodes + j] = e;
415                        }
416                    }
417                }
418
419                // Softmax per row (per destination node)
420                for i in 0..nodes {
421                    let row_start = i * nodes;
422                    let row_end = row_start + nodes;
423                    let row = &attn_scores[row_start..row_end];
424
425                    let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
426                    if max_val == f32::NEG_INFINITY {
427                        continue; // No neighbors
428                    }
429
430                    let mut sum_exp = 0.0f32;
431                    let mut exps = vec![0.0; nodes];
432                    for j in 0..nodes {
433                        if row[j] > f32::NEG_INFINITY {
434                            exps[j] = (row[j] - max_val).exp();
435                            sum_exp += exps[j];
436                        }
437                    }
438
439                    // Weighted sum of neighbor features
440                    let out_off = (b * nodes + i) * total_out + head_off;
441                    for j in 0..nodes {
442                        if exps[j] > 0.0 {
443                            let alpha = exps[j] / sum_exp;
444                            for f in 0..self.out_features {
445                                output[out_off + f] += alpha * h[j * total_out + head_off + f];
446                            }
447                        }
448                    }
449                }
450            }
451        }
452
453        // Add bias
454        if let Some(bias) = &self.bias {
455            let bias_data = bias.data().to_vec();
456            for b in 0..batch {
457                for i in 0..nodes {
458                    let offset = (b * nodes + i) * total_out;
459                    for o in 0..total_out {
460                        output[offset + o] += bias_data[o];
461                    }
462                }
463            }
464        }
465
466        Variable::new(
467            Tensor::from_vec(output, &[batch, nodes, total_out]).expect("tensor creation failed"),
468            x.requires_grad(),
469        )
470    }
471
472    /// Total output dimension (out_features * num_heads).
473    pub fn total_out_features(&self) -> usize {
474        self.out_features * self.num_heads
475    }
476}
477
478impl Module for GATConv {
479    fn forward(&self, input: &Variable) -> Variable {
480        // Use identity adjacency: each node only attends to itself
481        let n = input.shape()[0];
482        let mut eye_data = vec![0.0f32; n * n];
483        for i in 0..n {
484            eye_data[i * n + i] = 1.0;
485        }
486        let adj = Variable::new(
487            axonml_tensor::Tensor::from_vec(eye_data, &[n, n])
488                .expect("identity matrix creation failed"),
489            false,
490        );
491        self.forward_graph(input, &adj)
492    }
493
494    fn parameters(&self) -> Vec<Parameter> {
495        let mut params = vec![self.w.clone(), self.attn_src.clone(), self.attn_dst.clone()];
496        if let Some(bias) = &self.bias {
497            params.push(bias.clone());
498        }
499        params
500    }
501
502    fn named_parameters(&self) -> HashMap<String, Parameter> {
503        let mut params = HashMap::new();
504        params.insert("w".to_string(), self.w.clone());
505        params.insert("attn_src".to_string(), self.attn_src.clone());
506        params.insert("attn_dst".to_string(), self.attn_dst.clone());
507        if let Some(bias) = &self.bias {
508            params.insert("bias".to_string(), bias.clone());
509        }
510        params
511    }
512
513    fn name(&self) -> &'static str {
514        "GATConv"
515    }
516}
517
518// =============================================================================
519// Tests
520// =============================================================================
521
522#[cfg(test)]
523mod tests {
524    use super::*;
525
526    #[test]
527    fn test_gcn_conv_shape() {
528        let gcn = GCNConv::new(72, 128);
529        let x = Variable::new(
530            Tensor::from_vec(vec![1.0; 2 * 7 * 72], &[2, 7, 72]).expect("tensor creation failed"),
531            false,
532        );
533        let adj = Variable::new(
534            Tensor::from_vec(vec![1.0; 7 * 7], &[7, 7]).expect("tensor creation failed"),
535            false,
536        );
537        let output = gcn.forward_graph(&x, &adj);
538        assert_eq!(output.shape(), vec![2, 7, 128]);
539    }
540
541    #[test]
542    fn test_gcn_conv_identity_adjacency() {
543        // With identity adjacency, each node only sees itself
544        let gcn = GCNConv::new(4, 8);
545        let x = Variable::new(
546            Tensor::from_vec(vec![1.0; 1 * 3 * 4], &[1, 3, 4]).expect("tensor creation failed"),
547            false,
548        );
549
550        // Identity adjacency
551        let mut adj_data = vec![0.0; 9];
552        adj_data[0] = 1.0; // (0,0)
553        adj_data[4] = 1.0; // (1,1)
554        adj_data[8] = 1.0; // (2,2)
555        let adj = Variable::new(
556            Tensor::from_vec(adj_data, &[3, 3]).expect("tensor creation failed"),
557            false,
558        );
559
560        let output = gcn.forward_graph(&x, &adj);
561        assert_eq!(output.shape(), vec![1, 3, 8]);
562
563        // All nodes have same input, so all should produce same output
564        let data = output.data().to_vec();
565        for i in 0..3 {
566            for f in 0..8 {
567                assert!(
568                    (data[i * 8 + f] - data[f]).abs() < 1e-6,
569                    "Node outputs should be identical with identity adj and same input"
570                );
571            }
572        }
573    }
574
575    #[test]
576    fn test_gcn_conv_parameters() {
577        let gcn = GCNConv::new(16, 32);
578        let params = gcn.parameters();
579        assert_eq!(params.len(), 2); // weight + bias
580
581        let total_params: usize = params.iter().map(|p| p.numel()).sum();
582        assert_eq!(total_params, 16 * 32 + 32); // weight + bias
583    }
584
585    #[test]
586    fn test_gcn_conv_no_bias() {
587        let gcn = GCNConv::without_bias(16, 32);
588        let params = gcn.parameters();
589        assert_eq!(params.len(), 1); // weight only
590    }
591
592    #[test]
593    fn test_gcn_conv_named_parameters() {
594        let gcn = GCNConv::new(16, 32);
595        let params = gcn.named_parameters();
596        assert!(params.contains_key("weight"));
597        assert!(params.contains_key("bias"));
598    }
599
600    #[test]
601    fn test_gat_conv_shape() {
602        let gat = GATConv::new(72, 32, 4); // 4 heads
603        let x = Variable::new(
604            Tensor::from_vec(vec![1.0; 2 * 7 * 72], &[2, 7, 72]).expect("tensor creation failed"),
605            false,
606        );
607        let adj = Variable::new(
608            Tensor::from_vec(vec![1.0; 7 * 7], &[7, 7]).expect("tensor creation failed"),
609            false,
610        );
611        let output = gat.forward_graph(&x, &adj);
612        assert_eq!(output.shape(), vec![2, 7, 128]); // 32 * 4 = 128
613    }
614
615    #[test]
616    fn test_gat_conv_single_head() {
617        let gat = GATConv::new(16, 8, 1);
618        let x = Variable::new(
619            Tensor::from_vec(vec![1.0; 1 * 5 * 16], &[1, 5, 16]).expect("tensor creation failed"),
620            false,
621        );
622        let adj = Variable::new(
623            Tensor::from_vec(vec![1.0; 5 * 5], &[5, 5]).expect("tensor creation failed"),
624            false,
625        );
626        let output = gat.forward_graph(&x, &adj);
627        assert_eq!(output.shape(), vec![1, 5, 8]);
628    }
629
630    #[test]
631    fn test_gat_conv_parameters() {
632        let gat = GATConv::new(16, 8, 4);
633        let params = gat.parameters();
634        assert_eq!(params.len(), 4); // w, attn_src, attn_dst, bias
635
636        let named = gat.named_parameters();
637        assert!(named.contains_key("w"));
638        assert!(named.contains_key("attn_src"));
639        assert!(named.contains_key("attn_dst"));
640        assert!(named.contains_key("bias"));
641    }
642
643    #[test]
644    fn test_gat_conv_total_output() {
645        let gat = GATConv::new(16, 32, 4);
646        assert_eq!(gat.total_out_features(), 128);
647    }
648
649    #[test]
650    fn test_gcn_zero_adjacency() {
651        // Zero adjacency should produce only bias in output
652        let gcn = GCNConv::new(4, 4);
653        let x = Variable::new(
654            Tensor::from_vec(vec![99.0; 1 * 3 * 4], &[1, 3, 4]).expect("tensor creation failed"),
655            false,
656        );
657        let adj = Variable::new(
658            Tensor::from_vec(vec![0.0; 9], &[3, 3]).expect("tensor creation failed"),
659            false,
660        );
661        let output = gcn.forward_graph(&x, &adj);
662
663        // With zero adjacency, output should be just bias (all zeros initially)
664        let data = output.data().to_vec();
665        for val in &data {
666            assert!(
667                val.abs() < 1e-6,
668                "Zero adjacency should zero out message passing"
669            );
670        }
671    }
672}