use crate::autograd::Tensor;
use crate::nn::{Linear, Module};
pub type EdgeIndex = (usize, usize);
pub trait GNNModule: Module {
fn forward_gnn(&self, x: &Tensor, edge_index: &[EdgeIndex]) -> Tensor;
}
#[derive(Debug)]
#[allow(dead_code)]
pub struct GCNConv {
linear: Linear,
in_features: usize,
out_features: usize,
add_self_loops: bool,
use_bias: bool,
}
impl GCNConv {
#[must_use]
pub fn new(in_features: usize, out_features: usize) -> Self {
Self {
linear: Linear::new(in_features, out_features),
in_features,
out_features,
add_self_loops: true,
use_bias: true,
}
}
#[must_use]
pub fn without_self_loops(in_features: usize, out_features: usize) -> Self {
Self {
linear: Linear::new(in_features, out_features),
in_features,
out_features,
add_self_loops: false,
use_bias: true,
}
}
#[must_use]
pub fn in_features(&self) -> usize {
self.in_features
}
#[must_use]
pub fn out_features(&self) -> usize {
self.out_features
}
}
impl Module for GCNConv {
fn forward(&self, _input: &Tensor) -> Tensor {
panic!("GCNConv requires graph structure. Use forward_gnn() instead.");
}
fn parameters(&self) -> Vec<&Tensor> {
self.linear.parameters()
}
fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
self.linear.parameters_mut()
}
}
impl GNNModule for GCNConv {
fn forward_gnn(&self, x: &Tensor, edge_index: &[EdgeIndex]) -> Tensor {
let num_nodes = x.shape()[0];
let in_features = x.shape()[1];
assert_eq!(
in_features, self.in_features,
"Expected {} input features, got {}",
self.in_features, in_features
);
let mut degrees = vec![0.0f32; num_nodes];
if self.add_self_loops {
for d in &mut degrees {
*d += 1.0;
}
}
for &(src, tgt) in edge_index {
degrees[src] += 1.0;
degrees[tgt] += 1.0; }
let norm: Vec<f32> = degrees.iter().map(|&d| 1.0 / d.sqrt().max(1e-6)).collect();
let x_data = x.data();
let mut aggregated = vec![0.0f32; num_nodes * in_features];
if self.add_self_loops {
for i in 0..num_nodes {
let norm_ii = norm[i] * norm[i]; for f in 0..in_features {
aggregated[i * in_features + f] += norm_ii * x_data[i * in_features + f];
}
}
}
for &(src, tgt) in edge_index {
let norm_coeff = norm[src] * norm[tgt];
for f in 0..in_features {
aggregated[tgt * in_features + f] += norm_coeff * x_data[src * in_features + f];
}
for f in 0..in_features {
aggregated[src * in_features + f] += norm_coeff * x_data[tgt * in_features + f];
}
}
let agg_tensor = Tensor::new(&aggregated, &[num_nodes, in_features]);
self.linear.forward(&agg_tensor)
}
}
#[derive(Debug)]
pub struct GATConv {
linear: Linear,
attention_src: Tensor,
attention_tgt: Tensor,
in_features: usize,
out_features: usize,
num_heads: usize,
negative_slope: f32,
add_self_loops: bool,
}
impl GATConv {
#[must_use]
pub fn new(in_features: usize, out_features: usize, num_heads: usize) -> Self {
let total_out = out_features * num_heads;
let attn_data: Vec<f32> = (0..total_out)
.map(|i| ((i % 5) as f32 - 2.0) * 0.1)
.collect();
Self {
linear: Linear::new(in_features, total_out),
attention_src: Tensor::new(&attn_data, &[num_heads, out_features]).requires_grad(),
attention_tgt: Tensor::new(&attn_data, &[num_heads, out_features]).requires_grad(),
in_features,
out_features,
num_heads,
negative_slope: 0.2,
add_self_loops: true,
}
}
#[must_use]
pub fn num_heads(&self) -> usize {
self.num_heads
}
#[must_use]
pub fn out_features(&self) -> usize {
self.out_features
}
#[must_use]
pub fn total_out_features(&self) -> usize {
self.out_features * self.num_heads
}
}
impl Module for GATConv {
fn forward(&self, _input: &Tensor) -> Tensor {
panic!("GATConv requires graph structure. Use forward_gnn() instead.");
}
fn parameters(&self) -> Vec<&Tensor> {
let mut params = self.linear.parameters();
params.push(&self.attention_src);
params.push(&self.attention_tgt);
params
}
fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
let mut params = self.linear.parameters_mut();
params.push(&mut self.attention_src);
params.push(&mut self.attention_tgt);
params
}
}
impl GNNModule for GATConv {
fn forward_gnn(&self, x: &Tensor, edge_index: &[EdgeIndex]) -> Tensor {
let num_nodes = x.shape()[0];
let in_features = x.shape()[1];
assert_eq!(
in_features, self.in_features,
"Expected {} input features, got {}",
self.in_features, in_features
);
let h = self.linear.forward(x);
let h_data = h.data();
let total_out = self.num_heads * self.out_features;
let mut edges: Vec<EdgeIndex> = edge_index.to_vec();
if self.add_self_loops {
for i in 0..num_nodes {
edges.push((i, i));
}
}
let mut output = vec![0.0f32; num_nodes * total_out];
let mut neighbors: Vec<Vec<usize>> = vec![vec![]; num_nodes];
for &(src, tgt) in &edges {
neighbors[tgt].push(src);
}
let attn_src_data = self.attention_src.data();
let attn_tgt_data = self.attention_tgt.data();
for tgt in 0..num_nodes {
if neighbors[tgt].is_empty() {
continue;
}
for head in 0..self.num_heads {
let head_offset = head * self.out_features;
let mut scores: Vec<f32> = Vec::with_capacity(neighbors[tgt].len());
let mut max_score = f32::NEG_INFINITY;
for &src in &neighbors[tgt] {
let mut score = 0.0;
for f in 0..self.out_features {
let h_src_f = h_data[src * total_out + head_offset + f];
let h_tgt_f = h_data[tgt * total_out + head_offset + f];
score += attn_src_data[head * self.out_features + f] * h_src_f
+ attn_tgt_data[head * self.out_features + f] * h_tgt_f;
}
if score < 0.0 {
score *= self.negative_slope;
}
scores.push(score);
max_score = max_score.max(score);
}
let mut exp_sum = 0.0;
for s in &mut scores {
*s = (*s - max_score).exp();
exp_sum += *s;
}
for s in &mut scores {
*s /= exp_sum.max(1e-8);
}
for (idx, &src) in neighbors[tgt].iter().enumerate() {
let alpha = scores[idx];
for f in 0..self.out_features {
output[tgt * total_out + head_offset + f] +=
alpha * h_data[src * total_out + head_offset + f];
}
}
}
}
Tensor::new(&output, &[num_nodes, total_out])
}
}
mod gin_conv;
pub use gin_conv::*;
mod accumulate;
pub use accumulate::*;
#[cfg(test)]
mod tests;