torsh-graph
Graph neural network components for ToRSh - powered by SciRS2.
Overview
This crate provides comprehensive graph neural network (GNN) implementations with a PyTorch Geometric-compatible API. It leverages scirs2-graph for high-performance graph operations while maintaining full integration with ToRSh's autograd system and neural network modules.
Features
- Graph Representations: Adjacency matrices, edge lists, COO/CSR formats
- Message Passing Layers: GCN, GAT, GraphSAGE, GIN, EdgeConv
- Pooling Operations: Global pooling, TopK pooling, SAGPool, DiffPool
- Graph Convolutions: Spectral and spatial convolutions
- Attention Mechanisms: Graph attention, multi-head attention, transformer layers
- Graph Generation: Erdős-Rényi, Barabási-Albert, Watts-Strogatz
- Graph Utilities: Subgraph sampling, neighborhood aggregation, batching
- Heterogeneous Graphs: Support for multiple node/edge types
- Temporal Graphs: Dynamic graph neural networks
- Explainability: GNNExplainer, attention visualization
Usage
Basic Graph Construction
use *;
use *;
// Create a simple graph with 5 nodes
let num_nodes = 5;
// Define edges as pairs of node indices
let edge_index = tensor!; // target nodes
// Node features (5 nodes, 3 features each)
let x = randn?;
// Create graph
let graph = new?;
println!;
println!;
println!;
Graph Neural Network Layers
Graph Convolutional Network (GCN)
use *;
use *;
// Single GCN layer
let gcn_layer = new?;
// Forward pass
let x = randn?;
let edge_index = graph.edge_index;
let output = gcn_layer.forward?;
// Multi-layer GCN
Graph Attention Network (GAT)
use *;
// GAT layer with multi-head attention
let gat_layer = new?;
let output = gat_layer.forward?;
println!; // [num_nodes, 32 * 8] if concat=true
// Multi-layer GAT
GraphSAGE
use *;
// GraphSAGE layer with different aggregation methods
let sage_mean = new?;
let output = sage_mean.forward?;
// Using max aggregation
let sage_max = new?;
// Using LSTM aggregation
let sage_lstm = new?;
Graph Isomorphism Network (GIN)
use *;
// GIN layer with MLP
let mlp = new
.add
.add
.add;
let gin_layer = new?;
let output = gin_layer.forward?;
EdgeConv (Dynamic Graph CNN)
use *;
// EdgeConv for point cloud processing
let edge_conv = new?;
let output = edge_conv.forward?;
Graph Pooling
Global Pooling
use *;
// Global mean pooling
let global_mean = global_mean_pool?;
// Global max pooling
let global_max = global_max_pool?;
// Global sum pooling
let global_sum = global_add_pool?;
// Global attention pooling
let global_attn = new?;
let output = global_attn.forward?;
Hierarchical Pooling
use *;
// TopK pooling
let topk_pool = new?;
let =
topk_pool.forward?;
// SAGPool (Self-Attention Graph Pooling)
let sag_pool = new?;
// DiffPool (Differentiable Pooling)
let diff_pool = new?;
let =
diff_pool.forward?;
Heterogeneous Graphs
use *;
// Create heterogeneous graph with different node types
let hetero_graph = new;
// Add node types
hetero_graph.add_node_type?;
hetero_graph.add_node_type?;
// Add edge types with relations
hetero_graph.add_edge_type?;
hetero_graph.add_edge_type?;
// Heterogeneous GNN layer
let hetero_conv = new
.add_conv
.add_conv;
let output = hetero_conv.forward?;
Temporal Graphs
use *;
// Temporal graph with snapshots
let temporal_graph = new;
for t in 0..10
// Temporal GNN
let tgnn = TGCNnew?;
// Process temporal sequence
let outputs = tgnn.forward?;
// Dynamic edge RNN
let edge_rnn = new?;
Graph Generation
use *;
// Erdős-Rényi random graph
let er_graph = erdos_renyi_graph?;
// Barabási-Albert preferential attachment
let ba_graph = barabasi_albert_graph?;
// Watts-Strogatz small-world
let ws_graph = watts_strogatz_graph?;
// Stochastic block model
let sbm_graph = stochastic_block_model?;
Graph Utilities
Batching
use *;
// Create a batch from multiple graphs
let graphs = vec!;
let batch = from_data_list?;
// Batch contains:
// - batch.x: Concatenated node features
// - batch.edge_index: Concatenated edges with offset indices
// - batch.batch: Tensor indicating which graph each node belongs to
// Unbatch
let graphs_recovered = batch.to_data_list?;
Neighborhood Sampling
use *;
// Sample k-hop neighborhood
let subgraph = k_hop_subgraph?;
// Random walk sampling
let walks = random_walk?;
// Neighbor sampling (for GraphSAGE)
let neighbor_sampler = new?;
for batch in neighbor_sampler
Graph Classification
use *;
Node Classification
use *;
// Load citation network (e.g., Cora, CiteSeer, PubMed)
let dataset = new?;
let data = dataset.get?;
// Split into train/val/test
let train_mask = &data.train_mask;
let val_mask = &data.val_mask;
let test_mask = &data.test_mask;
// Define model
let model = GCNnew;
// Training loop
for epoch in 0..200
Explainability
use *;
// GNNExplainer for model interpretation
let explainer = new?;
// Explain prediction for a specific node
let = explainer.explain_node?;
// Visualize important edges
visualize_subgraph?;
// Attention weights visualization for GAT
let attention_weights = gat_layer.get_attention_weights?;
plot_attention_graph?;
Advanced Examples
Link Prediction
use *;
// Negative sampling for link prediction
let neg_edge_index = negative_sampling?;
// Model for link prediction
Integration with SciRS2
This crate leverages the SciRS2 ecosystem for:
- Graph algorithms and data structures through
scirs2-graph - Sparse matrix operations via
scirs2-core - Spatial indexing through
scirs2-spatial - Automatic differentiation via
scirs2-autograd
All implementations follow the SciRS2 POLICY for optimal performance and maintainability.
License
Licensed under the Apache License, Version 2.0. See LICENSE for details.