use numrs2::new_modules::nn::graph::*;
use scirs2_core::ndarray::Array2;
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("=== GIN Graph Classification Example ===\n");
println!("Molecule 1: Small non-toxic structure");
let mol1_nodes = 3; let mol1_edges = vec![(0, 1), (1, 2)]; let mol1_features = Array2::from_shape_fn((mol1_nodes, 8), |(i, j)| {
((i * 3 + j) % 10) as f64 / 10.0
});
println!(" Nodes: {}, Edges: {}", mol1_nodes, mol1_edges.len());
println!("Molecule 2: Larger toxic structure");
let mol2_nodes = 6; let mol2_edges = vec![
(0, 1),
(1, 2),
(2, 3),
(3, 4),
(4, 5),
(5, 0), ];
let mol2_features = Array2::from_shape_fn((mol2_nodes, 8), |(i, j)| {
((i * 5 + j + 3) % 10) as f64 / 10.0 + 0.3
});
println!(" Nodes: {}, Edges: {}", mol2_nodes, mol2_edges.len());
println!();
let input_dim = 8;
let hidden_dim = 16;
let epsilon = 0.0;
println!("Building GIN model:");
println!(" Input features: {}", input_dim);
println!(" Hidden dimension: {}", hidden_dim);
println!(" Epsilon: {} (fixed)", epsilon);
println!(" Aggregation: SUM (most expressive for WL test)");
println!();
let gin_layer = GinLayer::new(input_dim, hidden_dim, epsilon)?;
println!("Processing Molecule 1...");
let mol1_adj = SparseAdjacency::from_edges(mol1_nodes, &mol1_edges)?;
let mol1_hidden = gin_layer.forward(&mol1_adj, &mol1_features.view())?;
println!(" Node representations shape: {:?}", mol1_hidden.shape());
let mut mol1_hidden_relu = mol1_hidden.clone();
for i in 0..mol1_hidden_relu.nrows() {
for j in 0..mol1_hidden_relu.ncols() {
if mol1_hidden_relu[[i, j]] < 0.0 {
mol1_hidden_relu[[i, j]] = 0.0;
}
}
}
let mol1_graph_repr = global_mean_pool(&mol1_hidden_relu.view())?;
println!(" After global pooling: {:?}", mol1_graph_repr.shape());
println!(
" Graph representation (first 5 dims): [{:.3}, {:.3}, {:.3}, {:.3}, {:.3}]",
mol1_graph_repr[0],
mol1_graph_repr[1],
mol1_graph_repr[2],
mol1_graph_repr[3],
mol1_graph_repr[4]
);
println!("\nProcessing Molecule 2...");
let mol2_adj = SparseAdjacency::from_edges(mol2_nodes, &mol2_edges)?;
let mol2_hidden = gin_layer.forward(&mol2_adj, &mol2_features.view())?;
println!(" Node representations shape: {:?}", mol2_hidden.shape());
let mut mol2_hidden_relu = mol2_hidden.clone();
for i in 0..mol2_hidden_relu.nrows() {
for j in 0..mol2_hidden_relu.ncols() {
if mol2_hidden_relu[[i, j]] < 0.0 {
mol2_hidden_relu[[i, j]] = 0.0;
}
}
}
let mol2_graph_repr = global_mean_pool(&mol2_hidden_relu.view())?;
println!(" After global pooling: {:?}", mol2_graph_repr.shape());
println!(
" Graph representation (first 5 dims): [{:.3}, {:.3}, {:.3}, {:.3}, {:.3}]",
mol2_graph_repr[0],
mol2_graph_repr[1],
mol2_graph_repr[2],
mol2_graph_repr[3],
mol2_graph_repr[4]
);
println!("\nGraph Representation Comparison:");
let mut l2_dist = 0.0;
for i in 0..hidden_dim {
let diff = mol1_graph_repr[i] - mol2_graph_repr[i];
l2_dist += diff * diff;
}
l2_dist = l2_dist.sqrt();
println!(" L2 distance between representations: {:.4}", l2_dist);
println!(" Different structures produce different embeddings!");
println!("\nPooling Strategy Comparison:");
println!(" Mean pooling:");
println!(
" Mol1: {:.4}",
mol1_graph_repr.iter().sum::<f64>() / hidden_dim as f64
);
println!(
" Mol2: {:.4}",
mol2_graph_repr.iter().sum::<f64>() / hidden_dim as f64
);
let mol1_max_pool = global_max_pool(&mol1_hidden_relu.view())?;
let mol2_max_pool = global_max_pool(&mol2_hidden_relu.view())?;
println!(" Max pooling:");
println!(
" Mol1: {:.4}",
mol1_max_pool
.iter()
.cloned()
.fold(f64::NEG_INFINITY, f64::max)
);
println!(
" Mol2: {:.4}",
mol2_max_pool
.iter()
.cloned()
.fold(f64::NEG_INFINITY, f64::max)
);
let mol1_sum_pool = global_sum_pool(&mol1_hidden_relu.view())?;
let mol2_sum_pool = global_sum_pool(&mol2_hidden_relu.view())?;
println!(" Sum pooling:");
println!(" Mol1: {:.4}", mol1_sum_pool.iter().sum::<f64>());
println!(" Mol2: {:.4}", mol2_sum_pool.iter().sum::<f64>());
println!("\nGIN Expressiveness:");
println!(" GIN with sum aggregation is as powerful as WL test");
println!(" It can distinguish non-isomorphic graphs");
println!(" Epsilon parameter: (1+ε)·h_v + Σh_u");
println!(" ε=0 (fixed): Standard GIN");
println!(" ε>0 (learnable): Can adapt to specific tasks");
println!("\nStructural Information:");
println!(" Molecule 1 (linear):");
println!(
" Average degree: {:.2}",
mol1_edges.len() as f64 * 2.0 / mol1_nodes as f64
);
println!(" Structure: Open chain");
println!(" Molecule 2 (ring):");
println!(
" Average degree: {:.2}",
mol2_edges.len() as f64 * 2.0 / mol2_nodes as f64
);
println!(" Structure: Closed cycle");
println!("\n=== Example completed successfully! ===");
println!("\nNext steps:");
println!(" 1. Stack multiple GIN layers for deeper networks");
println!(" 2. Add edge features for bond types");
println!(" 3. Train with labeled molecular datasets");
println!(" 4. Use different pooling strategies (DiffPool, TopK)");
Ok(())
}