use crate::autograd::Tensor;
use crate::primitives::Matrix;
#[derive(Debug, Clone)]
pub struct AdjacencyMatrix {
num_nodes: usize,
edge_src: Vec<usize>,
edge_tgt: Vec<usize>,
edge_weights: Option<Vec<f32>>,
has_self_loops: bool,
}
impl AdjacencyMatrix {
#[must_use]
pub fn from_edge_index(edges: &[[usize; 2]], num_nodes: usize) -> Self {
let edge_src: Vec<usize> = edges.iter().map(|e| e[0]).collect();
let edge_tgt: Vec<usize> = edges.iter().map(|e| e[1]).collect();
Self {
num_nodes,
edge_src,
edge_tgt,
edge_weights: None,
has_self_loops: false,
}
}
#[must_use]
pub fn from_coo(src: Vec<usize>, tgt: Vec<usize>, num_nodes: usize) -> Self {
Self {
num_nodes,
edge_src: src,
edge_tgt: tgt,
edge_weights: None,
has_self_loops: false,
}
}
#[must_use]
pub fn add_self_loops(mut self) -> Self {
if self.has_self_loops {
return self;
}
for i in 0..self.num_nodes {
self.edge_src.push(i);
self.edge_tgt.push(i);
}
if let Some(ref mut weights) = self.edge_weights {
weights.extend(vec![1.0; self.num_nodes]);
}
self.has_self_loops = true;
self
}
#[must_use]
pub fn with_weights(mut self, weights: Vec<f32>) -> Self {
self.edge_weights = Some(weights);
self
}
#[must_use]
pub fn num_nodes(&self) -> usize {
self.num_nodes
}
#[must_use]
pub fn num_edges(&self) -> usize {
self.edge_src.len()
}
#[must_use]
pub fn edge_src(&self) -> &[usize] {
&self.edge_src
}
#[must_use]
pub fn edge_tgt(&self) -> &[usize] {
&self.edge_tgt
}
#[must_use]
pub fn has_self_loops(&self) -> bool {
self.has_self_loops
}
#[must_use]
pub fn in_degrees(&self) -> Vec<f32> {
let mut degrees = vec![0.0f32; self.num_nodes];
for &tgt in &self.edge_tgt {
if tgt < self.num_nodes {
degrees[tgt] += 1.0;
}
}
degrees
}
#[must_use]
pub fn out_degrees(&self) -> Vec<f32> {
let mut degrees = vec![0.0f32; self.num_nodes];
for &src in &self.edge_src {
if src < self.num_nodes {
degrees[src] += 1.0;
}
}
degrees
}
#[must_use]
pub fn neighbors(&self, node: usize) -> Vec<usize> {
self.edge_src
.iter()
.zip(self.edge_tgt.iter())
.filter(|(&src, _)| src == node)
.map(|(_, &tgt)| tgt)
.collect()
}
#[must_use]
pub fn to_dense(&self) -> Matrix<f32> {
let n = self.num_nodes;
let mut data = vec![0.0f32; n * n];
for (i, (&src, &tgt)) in self.edge_src.iter().zip(self.edge_tgt.iter()).enumerate() {
if src < n && tgt < n {
let weight = self
.edge_weights
.as_ref()
.map_or(1.0, |w| w.get(i).copied().unwrap_or(1.0));
data[src * n + tgt] = weight;
}
}
Matrix::from_vec(n, n, data).expect("Valid matrix dimensions")
}
}
#[derive(Debug, Clone)]
pub struct GCNConv {
in_features: usize,
out_features: usize,
weight: Tensor,
bias: Option<Tensor>,
use_bias: bool,
add_self_loops: bool,
normalize: bool,
}
impl GCNConv {
#[must_use]
pub fn new(in_features: usize, out_features: usize) -> Self {
let std = (2.0 / (in_features + out_features) as f32).sqrt();
let weight_data: Vec<f32> = (0..in_features * out_features)
.map(|i| (i as f32 * 0.1).sin() * std)
.collect();
let bias_data: Vec<f32> = vec![0.0; out_features];
Self {
in_features,
out_features,
weight: Tensor::new(&weight_data, &[in_features, out_features]),
bias: Some(Tensor::new(&bias_data, &[out_features])),
use_bias: true,
add_self_loops: true,
normalize: true,
}
}
#[must_use]
pub fn without_bias(mut self) -> Self {
self.use_bias = false;
self.bias = None;
self
}
#[must_use]
pub fn without_self_loops(mut self) -> Self {
self.add_self_loops = false;
self
}
#[must_use]
pub fn without_normalize(mut self) -> Self {
self.normalize = false;
self
}
#[must_use]
pub fn in_features(&self) -> usize {
self.in_features
}
#[must_use]
pub fn out_features(&self) -> usize {
self.out_features
}
#[must_use]
pub fn forward(&self, x: &Tensor, adj: &AdjacencyMatrix) -> Tensor {
let num_nodes = x.shape()[0];
let in_feat = x.shape()[1];
assert_eq!(
in_feat, self.in_features,
"Input features mismatch: expected {}, got {}",
self.in_features, in_feat
);
let adj_normalized = if self.add_self_loops && !adj.has_self_loops() {
adj.clone().add_self_loops()
} else {
adj.clone()
};
let degrees = adj_normalized.in_degrees();
let norm_coeffs: Vec<f32> = degrees
.iter()
.map(|&d| if d > 0.0 { 1.0 / d.sqrt() } else { 0.0 })
.collect();
let x_data = x.data();
let w_data = self.weight.data();
let mut h_data = vec![0.0f32; num_nodes * self.out_features];
for node in 0..num_nodes {
for out_f in 0..self.out_features {
let mut sum = 0.0f32;
for in_f in 0..self.in_features {
sum += x_data[node * in_feat + in_f] * w_data[in_f * self.out_features + out_f];
}
h_data[node * self.out_features + out_f] = sum;
}
}
let mut output = vec![0.0f32; num_nodes * self.out_features];
if self.normalize {
for (i, (&src, &tgt)) in adj_normalized
.edge_src()
.iter()
.zip(adj_normalized.edge_tgt().iter())
.enumerate()
{
if src < num_nodes && tgt < num_nodes {
let edge_weight = adj_normalized
.edge_weights
.as_ref()
.map_or(1.0, |w| w.get(i).copied().unwrap_or(1.0));
let norm = norm_coeffs[src] * norm_coeffs[tgt] * edge_weight;
for f in 0..self.out_features {
output[tgt * self.out_features + f] +=
norm * h_data[src * self.out_features + f];
}
}
}
} else {
for (&src, &tgt) in adj_normalized
.edge_src()
.iter()
.zip(adj_normalized.edge_tgt().iter())
{
if src < num_nodes && tgt < num_nodes {
for f in 0..self.out_features {
output[tgt * self.out_features + f] += h_data[src * self.out_features + f];
}
}
}
}
if self.use_bias {
if let Some(ref bias) = self.bias {
let bias_data = bias.data();
for node in 0..num_nodes {
for f in 0..self.out_features {
output[node * self.out_features + f] += bias_data[f];
}
}
}
}
Tensor::new(&output, &[num_nodes, self.out_features])
}
#[must_use]
pub fn weight(&self) -> &Tensor {
&self.weight
}
#[must_use]
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum SAGEAggregation {
#[default]
Mean,
Max,
Sum,
Lstm,
}
#[path = "sage_gat.rs"]
mod sage_gat;
pub use sage_gat::*;
#[path = "message_passing.rs"]
mod message_passing;
pub use message_passing::*;