Skip to main content

axonml_nn/layers/
graph.rs

1//! Graph Neural Network Layers
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/graph.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_tensor::Tensor;
21
22use crate::module::Module;
23use crate::parameter::Parameter;
24
25// =============================================================================
26// GCNConv
27// =============================================================================
28
29/// Graph Convolutional Network layer (Kipf & Welling, 2017).
30///
31/// Performs the graph convolution: `output = adj @ x @ weight + bias`
32/// where `adj` is the (possibly learned) adjacency matrix.
33///
34/// # Arguments
35/// * `in_features` - Number of input features per node
36/// * `out_features` - Number of output features per node
37///
38/// # Forward Signature
39/// Uses `forward_graph(x, adj)` since it requires an adjacency matrix.
40/// The standard `forward()` from Module is provided but panics—use `forward_graph()`.
41///
42/// # Example
43/// ```ignore
44/// use axonml_nn::layers::GCNConv;
45///
46/// let gcn = GCNConv::new(72, 128);
47/// let x = Variable::new(Tensor::randn(&[2, 7, 72]), true);     // (batch, nodes, features)
48/// let adj = Variable::new(Tensor::ones(&[7, 7]), false);         // (nodes, nodes)
49/// let output = gcn.forward_graph(&x, &adj);
50/// // output shape: (2, 7, 128)
51/// ```
52pub struct GCNConv {
53    weight: Parameter,
54    bias: Option<Parameter>,
55    in_features: usize,
56    out_features: usize,
57}
58
59impl GCNConv {
60    /// Creates a new GCN convolution layer with bias.
61    pub fn new(in_features: usize, out_features: usize) -> Self {
62        // Xavier initialization
63        let scale = (2.0 / (in_features + out_features) as f32).sqrt();
64        let weight_data: Vec<f32> = (0..in_features * out_features)
65            .map(|i| {
66                // Simple deterministic-ish init for reproducibility
67                let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
68                x * scale
69            })
70            .collect();
71
72        let weight = Parameter::named(
73            "weight",
74            Tensor::from_vec(weight_data, &[in_features, out_features]).unwrap(),
75            true,
76        );
77
78        let bias_data = vec![0.0; out_features];
79        let bias = Some(Parameter::named(
80            "bias",
81            Tensor::from_vec(bias_data, &[out_features]).unwrap(),
82            true,
83        ));
84
85        Self {
86            weight,
87            bias,
88            in_features,
89            out_features,
90        }
91    }
92
93    /// Creates a GCN layer without bias.
94    pub fn without_bias(in_features: usize, out_features: usize) -> Self {
95        let scale = (2.0 / (in_features + out_features) as f32).sqrt();
96        let weight_data: Vec<f32> = (0..in_features * out_features)
97            .map(|i| {
98                let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
99                x * scale
100            })
101            .collect();
102
103        let weight = Parameter::named(
104            "weight",
105            Tensor::from_vec(weight_data, &[in_features, out_features]).unwrap(),
106            true,
107        );
108
109        Self {
110            weight,
111            bias: None,
112            in_features,
113            out_features,
114        }
115    }
116
117    /// Graph convolution forward pass.
118    ///
119    /// # Arguments
120    /// * `x` - Node features: `(batch, num_nodes, in_features)`
121    /// * `adj` - Adjacency matrix: `(num_nodes, num_nodes)` or `(batch, num_nodes, num_nodes)`
122    ///
123    /// # Returns
124    /// Output features: `(batch, num_nodes, out_features)`
125    pub fn forward_graph(&self, x: &Variable, adj: &Variable) -> Variable {
126        let shape = x.shape();
127        assert!(
128            shape.len() == 3,
129            "GCNConv expects input shape (batch, nodes, features), got {:?}",
130            shape
131        );
132        assert_eq!(shape[2], self.in_features, "Input features mismatch");
133
134        let batch = shape[0];
135        let adj_shape = adj.shape();
136
137        // GCN: output = adj @ x @ weight + bias
138        // Use Variable operations to preserve gradient flow
139        //
140        // Process per-sample: for each batch element, compute adj @ x_b @ weight
141        // This preserves gradient flow through x and weight via matmul backward.
142        let weight = self.weight.variable();
143
144        let mut per_sample: Vec<Variable> = Vec::with_capacity(batch);
145        for b in 0..batch {
146            // x_b: (nodes, in_features)
147            let x_b = x.select(0, b);
148
149            // adj_b: (nodes, nodes)
150            let adj_b = if adj_shape.len() == 3 {
151                adj.select(0, b)
152            } else {
153                adj.clone()
154            };
155
156            // message_b = adj_b @ x_b → (nodes, in_features)
157            let msg_b = adj_b.matmul(&x_b);
158
159            // out_b = msg_b @ weight → (nodes, out_features)
160            let mut out_b = msg_b.matmul(&weight);
161
162            // Add bias
163            if let Some(bias) = &self.bias {
164                out_b = out_b.add_var(&bias.variable());
165            }
166
167            // Unsqueeze to (1, nodes, out_features) for stacking
168            per_sample.push(out_b.unsqueeze(0));
169        }
170
171        // Stack along batch dimension
172        let refs: Vec<&Variable> = per_sample.iter().collect();
173        Variable::cat(&refs, 0)
174    }
175
176    /// Returns the input feature dimension.
177    pub fn in_features(&self) -> usize {
178        self.in_features
179    }
180
181    /// Returns the output feature dimension.
182    pub fn out_features(&self) -> usize {
183        self.out_features
184    }
185}
186
187impl Module for GCNConv {
188    fn forward(&self, _input: &Variable) -> Variable {
189        panic!("GCNConv requires an adjacency matrix. Use forward_graph(x, adj) instead.")
190    }
191
192    fn parameters(&self) -> Vec<Parameter> {
193        let mut params = vec![self.weight.clone()];
194        if let Some(bias) = &self.bias {
195            params.push(bias.clone());
196        }
197        params
198    }
199
200    fn named_parameters(&self) -> HashMap<String, Parameter> {
201        let mut params = HashMap::new();
202        params.insert("weight".to_string(), self.weight.clone());
203        if let Some(bias) = &self.bias {
204            params.insert("bias".to_string(), bias.clone());
205        }
206        params
207    }
208
209    fn name(&self) -> &'static str {
210        "GCNConv"
211    }
212}
213
214// =============================================================================
215// GATConv
216// =============================================================================
217
218/// Graph Attention Network layer (Veličković et al., 2018).
219///
220/// Computes attention-weighted graph convolution where attention coefficients
221/// are learned based on node features, masked by the adjacency matrix.
222///
223/// # Example
224/// ```ignore
225/// use axonml_nn::layers::GATConv;
226///
227/// let gat = GATConv::new(72, 32, 4); // 4 attention heads
228/// let x = Variable::new(Tensor::randn(&[2, 7, 72]), true);
229/// let adj = Variable::new(Tensor::ones(&[7, 7]), false);
230/// let output = gat.forward_graph(&x, &adj);
231/// // output shape: (2, 7, 128)  — 32 * 4 heads
232/// ```
233pub struct GATConv {
234    w: Parameter,
235    attn_src: Parameter,
236    attn_dst: Parameter,
237    bias: Option<Parameter>,
238    in_features: usize,
239    out_features: usize,
240    num_heads: usize,
241    negative_slope: f32,
242}
243
244impl GATConv {
245    /// Creates a new GAT convolution layer.
246    ///
247    /// # Arguments
248    /// * `in_features` - Input feature dimension per node
249    /// * `out_features` - Output feature dimension per head
250    /// * `num_heads` - Number of attention heads (output = out_features * num_heads)
251    pub fn new(in_features: usize, out_features: usize, num_heads: usize) -> Self {
252        let total_out = out_features * num_heads;
253        let scale = (2.0 / (in_features + total_out) as f32).sqrt();
254
255        let w_data: Vec<f32> = (0..in_features * total_out)
256            .map(|i| {
257                let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
258                x * scale
259            })
260            .collect();
261
262        let w = Parameter::named(
263            "w",
264            Tensor::from_vec(w_data, &[in_features, total_out]).unwrap(),
265            true,
266        );
267
268        // Attention vectors: one per head for source and destination
269        let attn_scale = (1.0 / out_features as f32).sqrt();
270        let attn_src_data: Vec<f32> = (0..total_out)
271            .map(|i| {
272                let x = ((i as f32 * 0.723_606_8) % 1.0) * 2.0 - 1.0;
273                x * attn_scale
274            })
275            .collect();
276        let attn_dst_data: Vec<f32> = (0..total_out)
277            .map(|i| {
278                let x = ((i as f32 * 0.381_966_02) % 1.0) * 2.0 - 1.0;
279                x * attn_scale
280            })
281            .collect();
282
283        let attn_src = Parameter::named(
284            "attn_src",
285            Tensor::from_vec(attn_src_data, &[num_heads, out_features]).unwrap(),
286            true,
287        );
288
289        let attn_dst = Parameter::named(
290            "attn_dst",
291            Tensor::from_vec(attn_dst_data, &[num_heads, out_features]).unwrap(),
292            true,
293        );
294
295        let bias_data = vec![0.0; total_out];
296        let bias = Some(Parameter::named(
297            "bias",
298            Tensor::from_vec(bias_data, &[total_out]).unwrap(),
299            true,
300        ));
301
302        Self {
303            w,
304            attn_src,
305            attn_dst,
306            bias,
307            in_features,
308            out_features,
309            num_heads,
310            negative_slope: 0.2,
311        }
312    }
313
314    /// Graph attention forward pass.
315    ///
316    /// # Arguments
317    /// * `x` - Node features: `(batch, num_nodes, in_features)`
318    /// * `adj` - Adjacency mask: `(num_nodes, num_nodes)` — non-zero entries allow attention
319    ///
320    /// # Returns
321    /// Output features: `(batch, num_nodes, out_features * num_heads)`
322    pub fn forward_graph(&self, x: &Variable, adj: &Variable) -> Variable {
323        let shape = x.shape();
324        assert!(
325            shape.len() == 3,
326            "GATConv expects (batch, nodes, features), got {:?}",
327            shape
328        );
329
330        let batch = shape[0];
331        let nodes = shape[1];
332        let total_out = self.out_features * self.num_heads;
333
334        let x_data = x.data().to_vec();
335        let adj_data = adj.data().to_vec();
336        let w_data = self.w.data().to_vec();
337        let attn_src_data = self.attn_src.data().to_vec();
338        let attn_dst_data = self.attn_dst.data().to_vec();
339
340        let adj_nodes = if adj.shape().len() == 3 {
341            adj.shape()[1]
342        } else {
343            adj.shape()[0]
344        };
345        assert_eq!(adj_nodes, nodes, "Adjacency matrix size mismatch");
346
347        let mut output = vec![0.0f32; batch * nodes * total_out];
348
349        for b in 0..batch {
350            // Step 1: Project all nodes: h = x @ w → (nodes, total_out)
351            let mut h = vec![0.0f32; nodes * total_out];
352            for i in 0..nodes {
353                let x_off = (b * nodes + i) * self.in_features;
354                for o in 0..total_out {
355                    let mut val = 0.0;
356                    for f in 0..self.in_features {
357                        val += x_data[x_off + f] * w_data[f * total_out + o];
358                    }
359                    h[i * total_out + o] = val;
360                }
361            }
362
363            // Step 2: Compute attention per head
364            let adj_off = if adj.shape().len() == 3 {
365                b * nodes * nodes
366            } else {
367                0
368            };
369
370            for head in 0..self.num_heads {
371                let head_off = head * self.out_features;
372
373                // Compute attention scores for each edge
374                // e_ij = LeakyReLU(attn_src · h_i + attn_dst · h_j)
375                let mut attn_scores = vec![f32::NEG_INFINITY; nodes * nodes];
376
377                for i in 0..nodes {
378                    // src score for node i
379                    let mut src_score = 0.0;
380                    for f in 0..self.out_features {
381                        src_score += h[i * total_out + head_off + f]
382                            * attn_src_data[head * self.out_features + f];
383                    }
384
385                    for j in 0..nodes {
386                        let a_ij = adj_data[adj_off + i * nodes + j];
387                        if a_ij != 0.0 {
388                            let mut dst_score = 0.0;
389                            for f in 0..self.out_features {
390                                dst_score += h[j * total_out + head_off + f]
391                                    * attn_dst_data[head * self.out_features + f];
392                            }
393
394                            let e = src_score + dst_score;
395                            // LeakyReLU
396                            let e = if e > 0.0 { e } else { e * self.negative_slope };
397                            attn_scores[i * nodes + j] = e;
398                        }
399                    }
400                }
401
402                // Softmax per row (per destination node)
403                for i in 0..nodes {
404                    let row_start = i * nodes;
405                    let row_end = row_start + nodes;
406                    let row = &attn_scores[row_start..row_end];
407
408                    let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
409                    if max_val == f32::NEG_INFINITY {
410                        continue; // No neighbors
411                    }
412
413                    let mut sum_exp = 0.0f32;
414                    let mut exps = vec![0.0; nodes];
415                    for j in 0..nodes {
416                        if row[j] > f32::NEG_INFINITY {
417                            exps[j] = (row[j] - max_val).exp();
418                            sum_exp += exps[j];
419                        }
420                    }
421
422                    // Weighted sum of neighbor features
423                    let out_off = (b * nodes + i) * total_out + head_off;
424                    for j in 0..nodes {
425                        if exps[j] > 0.0 {
426                            let alpha = exps[j] / sum_exp;
427                            for f in 0..self.out_features {
428                                output[out_off + f] += alpha * h[j * total_out + head_off + f];
429                            }
430                        }
431                    }
432                }
433            }
434        }
435
436        // Add bias
437        if let Some(bias) = &self.bias {
438            let bias_data = bias.data().to_vec();
439            for b in 0..batch {
440                for i in 0..nodes {
441                    let offset = (b * nodes + i) * total_out;
442                    for o in 0..total_out {
443                        output[offset + o] += bias_data[o];
444                    }
445                }
446            }
447        }
448
449        Variable::new(
450            Tensor::from_vec(output, &[batch, nodes, total_out]).unwrap(),
451            x.requires_grad(),
452        )
453    }
454
455    /// Total output dimension (out_features * num_heads).
456    pub fn total_out_features(&self) -> usize {
457        self.out_features * self.num_heads
458    }
459}
460
461impl Module for GATConv {
462    fn forward(&self, _input: &Variable) -> Variable {
463        panic!("GATConv requires an adjacency matrix. Use forward_graph(x, adj) instead.")
464    }
465
466    fn parameters(&self) -> Vec<Parameter> {
467        let mut params = vec![self.w.clone(), self.attn_src.clone(), self.attn_dst.clone()];
468        if let Some(bias) = &self.bias {
469            params.push(bias.clone());
470        }
471        params
472    }
473
474    fn named_parameters(&self) -> HashMap<String, Parameter> {
475        let mut params = HashMap::new();
476        params.insert("w".to_string(), self.w.clone());
477        params.insert("attn_src".to_string(), self.attn_src.clone());
478        params.insert("attn_dst".to_string(), self.attn_dst.clone());
479        if let Some(bias) = &self.bias {
480            params.insert("bias".to_string(), bias.clone());
481        }
482        params
483    }
484
485    fn name(&self) -> &'static str {
486        "GATConv"
487    }
488}
489
490// =============================================================================
491// Tests
492// =============================================================================
493
494#[cfg(test)]
495mod tests {
496    use super::*;
497
498    #[test]
499    fn test_gcn_conv_shape() {
500        let gcn = GCNConv::new(72, 128);
501        let x = Variable::new(
502            Tensor::from_vec(vec![1.0; 2 * 7 * 72], &[2, 7, 72]).unwrap(),
503            false,
504        );
505        let adj = Variable::new(Tensor::from_vec(vec![1.0; 7 * 7], &[7, 7]).unwrap(), false);
506        let output = gcn.forward_graph(&x, &adj);
507        assert_eq!(output.shape(), vec![2, 7, 128]);
508    }
509
510    #[test]
511    fn test_gcn_conv_identity_adjacency() {
512        // With identity adjacency, each node only sees itself
513        let gcn = GCNConv::new(4, 8);
514        let x = Variable::new(
515            Tensor::from_vec(vec![1.0; 1 * 3 * 4], &[1, 3, 4]).unwrap(),
516            false,
517        );
518
519        // Identity adjacency
520        let mut adj_data = vec![0.0; 9];
521        adj_data[0] = 1.0; // (0,0)
522        adj_data[4] = 1.0; // (1,1)
523        adj_data[8] = 1.0; // (2,2)
524        let adj = Variable::new(Tensor::from_vec(adj_data, &[3, 3]).unwrap(), false);
525
526        let output = gcn.forward_graph(&x, &adj);
527        assert_eq!(output.shape(), vec![1, 3, 8]);
528
529        // All nodes have same input, so all should produce same output
530        let data = output.data().to_vec();
531        for i in 0..3 {
532            for f in 0..8 {
533                assert!(
534                    (data[i * 8 + f] - data[f]).abs() < 1e-6,
535                    "Node outputs should be identical with identity adj and same input"
536                );
537            }
538        }
539    }
540
541    #[test]
542    fn test_gcn_conv_parameters() {
543        let gcn = GCNConv::new(16, 32);
544        let params = gcn.parameters();
545        assert_eq!(params.len(), 2); // weight + bias
546
547        let total_params: usize = params.iter().map(|p| p.numel()).sum();
548        assert_eq!(total_params, 16 * 32 + 32); // weight + bias
549    }
550
551    #[test]
552    fn test_gcn_conv_no_bias() {
553        let gcn = GCNConv::without_bias(16, 32);
554        let params = gcn.parameters();
555        assert_eq!(params.len(), 1); // weight only
556    }
557
558    #[test]
559    fn test_gcn_conv_named_parameters() {
560        let gcn = GCNConv::new(16, 32);
561        let params = gcn.named_parameters();
562        assert!(params.contains_key("weight"));
563        assert!(params.contains_key("bias"));
564    }
565
566    #[test]
567    fn test_gat_conv_shape() {
568        let gat = GATConv::new(72, 32, 4); // 4 heads
569        let x = Variable::new(
570            Tensor::from_vec(vec![1.0; 2 * 7 * 72], &[2, 7, 72]).unwrap(),
571            false,
572        );
573        let adj = Variable::new(Tensor::from_vec(vec![1.0; 7 * 7], &[7, 7]).unwrap(), false);
574        let output = gat.forward_graph(&x, &adj);
575        assert_eq!(output.shape(), vec![2, 7, 128]); // 32 * 4 = 128
576    }
577
578    #[test]
579    fn test_gat_conv_single_head() {
580        let gat = GATConv::new(16, 8, 1);
581        let x = Variable::new(
582            Tensor::from_vec(vec![1.0; 1 * 5 * 16], &[1, 5, 16]).unwrap(),
583            false,
584        );
585        let adj = Variable::new(Tensor::from_vec(vec![1.0; 5 * 5], &[5, 5]).unwrap(), false);
586        let output = gat.forward_graph(&x, &adj);
587        assert_eq!(output.shape(), vec![1, 5, 8]);
588    }
589
590    #[test]
591    fn test_gat_conv_parameters() {
592        let gat = GATConv::new(16, 8, 4);
593        let params = gat.parameters();
594        assert_eq!(params.len(), 4); // w, attn_src, attn_dst, bias
595
596        let named = gat.named_parameters();
597        assert!(named.contains_key("w"));
598        assert!(named.contains_key("attn_src"));
599        assert!(named.contains_key("attn_dst"));
600        assert!(named.contains_key("bias"));
601    }
602
603    #[test]
604    fn test_gat_conv_total_output() {
605        let gat = GATConv::new(16, 32, 4);
606        assert_eq!(gat.total_out_features(), 128);
607    }
608
609    #[test]
610    fn test_gcn_zero_adjacency() {
611        // Zero adjacency should produce only bias in output
612        let gcn = GCNConv::new(4, 4);
613        let x = Variable::new(
614            Tensor::from_vec(vec![99.0; 1 * 3 * 4], &[1, 3, 4]).unwrap(),
615            false,
616        );
617        let adj = Variable::new(Tensor::from_vec(vec![0.0; 9], &[3, 3]).unwrap(), false);
618        let output = gcn.forward_graph(&x, &adj);
619
620        // With zero adjacency, output should be just bias (all zeros initially)
621        let data = output.data().to_vec();
622        for val in &data {
623            assert!(
624                val.abs() < 1e-6,
625                "Zero adjacency should zero out message passing"
626            );
627        }
628    }
629}