Module gnn

Module gnn 

Source
Expand description

Graph Neural Network layers for learning on graph-structured data.

This module provides neural network layers that operate on graphs, enabling learning from node features and graph topology.

§Architecture

Node Features    Graph Structure
     │                 │
     ▼                 ▼
┌────────────────────────────┐
│       GNN Layer            │
│  (aggregate + transform)   │
└────────────────────────────┘
           │
           ▼
   Updated Node Features

§Layers

  • GCNConv - Graph Convolutional Network (Kipf & Welling, 2017)
  • GATConv - Graph Attention Network (Velickovic et al., 2018)

§Example

use aprender::gnn::{GCNConv, GNNModule};
use aprender::autograd::Tensor;

// Create GCN layer: 16 input features → 32 output features
let gcn = GCNConv::new(16, 32);

// Node features [num_nodes, in_features]
let x = Tensor::ones(&[4, 16]);

// Adjacency matrix (COO format): edge_index[2, num_edges]
let edge_index = vec![(0, 1), (1, 2), (2, 3), (3, 0)];

let out = gcn.forward_gnn(&x, &edge_index);
assert_eq!(out.shape(), &[4, 32]);

§References

  • Kipf, T. N., & Welling, M. (2017). Semi-Supervised Classification with Graph Convolutional Networks. ICLR.
  • Velickovic, P., et al. (2018). Graph Attention Networks. ICLR.
  • Hamilton, W. L., et al. (2017). Inductive Representation Learning on Large Graphs (GraphSAGE). NeurIPS.

Structs§

EdgeConv
Edge convolution layer (Wang et al., 2019).
GATConv
Graph Attention Network layer (Velickovic et al., 2018).
GCNConv
Graph Convolutional Network layer (Kipf & Welling, 2017).
GINConv
Graph Isomorphism Network layer (Xu et al., 2019).
GraphSAGEConv
GraphSAGE layer (Hamilton et al., 2017).

Enums§

SAGEAggregation
Aggregation method for GraphSAGE.

Traits§

GNNModule
Trait for GNN modules that process graph-structured data.

Functions§

global_max_pool
Global max pooling for graph-level predictions.
global_mean_pool
Global mean pooling for graph-level predictions.
global_sum_pool
Global sum pooling for graph-level predictions.

Type Aliases§

EdgeIndex
Edge index type: (source_node, target_node)