#[allow(clippy::wildcard_imports)]
use super::*;
#[derive(Debug, Clone)]
pub struct SAGEConv {
in_features: usize,
out_features: usize,
weight_self: Tensor,
weight_neigh: Tensor,
bias: Option<Tensor>,
aggregation: SAGEAggregation,
normalize: bool,
root_weight: bool,
}
impl SAGEConv {
#[must_use]
pub fn new(in_features: usize, out_features: usize) -> Self {
let std = (2.0 / (in_features + out_features) as f32).sqrt();
let weight_self_data: Vec<f32> = (0..in_features * out_features)
.map(|i| (i as f32 * 0.13).sin() * std)
.collect();
let weight_neigh_data: Vec<f32> = (0..in_features * out_features)
.map(|i| (i as f32 * 0.17).sin() * std)
.collect();
let bias_data = vec![0.0f32; out_features];
Self {
in_features,
out_features,
weight_self: Tensor::new(&weight_self_data, &[in_features, out_features]),
weight_neigh: Tensor::new(&weight_neigh_data, &[in_features, out_features]),
bias: Some(Tensor::new(&bias_data, &[out_features])),
aggregation: SAGEAggregation::Mean,
normalize: false,
root_weight: true,
}
}
#[must_use]
pub fn with_aggregation(mut self, agg: SAGEAggregation) -> Self {
self.aggregation = agg;
self
}
#[must_use]
pub fn with_normalize(mut self) -> Self {
self.normalize = true;
self
}
#[must_use]
pub fn without_root(mut self) -> Self {
self.root_weight = false;
self
}
#[must_use]
pub fn without_bias(mut self) -> Self {
self.bias = None;
self
}
#[must_use]
pub fn in_features(&self) -> usize {
self.in_features
}
#[must_use]
pub fn out_features(&self) -> usize {
self.out_features
}
#[must_use]
pub fn aggregation(&self) -> SAGEAggregation {
self.aggregation
}
fn accumulate_neighbors(
x_data: &[f32],
neighbors: &[usize],
in_features: usize,
agg: &mut [f32],
) {
for &neigh in neighbors {
for f in 0..in_features {
agg[f] += x_data[neigh * in_features + f];
}
}
}
fn scale_by_mean(agg: &mut [f32], count: usize) {
let divisor = count as f32;
for f in agg.iter_mut() {
*f /= divisor;
}
}
fn aggregate_neighbors(&self, x_data: &[f32], neighbors: &[usize]) -> Vec<f32> {
if neighbors.is_empty() {
return vec![0.0f32; self.in_features];
}
match self.aggregation {
SAGEAggregation::Mean | SAGEAggregation::Lstm => {
let mut agg = vec![0.0f32; self.in_features];
Self::accumulate_neighbors(x_data, neighbors, self.in_features, &mut agg);
Self::scale_by_mean(&mut agg, neighbors.len());
agg
}
SAGEAggregation::Sum => {
let mut agg = vec![0.0f32; self.in_features];
Self::accumulate_neighbors(x_data, neighbors, self.in_features, &mut agg);
agg
}
SAGEAggregation::Max => {
let mut agg = vec![f32::NEG_INFINITY; self.in_features];
for &neigh in neighbors {
for f in 0..self.in_features {
agg[f] = agg[f].max(x_data[neigh * self.in_features + f]);
}
}
for f in &mut agg {
if f.is_infinite() {
*f = 0.0;
}
}
agg
}
}
}
fn transform_node(
&self,
node: usize,
x_data: &[f32],
in_feat: usize,
ws_data: &[f32],
wn_data: &[f32],
agg_features: &[f32],
output: &mut [f32],
) {
for out_f in 0..self.out_features {
let mut val = 0.0f32;
if self.root_weight {
for in_f in 0..self.in_features {
val +=
x_data[node * in_feat + in_f] * ws_data[in_f * self.out_features + out_f];
}
}
for in_f in 0..self.in_features {
val += agg_features[in_f] * wn_data[in_f * self.out_features + out_f];
}
output[node * self.out_features + out_f] = val;
}
}
fn add_bias_to_output(
bias_data: &[f32],
num_nodes: usize,
out_features: usize,
output: &mut [f32],
) {
for node in 0..num_nodes {
for f in 0..out_features {
output[node * out_features + f] += bias_data[f];
}
}
}
fn l2_normalize_rows(num_nodes: usize, out_features: usize, output: &mut [f32]) {
for node in 0..num_nodes {
let start = node * out_features;
let end = start + out_features;
let norm: f32 = output[start..end].iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-8 {
for val in &mut output[start..end] {
*val /= norm;
}
}
}
}
#[must_use]
pub fn forward(&self, x: &Tensor, adj: &AdjacencyMatrix) -> Tensor {
let num_nodes = x.shape()[0];
let in_feat = x.shape()[1];
assert_eq!(in_feat, self.in_features);
let x_data = x.data();
let ws_data = self.weight_self.data();
let wn_data = self.weight_neigh.data();
let neighbor_lists = Self::build_neighbor_lists(adj, num_nodes);
let mut output = vec![0.0f32; num_nodes * self.out_features];
for node in 0..num_nodes {
let agg_features = self.aggregate_neighbors(x_data, &neighbor_lists[node]);
self.transform_node(
node,
x_data,
in_feat,
ws_data,
wn_data,
&agg_features,
&mut output,
);
}
if let Some(ref bias) = self.bias {
Self::add_bias_to_output(bias.data(), num_nodes, self.out_features, &mut output);
}
if self.normalize {
Self::l2_normalize_rows(num_nodes, self.out_features, &mut output);
}
Tensor::new(&output, &[num_nodes, self.out_features])
}
fn build_neighbor_lists(adj: &AdjacencyMatrix, num_nodes: usize) -> Vec<Vec<usize>> {
let mut neighbor_lists: Vec<Vec<usize>> = vec![Vec::new(); num_nodes];
for (&src, &tgt) in adj.edge_src().iter().zip(adj.edge_tgt().iter()) {
if tgt < num_nodes && src < num_nodes {
neighbor_lists[tgt].push(src);
}
}
neighbor_lists
}
}
#[derive(Debug, Clone)]
pub struct GATConv {
pub(crate) in_features: usize,
pub(crate) out_features: usize,
pub(crate) num_heads: usize,
pub(crate) weight: Tensor,
pub(crate) att_src: Tensor,
pub(crate) att_tgt: Tensor,
pub(crate) bias: Option<Tensor>,
pub(crate) negative_slope: f32,
pub(crate) dropout: f32,
pub(crate) concat: bool,
pub(crate) add_self_loops: bool,
}