use crate::autograd::Tensor;
use crate::nn::{Linear, Module};
use super::{EdgeIndex, GNNModule};
#[derive(Debug)]
pub struct GINConv {
linear1: Linear,
linear2: Linear,
eps: f32,
train_eps: bool,
in_features: usize,
hidden_features: usize,
out_features: usize,
}
impl GINConv {
#[must_use]
pub fn new(in_features: usize, hidden_features: usize, out_features: usize) -> Self {
Self {
linear1: Linear::new(in_features, hidden_features),
linear2: Linear::new(hidden_features, out_features),
eps: 0.0,
train_eps: true,
in_features,
hidden_features,
out_features,
}
}
#[must_use]
pub fn eps(&self) -> f32 {
self.eps
}
pub fn set_eps(&mut self, eps: f32) {
self.eps = eps;
}
#[must_use]
pub fn train_eps(&self) -> bool {
self.train_eps
}
#[must_use]
pub fn in_features(&self) -> usize {
self.in_features
}
#[must_use]
pub fn hidden_features(&self) -> usize {
self.hidden_features
}
#[must_use]
pub fn out_features(&self) -> usize {
self.out_features
}
}
impl Module for GINConv {
fn forward(&self, _input: &Tensor) -> Tensor {
panic!("GINConv requires graph structure. Use forward_gnn() instead.");
}
fn parameters(&self) -> Vec<&Tensor> {
let mut params = self.linear1.parameters();
params.extend(self.linear2.parameters());
params
}
fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
let mut params = self.linear1.parameters_mut();
params.extend(self.linear2.parameters_mut());
params
}
}
impl GNNModule for GINConv {
fn forward_gnn(&self, x: &Tensor, edge_index: &[EdgeIndex]) -> Tensor {
let num_nodes = x.shape()[0];
let in_features = x.shape()[1];
assert_eq!(
in_features, self.in_features,
"Expected {} input features, got {}",
self.in_features, in_features
);
let x_data = x.data();
let mut aggregated = vec![0.0f32; num_nodes * in_features];
let self_weight = 1.0 + self.eps;
for i in 0..num_nodes {
for f in 0..in_features {
aggregated[i * in_features + f] = self_weight * x_data[i * in_features + f];
}
}
for &(src, tgt) in edge_index {
for f in 0..in_features {
aggregated[tgt * in_features + f] += x_data[src * in_features + f];
aggregated[src * in_features + f] += x_data[tgt * in_features + f];
}
}
let agg_tensor = Tensor::new(&aggregated, &[num_nodes, in_features]);
let h1 = self.linear1.forward(&agg_tensor);
let h1_data = h1.data();
let h1_relu: Vec<f32> = h1_data.iter().map(|&v| v.max(0.0)).collect();
let h1_relu_tensor = Tensor::new(&h1_relu, h1.shape());
self.linear2.forward(&h1_relu_tensor)
}
}
#[derive(Debug)]
pub struct GraphSAGEConv {
linear: Linear,
aggregation: SAGEAggregation,
in_features: usize,
out_features: usize,
normalize: bool,
sample_size: Option<usize>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum SAGEAggregation {
Mean,
Max,
Sum,
}
impl GraphSAGEConv {
#[must_use]
pub fn new(in_features: usize, out_features: usize) -> Self {
Self {
linear: Linear::new(in_features * 2, out_features),
aggregation: SAGEAggregation::Mean,
in_features,
out_features,
normalize: true,
sample_size: None,
}
}
#[must_use]
pub fn with_aggregation(mut self, agg: SAGEAggregation) -> Self {
self.aggregation = agg;
self
}
#[must_use]
pub fn with_sample_size(mut self, size: usize) -> Self {
self.sample_size = Some(size);
self
}
#[must_use]
pub fn without_normalize(mut self) -> Self {
self.normalize = false;
self
}
#[must_use]
pub fn aggregation(&self) -> SAGEAggregation {
self.aggregation
}
#[must_use]
pub fn sample_size(&self) -> Option<usize> {
self.sample_size
}
}
impl Module for GraphSAGEConv {
fn forward(&self, _input: &Tensor) -> Tensor {
panic!("GraphSAGEConv requires graph structure. Use forward_gnn() instead.");
}
fn parameters(&self) -> Vec<&Tensor> {
self.linear.parameters()
}
fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
self.linear.parameters_mut()
}
}
impl GNNModule for GraphSAGEConv {
#[allow(clippy::needless_range_loop)]
fn forward_gnn(&self, x: &Tensor, edge_index: &[EdgeIndex]) -> Tensor {
let num_nodes = x.shape()[0];
let in_features = x.shape()[1];
assert_eq!(
in_features, self.in_features,
"Expected {} input features, got {}",
self.in_features, in_features
);
let x_data = x.data();
let mut neighbors: Vec<Vec<usize>> = vec![vec![]; num_nodes];
for &(src, tgt) in edge_index {
neighbors[tgt].push(src);
neighbors[src].push(tgt); }
if let Some(sample_size) = self.sample_size {
for nbrs in &mut neighbors {
if nbrs.len() > sample_size {
nbrs.truncate(sample_size);
}
}
}
let mut concat_features = vec![0.0f32; num_nodes * in_features * 2];
for i in 0..num_nodes {
for f in 0..in_features {
concat_features[i * in_features * 2 + f] = x_data[i * in_features + f];
}
let nbrs = &neighbors[i];
if nbrs.is_empty() {
continue;
}
match self.aggregation {
SAGEAggregation::Mean => {
for &n in nbrs {
for f in 0..in_features {
concat_features[i * in_features * 2 + in_features + f] +=
x_data[n * in_features + f];
}
}
let count = nbrs.len() as f32;
for f in 0..in_features {
concat_features[i * in_features * 2 + in_features + f] /= count;
}
}
SAGEAggregation::Max => {
if let Some(&first) = nbrs.first() {
for f in 0..in_features {
concat_features[i * in_features * 2 + in_features + f] =
x_data[first * in_features + f];
}
}
for &n in nbrs.iter().skip(1) {
for f in 0..in_features {
let current = concat_features[i * in_features * 2 + in_features + f];
let neighbor = x_data[n * in_features + f];
concat_features[i * in_features * 2 + in_features + f] =
current.max(neighbor);
}
}
}
SAGEAggregation::Sum => {
for &n in nbrs {
for f in 0..in_features {
concat_features[i * in_features * 2 + in_features + f] +=
x_data[n * in_features + f];
}
}
}
}
}
let concat_tensor = Tensor::new(&concat_features, &[num_nodes, in_features * 2]);
let mut out = self.linear.forward(&concat_tensor);
if self.normalize {
let out_data = out.data();
let mut normalized = Vec::with_capacity(out_data.len());
for i in 0..num_nodes {
let mut norm = 0.0f32;
for f in 0..self.out_features {
norm += out_data[i * self.out_features + f].powi(2);
}
norm = norm.sqrt().max(1e-8);
for f in 0..self.out_features {
normalized.push(out_data[i * self.out_features + f] / norm);
}
}
out = Tensor::new(&normalized, &[num_nodes, self.out_features]);
}
out
}
}
#[derive(Debug)]
pub struct EdgeConv {
linear1: Linear,
linear2: Linear,
in_features: usize,
hidden_features: usize,
out_features: usize,
}
impl EdgeConv {
#[must_use]
pub fn new(in_features: usize, hidden_features: usize, out_features: usize) -> Self {
Self {
linear1: Linear::new(in_features * 2, hidden_features),
linear2: Linear::new(hidden_features, out_features),
in_features,
hidden_features,
out_features,
}
}
#[must_use]
pub fn in_features(&self) -> usize {
self.in_features
}
#[must_use]
pub fn out_features(&self) -> usize {
self.out_features
}
}
impl Module for EdgeConv {
fn forward(&self, _input: &Tensor) -> Tensor {
panic!("EdgeConv requires graph structure. Use forward_gnn() instead.");
}
fn parameters(&self) -> Vec<&Tensor> {
let mut params = self.linear1.parameters();
params.extend(self.linear2.parameters());
params
}
fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
let mut params = self.linear1.parameters_mut();
params.extend(self.linear2.parameters_mut());
params
}
}
impl GNNModule for EdgeConv {
#[allow(clippy::needless_range_loop)]
fn forward_gnn(&self, x: &Tensor, edge_index: &[EdgeIndex]) -> Tensor {
let num_nodes = x.shape()[0];
let in_features = x.shape()[1];
assert_eq!(
in_features, self.in_features,
"Expected {} input features, got {}",
self.in_features, in_features
);
let x_data = x.data();
let mut neighbors: Vec<Vec<usize>> = vec![vec![]; num_nodes];
for &(src, tgt) in edge_index {
neighbors[tgt].push(src);
neighbors[src].push(tgt);
}
let mut output = vec![f32::NEG_INFINITY; num_nodes * self.out_features];
for i in 0..num_nodes {
if neighbors[i].is_empty() {
neighbors[i].push(i);
}
for &j in &neighbors[i] {
let mut edge_feat = Vec::with_capacity(in_features * 2);
for f in 0..in_features {
edge_feat.push(x_data[i * in_features + f]);
}
for f in 0..in_features {
edge_feat.push(x_data[j * in_features + f] - x_data[i * in_features + f]);
}
let edge_tensor = Tensor::new(&edge_feat, &[1, in_features * 2]);
let h1 = self.linear1.forward(&edge_tensor);
let h1_relu: Vec<f32> = h1.data().iter().map(|&v| v.max(0.0)).collect();
let h1_tensor = Tensor::new(&h1_relu, &[1, self.hidden_features]);
let h2 = self.linear2.forward(&h1_tensor);
let h2_data = h2.data();
for f in 0..self.out_features {
output[i * self.out_features + f] =
output[i * self.out_features + f].max(h2_data[f]);
}
}
}
for o in &mut output {
if *o == f32::NEG_INFINITY {
*o = 0.0;
}
}
Tensor::new(&output, &[num_nodes, self.out_features])
}
}