Skip to main content

axonml_nn/layers/
graph.rs

1//! Graph Neural Network Layers
2//!
3//! Provides graph convolution layers for learning on graph-structured data.
4//! Includes GCN (Graph Convolutional Network) and GAT (Graph Attention Network).
5//!
6//! @version 0.1.0
7//! @author AutomataNexus Development Team
8
9use std::collections::HashMap;
10
11use axonml_autograd::Variable;
12use axonml_tensor::Tensor;
13
14use crate::module::Module;
15use crate::parameter::Parameter;
16
17
18// =============================================================================
19// GCNConv
20// =============================================================================
21
22/// Graph Convolutional Network layer (Kipf & Welling, 2017).
23///
24/// Performs the graph convolution: `output = adj @ x @ weight + bias`
25/// where `adj` is the (possibly learned) adjacency matrix.
26///
27/// # Arguments
28/// * `in_features` - Number of input features per node
29/// * `out_features` - Number of output features per node
30///
31/// # Forward Signature
32/// Uses `forward_graph(x, adj)` since it requires an adjacency matrix.
33/// The standard `forward()` from Module is provided but panics—use `forward_graph()`.
34///
35/// # Example
36/// ```ignore
37/// use axonml_nn::layers::GCNConv;
38///
39/// let gcn = GCNConv::new(72, 128);
40/// let x = Variable::new(Tensor::randn(&[2, 7, 72]), true);     // (batch, nodes, features)
41/// let adj = Variable::new(Tensor::ones(&[7, 7]), false);         // (nodes, nodes)
42/// let output = gcn.forward_graph(&x, &adj);
43/// // output shape: (2, 7, 128)
44/// ```
45pub struct GCNConv {
46    weight: Parameter,
47    bias: Option<Parameter>,
48    in_features: usize,
49    out_features: usize,
50}
51
52impl GCNConv {
53    /// Creates a new GCN convolution layer with bias.
54    pub fn new(in_features: usize, out_features: usize) -> Self {
55        // Xavier initialization
56        let scale = (2.0 / (in_features + out_features) as f32).sqrt();
57        let weight_data: Vec<f32> = (0..in_features * out_features)
58            .map(|i| {
59                // Simple deterministic-ish init for reproducibility
60                let x = ((i as f32 * 0.6180339887) % 1.0) * 2.0 - 1.0;
61                x * scale
62            })
63            .collect();
64
65        let weight = Parameter::named(
66            "weight",
67            Tensor::from_vec(weight_data, &[in_features, out_features]).unwrap(),
68            true,
69        );
70
71        let bias_data = vec![0.0; out_features];
72        let bias = Some(Parameter::named(
73            "bias",
74            Tensor::from_vec(bias_data, &[out_features]).unwrap(),
75            true,
76        ));
77
78        Self {
79            weight,
80            bias,
81            in_features,
82            out_features,
83        }
84    }
85
86    /// Creates a GCN layer without bias.
87    pub fn without_bias(in_features: usize, out_features: usize) -> Self {
88        let scale = (2.0 / (in_features + out_features) as f32).sqrt();
89        let weight_data: Vec<f32> = (0..in_features * out_features)
90            .map(|i| {
91                let x = ((i as f32 * 0.6180339887) % 1.0) * 2.0 - 1.0;
92                x * scale
93            })
94            .collect();
95
96        let weight = Parameter::named(
97            "weight",
98            Tensor::from_vec(weight_data, &[in_features, out_features]).unwrap(),
99            true,
100        );
101
102        Self {
103            weight,
104            bias: None,
105            in_features,
106            out_features,
107        }
108    }
109
110    /// Graph convolution forward pass.
111    ///
112    /// # Arguments
113    /// * `x` - Node features: `(batch, num_nodes, in_features)`
114    /// * `adj` - Adjacency matrix: `(num_nodes, num_nodes)` or `(batch, num_nodes, num_nodes)`
115    ///
116    /// # Returns
117    /// Output features: `(batch, num_nodes, out_features)`
118    pub fn forward_graph(&self, x: &Variable, adj: &Variable) -> Variable {
119        let shape = x.shape();
120        assert!(shape.len() == 3, "GCNConv expects input shape (batch, nodes, features), got {:?}", shape);
121        assert_eq!(shape[2], self.in_features, "Input features mismatch");
122
123        let batch = shape[0];
124        let nodes = shape[1];
125        let adj_shape = adj.shape();
126
127        let x_data = x.data().to_vec();
128        let adj_data = adj.data().to_vec();
129        let w_data = self.weight.data().to_vec();
130
131        let mut output = vec![0.0f32; batch * nodes * self.out_features];
132
133        for b in 0..batch {
134            // Get adjacency for this batch
135            let adj_offset = if adj_shape.len() == 3 {
136                b * nodes * nodes
137            } else {
138                0 // shared adjacency
139            };
140
141            // Step 1: message = adj @ x  → (nodes, in_features)
142            // Step 2: output = message @ weight → (nodes, out_features)
143            for i in 0..nodes {
144                // Aggregate neighbor features: message_i = sum_j adj[i,j] * x[b,j,:]
145                let mut message = vec![0.0f32; self.in_features];
146                for j in 0..nodes {
147                    let a_ij = adj_data[adj_offset + i * nodes + j];
148                    if a_ij != 0.0 {
149                        let x_offset = (b * nodes + j) * self.in_features;
150                        for f in 0..self.in_features {
151                            message[f] += a_ij * x_data[x_offset + f];
152                        }
153                    }
154                }
155
156                // Transform: out_i = message_i @ weight
157                let out_offset = (b * nodes + i) * self.out_features;
158                for o in 0..self.out_features {
159                    let mut val = 0.0;
160                    for f in 0..self.in_features {
161                        val += message[f] * w_data[f * self.out_features + o];
162                    }
163                    output[out_offset + o] = val;
164                }
165            }
166        }
167
168        // Add bias
169        if let Some(bias) = &self.bias {
170            let bias_data = bias.data().to_vec();
171            for b in 0..batch {
172                for i in 0..nodes {
173                    let offset = (b * nodes + i) * self.out_features;
174                    for o in 0..self.out_features {
175                        output[offset + o] += bias_data[o];
176                    }
177                }
178            }
179        }
180
181        Variable::new(
182            Tensor::from_vec(output, &[batch, nodes, self.out_features]).unwrap(),
183            x.requires_grad() || adj.requires_grad(),
184        )
185    }
186
187    /// Returns the input feature dimension.
188    pub fn in_features(&self) -> usize {
189        self.in_features
190    }
191
192    /// Returns the output feature dimension.
193    pub fn out_features(&self) -> usize {
194        self.out_features
195    }
196}
197
198impl Module for GCNConv {
199    fn forward(&self, _input: &Variable) -> Variable {
200        panic!("GCNConv requires an adjacency matrix. Use forward_graph(x, adj) instead.")
201    }
202
203    fn parameters(&self) -> Vec<Parameter> {
204        let mut params = vec![self.weight.clone()];
205        if let Some(bias) = &self.bias {
206            params.push(bias.clone());
207        }
208        params
209    }
210
211    fn named_parameters(&self) -> HashMap<String, Parameter> {
212        let mut params = HashMap::new();
213        params.insert("weight".to_string(), self.weight.clone());
214        if let Some(bias) = &self.bias {
215            params.insert("bias".to_string(), bias.clone());
216        }
217        params
218    }
219
220    fn name(&self) -> &'static str {
221        "GCNConv"
222    }
223}
224
225// =============================================================================
226// GATConv
227// =============================================================================
228
229/// Graph Attention Network layer (Veličković et al., 2018).
230///
231/// Computes attention-weighted graph convolution where attention coefficients
232/// are learned based on node features, masked by the adjacency matrix.
233///
234/// # Example
235/// ```ignore
236/// use axonml_nn::layers::GATConv;
237///
238/// let gat = GATConv::new(72, 32, 4); // 4 attention heads
239/// let x = Variable::new(Tensor::randn(&[2, 7, 72]), true);
240/// let adj = Variable::new(Tensor::ones(&[7, 7]), false);
241/// let output = gat.forward_graph(&x, &adj);
242/// // output shape: (2, 7, 128)  — 32 * 4 heads
243/// ```
244pub struct GATConv {
245    w: Parameter,
246    attn_src: Parameter,
247    attn_dst: Parameter,
248    bias: Option<Parameter>,
249    in_features: usize,
250    out_features: usize,
251    num_heads: usize,
252    negative_slope: f32,
253}
254
255impl GATConv {
256    /// Creates a new GAT convolution layer.
257    ///
258    /// # Arguments
259    /// * `in_features` - Input feature dimension per node
260    /// * `out_features` - Output feature dimension per head
261    /// * `num_heads` - Number of attention heads (output = out_features * num_heads)
262    pub fn new(in_features: usize, out_features: usize, num_heads: usize) -> Self {
263        let total_out = out_features * num_heads;
264        let scale = (2.0 / (in_features + total_out) as f32).sqrt();
265
266        let w_data: Vec<f32> = (0..in_features * total_out)
267            .map(|i| {
268                let x = ((i as f32 * 0.6180339887) % 1.0) * 2.0 - 1.0;
269                x * scale
270            })
271            .collect();
272
273        let w = Parameter::named(
274            "w",
275            Tensor::from_vec(w_data, &[in_features, total_out]).unwrap(),
276            true,
277        );
278
279        // Attention vectors: one per head for source and destination
280        let attn_scale = (1.0 / out_features as f32).sqrt();
281        let attn_src_data: Vec<f32> = (0..total_out)
282            .map(|i| {
283                let x = ((i as f32 * 0.7236067977) % 1.0) * 2.0 - 1.0;
284                x * attn_scale
285            })
286            .collect();
287        let attn_dst_data: Vec<f32> = (0..total_out)
288            .map(|i| {
289                let x = ((i as f32 * 0.3819660113) % 1.0) * 2.0 - 1.0;
290                x * attn_scale
291            })
292            .collect();
293
294        let attn_src = Parameter::named(
295            "attn_src",
296            Tensor::from_vec(attn_src_data, &[num_heads, out_features]).unwrap(),
297            true,
298        );
299
300        let attn_dst = Parameter::named(
301            "attn_dst",
302            Tensor::from_vec(attn_dst_data, &[num_heads, out_features]).unwrap(),
303            true,
304        );
305
306        let bias_data = vec![0.0; total_out];
307        let bias = Some(Parameter::named(
308            "bias",
309            Tensor::from_vec(bias_data, &[total_out]).unwrap(),
310            true,
311        ));
312
313        Self {
314            w,
315            attn_src,
316            attn_dst,
317            bias,
318            in_features,
319            out_features,
320            num_heads,
321            negative_slope: 0.2,
322        }
323    }
324
325    /// Graph attention forward pass.
326    ///
327    /// # Arguments
328    /// * `x` - Node features: `(batch, num_nodes, in_features)`
329    /// * `adj` - Adjacency mask: `(num_nodes, num_nodes)` — non-zero entries allow attention
330    ///
331    /// # Returns
332    /// Output features: `(batch, num_nodes, out_features * num_heads)`
333    pub fn forward_graph(&self, x: &Variable, adj: &Variable) -> Variable {
334        let shape = x.shape();
335        assert!(shape.len() == 3, "GATConv expects (batch, nodes, features), got {:?}", shape);
336
337        let batch = shape[0];
338        let nodes = shape[1];
339        let total_out = self.out_features * self.num_heads;
340
341        let x_data = x.data().to_vec();
342        let adj_data = adj.data().to_vec();
343        let w_data = self.w.data().to_vec();
344        let attn_src_data = self.attn_src.data().to_vec();
345        let attn_dst_data = self.attn_dst.data().to_vec();
346
347        let adj_nodes = if adj.shape().len() == 3 { adj.shape()[1] } else { adj.shape()[0] };
348        assert_eq!(adj_nodes, nodes, "Adjacency matrix size mismatch");
349
350        let mut output = vec![0.0f32; batch * nodes * total_out];
351
352        for b in 0..batch {
353            // Step 1: Project all nodes: h = x @ w → (nodes, total_out)
354            let mut h = vec![0.0f32; nodes * total_out];
355            for i in 0..nodes {
356                let x_off = (b * nodes + i) * self.in_features;
357                for o in 0..total_out {
358                    let mut val = 0.0;
359                    for f in 0..self.in_features {
360                        val += x_data[x_off + f] * w_data[f * total_out + o];
361                    }
362                    h[i * total_out + o] = val;
363                }
364            }
365
366            // Step 2: Compute attention per head
367            let adj_off = if adj.shape().len() == 3 { b * nodes * nodes } else { 0 };
368
369            for head in 0..self.num_heads {
370                let head_off = head * self.out_features;
371
372                // Compute attention scores for each edge
373                // e_ij = LeakyReLU(attn_src · h_i + attn_dst · h_j)
374                let mut attn_scores = vec![f32::NEG_INFINITY; nodes * nodes];
375
376                for i in 0..nodes {
377                    // src score for node i
378                    let mut src_score = 0.0;
379                    for f in 0..self.out_features {
380                        src_score += h[i * total_out + head_off + f] * attn_src_data[head * self.out_features + f];
381                    }
382
383                    for j in 0..nodes {
384                        let a_ij = adj_data[adj_off + i * nodes + j];
385                        if a_ij != 0.0 {
386                            let mut dst_score = 0.0;
387                            for f in 0..self.out_features {
388                                dst_score += h[j * total_out + head_off + f] * attn_dst_data[head * self.out_features + f];
389                            }
390
391                            let e = src_score + dst_score;
392                            // LeakyReLU
393                            let e = if e > 0.0 { e } else { e * self.negative_slope };
394                            attn_scores[i * nodes + j] = e;
395                        }
396                    }
397                }
398
399                // Softmax per row (per destination node)
400                for i in 0..nodes {
401                    let row_start = i * nodes;
402                    let row_end = row_start + nodes;
403                    let row = &attn_scores[row_start..row_end];
404
405                    let max_val = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
406                    if max_val == f32::NEG_INFINITY {
407                        continue; // No neighbors
408                    }
409
410                    let mut sum_exp = 0.0f32;
411                    let mut exps = vec![0.0; nodes];
412                    for j in 0..nodes {
413                        if row[j] > f32::NEG_INFINITY {
414                            exps[j] = (row[j] - max_val).exp();
415                            sum_exp += exps[j];
416                        }
417                    }
418
419                    // Weighted sum of neighbor features
420                    let out_off = (b * nodes + i) * total_out + head_off;
421                    for j in 0..nodes {
422                        if exps[j] > 0.0 {
423                            let alpha = exps[j] / sum_exp;
424                            for f in 0..self.out_features {
425                                output[out_off + f] += alpha * h[j * total_out + head_off + f];
426                            }
427                        }
428                    }
429                }
430            }
431        }
432
433        // Add bias
434        if let Some(bias) = &self.bias {
435            let bias_data = bias.data().to_vec();
436            for b in 0..batch {
437                for i in 0..nodes {
438                    let offset = (b * nodes + i) * total_out;
439                    for o in 0..total_out {
440                        output[offset + o] += bias_data[o];
441                    }
442                }
443            }
444        }
445
446        Variable::new(
447            Tensor::from_vec(output, &[batch, nodes, total_out]).unwrap(),
448            x.requires_grad(),
449        )
450    }
451
452    /// Total output dimension (out_features * num_heads).
453    pub fn total_out_features(&self) -> usize {
454        self.out_features * self.num_heads
455    }
456}
457
458impl Module for GATConv {
459    fn forward(&self, _input: &Variable) -> Variable {
460        panic!("GATConv requires an adjacency matrix. Use forward_graph(x, adj) instead.")
461    }
462
463    fn parameters(&self) -> Vec<Parameter> {
464        let mut params = vec![self.w.clone(), self.attn_src.clone(), self.attn_dst.clone()];
465        if let Some(bias) = &self.bias {
466            params.push(bias.clone());
467        }
468        params
469    }
470
471    fn named_parameters(&self) -> HashMap<String, Parameter> {
472        let mut params = HashMap::new();
473        params.insert("w".to_string(), self.w.clone());
474        params.insert("attn_src".to_string(), self.attn_src.clone());
475        params.insert("attn_dst".to_string(), self.attn_dst.clone());
476        if let Some(bias) = &self.bias {
477            params.insert("bias".to_string(), bias.clone());
478        }
479        params
480    }
481
482    fn name(&self) -> &'static str {
483        "GATConv"
484    }
485}
486
487// =============================================================================
488// Tests
489// =============================================================================
490
491#[cfg(test)]
492mod tests {
493    use super::*;
494
495    #[test]
496    fn test_gcn_conv_shape() {
497        let gcn = GCNConv::new(72, 128);
498        let x = Variable::new(
499            Tensor::from_vec(vec![1.0; 2 * 7 * 72], &[2, 7, 72]).unwrap(),
500            false,
501        );
502        let adj = Variable::new(
503            Tensor::from_vec(vec![1.0; 7 * 7], &[7, 7]).unwrap(),
504            false,
505        );
506        let output = gcn.forward_graph(&x, &adj);
507        assert_eq!(output.shape(), vec![2, 7, 128]);
508    }
509
510    #[test]
511    fn test_gcn_conv_identity_adjacency() {
512        // With identity adjacency, each node only sees itself
513        let gcn = GCNConv::new(4, 8);
514        let x = Variable::new(
515            Tensor::from_vec(vec![1.0; 1 * 3 * 4], &[1, 3, 4]).unwrap(),
516            false,
517        );
518
519        // Identity adjacency
520        let mut adj_data = vec![0.0; 9];
521        adj_data[0] = 1.0; // (0,0)
522        adj_data[4] = 1.0; // (1,1)
523        adj_data[8] = 1.0; // (2,2)
524        let adj = Variable::new(
525            Tensor::from_vec(adj_data, &[3, 3]).unwrap(),
526            false,
527        );
528
529        let output = gcn.forward_graph(&x, &adj);
530        assert_eq!(output.shape(), vec![1, 3, 8]);
531
532        // All nodes have same input, so all should produce same output
533        let data = output.data().to_vec();
534        for i in 0..3 {
535            for f in 0..8 {
536                assert!((data[i * 8 + f] - data[f]).abs() < 1e-6,
537                    "Node outputs should be identical with identity adj and same input");
538            }
539        }
540    }
541
542    #[test]
543    fn test_gcn_conv_parameters() {
544        let gcn = GCNConv::new(16, 32);
545        let params = gcn.parameters();
546        assert_eq!(params.len(), 2); // weight + bias
547
548        let total_params: usize = params.iter().map(|p| p.numel()).sum();
549        assert_eq!(total_params, 16 * 32 + 32); // weight + bias
550    }
551
552    #[test]
553    fn test_gcn_conv_no_bias() {
554        let gcn = GCNConv::without_bias(16, 32);
555        let params = gcn.parameters();
556        assert_eq!(params.len(), 1); // weight only
557    }
558
559    #[test]
560    fn test_gcn_conv_named_parameters() {
561        let gcn = GCNConv::new(16, 32);
562        let params = gcn.named_parameters();
563        assert!(params.contains_key("weight"));
564        assert!(params.contains_key("bias"));
565    }
566
567    #[test]
568    fn test_gat_conv_shape() {
569        let gat = GATConv::new(72, 32, 4); // 4 heads
570        let x = Variable::new(
571            Tensor::from_vec(vec![1.0; 2 * 7 * 72], &[2, 7, 72]).unwrap(),
572            false,
573        );
574        let adj = Variable::new(
575            Tensor::from_vec(vec![1.0; 7 * 7], &[7, 7]).unwrap(),
576            false,
577        );
578        let output = gat.forward_graph(&x, &adj);
579        assert_eq!(output.shape(), vec![2, 7, 128]); // 32 * 4 = 128
580    }
581
582    #[test]
583    fn test_gat_conv_single_head() {
584        let gat = GATConv::new(16, 8, 1);
585        let x = Variable::new(
586            Tensor::from_vec(vec![1.0; 1 * 5 * 16], &[1, 5, 16]).unwrap(),
587            false,
588        );
589        let adj = Variable::new(
590            Tensor::from_vec(vec![1.0; 5 * 5], &[5, 5]).unwrap(),
591            false,
592        );
593        let output = gat.forward_graph(&x, &adj);
594        assert_eq!(output.shape(), vec![1, 5, 8]);
595    }
596
597    #[test]
598    fn test_gat_conv_parameters() {
599        let gat = GATConv::new(16, 8, 4);
600        let params = gat.parameters();
601        assert_eq!(params.len(), 4); // w, attn_src, attn_dst, bias
602
603        let named = gat.named_parameters();
604        assert!(named.contains_key("w"));
605        assert!(named.contains_key("attn_src"));
606        assert!(named.contains_key("attn_dst"));
607        assert!(named.contains_key("bias"));
608    }
609
610    #[test]
611    fn test_gat_conv_total_output() {
612        let gat = GATConv::new(16, 32, 4);
613        assert_eq!(gat.total_out_features(), 128);
614    }
615
616    #[test]
617    fn test_gcn_zero_adjacency() {
618        // Zero adjacency should produce only bias in output
619        let gcn = GCNConv::new(4, 4);
620        let x = Variable::new(
621            Tensor::from_vec(vec![99.0; 1 * 3 * 4], &[1, 3, 4]).unwrap(),
622            false,
623        );
624        let adj = Variable::new(
625            Tensor::from_vec(vec![0.0; 9], &[3, 3]).unwrap(),
626            false,
627        );
628        let output = gcn.forward_graph(&x, &adj);
629
630        // With zero adjacency, output should be just bias (all zeros initially)
631        let data = output.data().to_vec();
632        for val in &data {
633            assert!(val.abs() < 1e-6, "Zero adjacency should zero out message passing");
634        }
635    }
636}