Skip to main content

axonml_nn/layers/
graph.rs

1//! Graph Neural Network Layers
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/graph.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_tensor::Tensor;
21
22use crate::module::Module;
23use crate::parameter::Parameter;
24
25// =============================================================================
26// GCNConv
27// =============================================================================
28
29/// Graph Convolutional Network layer (Kipf & Welling, 2017).
30///
31/// Performs the graph convolution: `output = adj @ x @ weight + bias`
32/// where `adj` is the (possibly learned) adjacency matrix.
33///
34/// # Arguments
35/// * `in_features` - Number of input features per node
36/// * `out_features` - Number of output features per node
37///
38/// # Forward Signature
39/// Uses `forward_graph(x, adj)` since it requires an adjacency matrix.
40/// The standard `forward()` from Module is provided but panics—use `forward_graph()`.
41///
42/// # Example
43/// ```ignore
44/// use axonml_nn::layers::GCNConv;
45///
46/// let gcn = GCNConv::new(72, 128);
47/// let x = Variable::new(Tensor::randn(&[2, 7, 72]), true);     // (batch, nodes, features)
48/// let adj = Variable::new(Tensor::ones(&[7, 7]), false);         // (nodes, nodes)
49/// let output = gcn.forward_graph(&x, &adj);
50/// // output shape: (2, 7, 128)
51/// ```
52pub struct GCNConv {
53    weight: Parameter,
54    bias: Option<Parameter>,
55    in_features: usize,
56    out_features: usize,
57}
58
59impl GCNConv {
60    /// Creates a new GCN convolution layer with bias.
61    pub fn new(in_features: usize, out_features: usize) -> Self {
62        // Xavier initialization
63        let scale = (2.0 / (in_features + out_features) as f32).sqrt();
64        let weight_data: Vec<f32> = (0..in_features * out_features)
65            .map(|i| {
66                // Simple deterministic-ish init for reproducibility
67                let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
68                x * scale
69            })
70            .collect();
71
72        let weight = Parameter::named(
73            "weight",
74            Tensor::from_vec(weight_data, &[in_features, out_features]).expect("tensor creation failed"),
75            true,
76        );
77
78        let bias_data = vec![0.0; out_features];
79        let bias = Some(Parameter::named(
80            "bias",
81            Tensor::from_vec(bias_data, &[out_features]).expect("tensor creation failed"),
82            true,
83        ));
84
85        Self {
86            weight,
87            bias,
88            in_features,
89            out_features,
90        }
91    }
92
93    /// Creates a GCN layer without bias.
94    pub fn without_bias(in_features: usize, out_features: usize) -> Self {
95        let scale = (2.0 / (in_features + out_features) as f32).sqrt();
96        let weight_data: Vec<f32> = (0..in_features * out_features)
97            .map(|i| {
98                let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
99                x * scale
100            })
101            .collect();
102
103        let weight = Parameter::named(
104            "weight",
105            Tensor::from_vec(weight_data, &[in_features, out_features]).expect("tensor creation failed"),
106            true,
107        );
108
109        Self {
110            weight,
111            bias: None,
112            in_features,
113            out_features,
114        }
115    }
116
117    /// Graph convolution forward pass.
118    ///
119    /// # Arguments
120    /// * `x` - Node features: `(batch, num_nodes, in_features)`
121    /// * `adj` - Adjacency matrix: `(num_nodes, num_nodes)` or `(batch, num_nodes, num_nodes)`
122    ///
123    /// # Returns
124    /// Output features: `(batch, num_nodes, out_features)`
125    pub fn forward_graph(&self, x: &Variable, adj: &Variable) -> Variable {
126        let shape = x.shape();
127        assert!(
128            shape.len() == 3,
129            "GCNConv expects input shape (batch, nodes, features), got {:?}",
130            shape
131        );
132        assert_eq!(shape[2], self.in_features, "Input features mismatch");
133
134        let batch = shape[0];
135        let adj_shape = adj.shape();
136
137        // GCN: output = adj @ x @ weight + bias
138        // Use Variable operations to preserve gradient flow
139        //
140        // Process per-sample: for each batch element, compute adj @ x_b @ weight
141        // This preserves gradient flow through x and weight via matmul backward.
142        let weight = self.weight.variable();
143
144        let mut per_sample: Vec<Variable> = Vec::with_capacity(batch);
145        for b in 0..batch {
146            // x_b: (nodes, in_features)
147            let x_b = x.select(0, b);
148
149            // adj_b: (nodes, nodes)
150            let adj_b = if adj_shape.len() == 3 {
151                adj.select(0, b)
152            } else {
153                adj.clone()
154            };
155
156            // message_b = adj_b @ x_b → (nodes, in_features)
157            let msg_b = adj_b.matmul(&x_b);
158
159            // out_b = msg_b @ weight → (nodes, out_features)
160            let mut out_b = msg_b.matmul(&weight);
161
162            // Add bias
163            if let Some(bias) = &self.bias {
164                out_b = out_b.add_var(&bias.variable());
165            }
166
167            // Unsqueeze to (1, nodes, out_features) for stacking
168            per_sample.push(out_b.unsqueeze(0));
169        }
170
171        // Stack along batch dimension
172        let refs: Vec<&Variable> = per_sample.iter().collect();
173        Variable::cat(&refs, 0)
174    }
175
176    /// Returns the input feature dimension.
177    pub fn in_features(&self) -> usize {
178        self.in_features
179    }
180
181    /// Returns the output feature dimension.
182    pub fn out_features(&self) -> usize {
183        self.out_features
184    }
185}
186
187impl Module for GCNConv {
188    /// Forward with identity adjacency (self-loops only).
189    /// For proper graph convolution, use `forward_graph(x, adj)`.
190    fn forward(&self, input: &Variable) -> Variable {
191        // Use identity adjacency: each node only aggregates from itself
192        let n = input.shape()[0];
193        let mut eye_data = vec![0.0f32; n * n];
194        for i in 0..n {
195            eye_data[i * n + i] = 1.0;
196        }
197        let adj = Variable::new(
198            axonml_tensor::Tensor::from_vec(eye_data, &[n, n]).expect("identity matrix creation failed"),
199            false,
200        );
201        self.forward_graph(input, &adj)
202    }
203
204    fn parameters(&self) -> Vec<Parameter> {
205        let mut params = vec![self.weight.clone()];
206        if let Some(bias) = &self.bias {
207            params.push(bias.clone());
208        }
209        params
210    }
211
212    fn named_parameters(&self) -> HashMap<String, Parameter> {
213        let mut params = HashMap::new();
214        params.insert("weight".to_string(), self.weight.clone());
215        if let Some(bias) = &self.bias {
216            params.insert("bias".to_string(), bias.clone());
217        }
218        params
219    }
220
221    fn name(&self) -> &'static str {
222        "GCNConv"
223    }
224}
225
226// =============================================================================
227// GATConv
228// =============================================================================
229
230/// Graph Attention Network layer (Veličković et al., 2018).
231///
232/// Computes attention-weighted graph convolution where attention coefficients
233/// are learned based on node features, masked by the adjacency matrix.
234///
235/// # Example
236/// ```ignore
237/// use axonml_nn::layers::GATConv;
238///
239/// let gat = GATConv::new(72, 32, 4); // 4 attention heads
240/// let x = Variable::new(Tensor::randn(&[2, 7, 72]), true);
241/// let adj = Variable::new(Tensor::ones(&[7, 7]), false);
242/// let output = gat.forward_graph(&x, &adj);
243/// // output shape: (2, 7, 128)  — 32 * 4 heads
244/// ```
245pub struct GATConv {
246    w: Parameter,
247    attn_src: Parameter,
248    attn_dst: Parameter,
249    bias: Option<Parameter>,
250    in_features: usize,
251    out_features: usize,
252    num_heads: usize,
253    negative_slope: f32,
254}
255
256impl GATConv {
257    /// Creates a new GAT convolution layer.
258    ///
259    /// # Arguments
260    /// * `in_features` - Input feature dimension per node
261    /// * `out_features` - Output feature dimension per head
262    /// * `num_heads` - Number of attention heads (output = out_features * num_heads)
263    pub fn new(in_features: usize, out_features: usize, num_heads: usize) -> Self {
264        let total_out = out_features * num_heads;
265        let scale = (2.0 / (in_features + total_out) as f32).sqrt();
266
267        let w_data: Vec<f32> = (0..in_features * total_out)
268            .map(|i| {
269                let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
270                x * scale
271            })
272            .collect();
273
274        let w = Parameter::named(
275            "w",
276            Tensor::from_vec(w_data, &[in_features, total_out]).expect("tensor creation failed"),
277            true,
278        );
279
280        // Attention vectors: one per head for source and destination
281        let attn_scale = (1.0 / out_features as f32).sqrt();
282        let attn_src_data: Vec<f32> = (0..total_out)
283            .map(|i| {
284                let x = ((i as f32 * 0.723_606_8) % 1.0) * 2.0 - 1.0;
285                x * attn_scale
286            })
287            .collect();
288        let attn_dst_data: Vec<f32> = (0..total_out)
289            .map(|i| {
290                let x = ((i as f32 * 0.381_966_02) % 1.0) * 2.0 - 1.0;
291                x * attn_scale
292            })
293            .collect();
294
295        let attn_src = Parameter::named(
296            "attn_src",
297            Tensor::from_vec(attn_src_data, &[num_heads, out_features]).expect("tensor creation failed"),
298            true,
299        );
300
301        let attn_dst = Parameter::named(
302            "attn_dst",
303            Tensor::from_vec(attn_dst_data, &[num_heads, out_features]).expect("tensor creation failed"),
304            true,
305        );
306
307        let bias_data = vec![0.0; total_out];
308        let bias = Some(Parameter::named(
309            "bias",
310            Tensor::from_vec(bias_data, &[total_out]).expect("tensor creation failed"),
311            true,
312        ));
313
314        Self {
315            w,
316            attn_src,
317            attn_dst,
318            bias,
319            in_features,
320            out_features,
321            num_heads,
322            negative_slope: 0.2,
323        }
324    }
325
326    /// Graph attention forward pass.
327    ///
328    /// # Arguments
329    /// * `x` - Node features: `(batch, num_nodes, in_features)`
330    /// * `adj` - Adjacency mask: `(num_nodes, num_nodes)` — non-zero entries allow attention
331    ///
332    /// # Returns
333    /// Output features: `(batch, num_nodes, out_features * num_heads)`
334    pub fn forward_graph(&self, x: &Variable, adj: &Variable) -> Variable {
335        let shape = x.shape();
336        assert!(
337            shape.len() == 3,
338            "GATConv expects (batch, nodes, features), got {:?}",
339            shape
340        );
341
342        let batch = shape[0];
343        let nodes = shape[1];
344        let total_out = self.out_features * self.num_heads;
345
346        let x_data = x.data().to_vec();
347        let adj_data = adj.data().to_vec();
348        let w_data = self.w.data().to_vec();
349        let attn_src_data = self.attn_src.data().to_vec();
350        let attn_dst_data = self.attn_dst.data().to_vec();
351
352        let adj_nodes = if adj.shape().len() == 3 {
353            adj.shape()[1]
354        } else {
355            adj.shape()[0]
356        };
357        assert_eq!(adj_nodes, nodes, "Adjacency matrix size mismatch");
358
359        let mut output = vec![0.0f32; batch * nodes * total_out];
360
361        for b in 0..batch {
362            // Step 1: Project all nodes: h = x @ w → (nodes, total_out)
363            let mut h = vec![0.0f32; nodes * total_out];
364            for i in 0..nodes {
365                let x_off = (b * nodes + i) * self.in_features;
366                for o in 0..total_out {
367                    let mut val = 0.0;
368                    for f in 0..self.in_features {
369                        val += x_data[x_off + f] * w_data[f * total_out + o];
370                    }
371                    h[i * total_out + o] = val;
372                }
373            }
374
375            // Step 2: Compute attention per head
376            let adj_off = if adj.shape().len() == 3 {
377                b * nodes * nodes
378            } else {
379                0
380            };
381
382            for head in 0..self.num_heads {
383                let head_off = head * self.out_features;
384
385                // Compute attention scores for each edge
386                // e_ij = LeakyReLU(attn_src · h_i + attn_dst · h_j)
387                let mut attn_scores = vec![f32::NEG_INFINITY; nodes * nodes];
388
389                for i in 0..nodes {
390                    // src score for node i
391                    let mut src_score = 0.0;
392                    for f in 0..self.out_features {
393                        src_score += h[i * total_out + head_off + f]
394                            * attn_src_data[head * self.out_features + f];
395                    }
396
397                    for j in 0..nodes {
398                        let a_ij = adj_data[adj_off + i * nodes + j];
399                        if a_ij != 0.0 {
400                            let mut dst_score = 0.0;
401                            for f in 0..self.out_features {
402                                dst_score += h[j * total_out + head_off + f]
403                                    * attn_dst_data[head * self.out_features + f];
404                            }
405
406                            let e = src_score + dst_score;
407                            // LeakyReLU
408                            let e = if e > 0.0 { e } else { e * self.negative_slope };
409                            attn_scores[i * nodes + j] = e;
410                        }
411                    }
412                }
413
414                // Softmax per row (per destination node)
415                for i in 0..nodes {
416                    let row_start = i * nodes;
417                    let row_end = row_start + nodes;
418                    let row = &attn_scores[row_start..row_end];
419
420                    let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
421                    if max_val == f32::NEG_INFINITY {
422                        continue; // No neighbors
423                    }
424
425                    let mut sum_exp = 0.0f32;
426                    let mut exps = vec![0.0; nodes];
427                    for j in 0..nodes {
428                        if row[j] > f32::NEG_INFINITY {
429                            exps[j] = (row[j] - max_val).exp();
430                            sum_exp += exps[j];
431                        }
432                    }
433
434                    // Weighted sum of neighbor features
435                    let out_off = (b * nodes + i) * total_out + head_off;
436                    for j in 0..nodes {
437                        if exps[j] > 0.0 {
438                            let alpha = exps[j] / sum_exp;
439                            for f in 0..self.out_features {
440                                output[out_off + f] += alpha * h[j * total_out + head_off + f];
441                            }
442                        }
443                    }
444                }
445            }
446        }
447
448        // Add bias
449        if let Some(bias) = &self.bias {
450            let bias_data = bias.data().to_vec();
451            for b in 0..batch {
452                for i in 0..nodes {
453                    let offset = (b * nodes + i) * total_out;
454                    for o in 0..total_out {
455                        output[offset + o] += bias_data[o];
456                    }
457                }
458            }
459        }
460
461        Variable::new(
462            Tensor::from_vec(output, &[batch, nodes, total_out]).expect("tensor creation failed"),
463            x.requires_grad(),
464        )
465    }
466
467    /// Total output dimension (out_features * num_heads).
468    pub fn total_out_features(&self) -> usize {
469        self.out_features * self.num_heads
470    }
471}
472
473impl Module for GATConv {
474    fn forward(&self, input: &Variable) -> Variable {
475        // Use identity adjacency: each node only attends to itself
476        let n = input.shape()[0];
477        let mut eye_data = vec![0.0f32; n * n];
478        for i in 0..n {
479            eye_data[i * n + i] = 1.0;
480        }
481        let adj = Variable::new(
482            axonml_tensor::Tensor::from_vec(eye_data, &[n, n]).expect("identity matrix creation failed"),
483            false,
484        );
485        self.forward_graph(input, &adj)
486    }
487
488    fn parameters(&self) -> Vec<Parameter> {
489        let mut params = vec![self.w.clone(), self.attn_src.clone(), self.attn_dst.clone()];
490        if let Some(bias) = &self.bias {
491            params.push(bias.clone());
492        }
493        params
494    }
495
496    fn named_parameters(&self) -> HashMap<String, Parameter> {
497        let mut params = HashMap::new();
498        params.insert("w".to_string(), self.w.clone());
499        params.insert("attn_src".to_string(), self.attn_src.clone());
500        params.insert("attn_dst".to_string(), self.attn_dst.clone());
501        if let Some(bias) = &self.bias {
502            params.insert("bias".to_string(), bias.clone());
503        }
504        params
505    }
506
507    fn name(&self) -> &'static str {
508        "GATConv"
509    }
510}
511
512// =============================================================================
513// Tests
514// =============================================================================
515
516#[cfg(test)]
517mod tests {
518    use super::*;
519
520    #[test]
521    fn test_gcn_conv_shape() {
522        let gcn = GCNConv::new(72, 128);
523        let x = Variable::new(
524            Tensor::from_vec(vec![1.0; 2 * 7 * 72], &[2, 7, 72]).expect("tensor creation failed"),
525            false,
526        );
527        let adj = Variable::new(Tensor::from_vec(vec![1.0; 7 * 7], &[7, 7]).expect("tensor creation failed"), false);
528        let output = gcn.forward_graph(&x, &adj);
529        assert_eq!(output.shape(), vec![2, 7, 128]);
530    }
531
532    #[test]
533    fn test_gcn_conv_identity_adjacency() {
534        // With identity adjacency, each node only sees itself
535        let gcn = GCNConv::new(4, 8);
536        let x = Variable::new(
537            Tensor::from_vec(vec![1.0; 1 * 3 * 4], &[1, 3, 4]).expect("tensor creation failed"),
538            false,
539        );
540
541        // Identity adjacency
542        let mut adj_data = vec![0.0; 9];
543        adj_data[0] = 1.0; // (0,0)
544        adj_data[4] = 1.0; // (1,1)
545        adj_data[8] = 1.0; // (2,2)
546        let adj = Variable::new(Tensor::from_vec(adj_data, &[3, 3]).expect("tensor creation failed"), false);
547
548        let output = gcn.forward_graph(&x, &adj);
549        assert_eq!(output.shape(), vec![1, 3, 8]);
550
551        // All nodes have same input, so all should produce same output
552        let data = output.data().to_vec();
553        for i in 0..3 {
554            for f in 0..8 {
555                assert!(
556                    (data[i * 8 + f] - data[f]).abs() < 1e-6,
557                    "Node outputs should be identical with identity adj and same input"
558                );
559            }
560        }
561    }
562
563    #[test]
564    fn test_gcn_conv_parameters() {
565        let gcn = GCNConv::new(16, 32);
566        let params = gcn.parameters();
567        assert_eq!(params.len(), 2); // weight + bias
568
569        let total_params: usize = params.iter().map(|p| p.numel()).sum();
570        assert_eq!(total_params, 16 * 32 + 32); // weight + bias
571    }
572
573    #[test]
574    fn test_gcn_conv_no_bias() {
575        let gcn = GCNConv::without_bias(16, 32);
576        let params = gcn.parameters();
577        assert_eq!(params.len(), 1); // weight only
578    }
579
580    #[test]
581    fn test_gcn_conv_named_parameters() {
582        let gcn = GCNConv::new(16, 32);
583        let params = gcn.named_parameters();
584        assert!(params.contains_key("weight"));
585        assert!(params.contains_key("bias"));
586    }
587
588    #[test]
589    fn test_gat_conv_shape() {
590        let gat = GATConv::new(72, 32, 4); // 4 heads
591        let x = Variable::new(
592            Tensor::from_vec(vec![1.0; 2 * 7 * 72], &[2, 7, 72]).expect("tensor creation failed"),
593            false,
594        );
595        let adj = Variable::new(Tensor::from_vec(vec![1.0; 7 * 7], &[7, 7]).expect("tensor creation failed"), false);
596        let output = gat.forward_graph(&x, &adj);
597        assert_eq!(output.shape(), vec![2, 7, 128]); // 32 * 4 = 128
598    }
599
600    #[test]
601    fn test_gat_conv_single_head() {
602        let gat = GATConv::new(16, 8, 1);
603        let x = Variable::new(
604            Tensor::from_vec(vec![1.0; 1 * 5 * 16], &[1, 5, 16]).expect("tensor creation failed"),
605            false,
606        );
607        let adj = Variable::new(Tensor::from_vec(vec![1.0; 5 * 5], &[5, 5]).expect("tensor creation failed"), false);
608        let output = gat.forward_graph(&x, &adj);
609        assert_eq!(output.shape(), vec![1, 5, 8]);
610    }
611
612    #[test]
613    fn test_gat_conv_parameters() {
614        let gat = GATConv::new(16, 8, 4);
615        let params = gat.parameters();
616        assert_eq!(params.len(), 4); // w, attn_src, attn_dst, bias
617
618        let named = gat.named_parameters();
619        assert!(named.contains_key("w"));
620        assert!(named.contains_key("attn_src"));
621        assert!(named.contains_key("attn_dst"));
622        assert!(named.contains_key("bias"));
623    }
624
625    #[test]
626    fn test_gat_conv_total_output() {
627        let gat = GATConv::new(16, 32, 4);
628        assert_eq!(gat.total_out_features(), 128);
629    }
630
631    #[test]
632    fn test_gcn_zero_adjacency() {
633        // Zero adjacency should produce only bias in output
634        let gcn = GCNConv::new(4, 4);
635        let x = Variable::new(
636            Tensor::from_vec(vec![99.0; 1 * 3 * 4], &[1, 3, 4]).expect("tensor creation failed"),
637            false,
638        );
639        let adj = Variable::new(Tensor::from_vec(vec![0.0; 9], &[3, 3]).expect("tensor creation failed"), false);
640        let output = gcn.forward_graph(&x, &adj);
641
642        // With zero adjacency, output should be just bias (all zeros initially)
643        let data = output.data().to_vec();
644        for val in &data {
645            assert!(
646                val.abs() < 1e-6,
647                "Zero adjacency should zero out message passing"
648            );
649        }
650    }
651}