#![allow(dead_code)]
use crate::{CooTensor, CsrTensor, SparseTensor, TorshResult};
use torsh_core::TorshError;
use torsh_tensor::{
creation::{randn, zeros},
Tensor,
};
#[derive(Debug, Clone)]
pub struct GraphConvolution {
weight: Tensor,
bias: Option<Tensor>,
in_features: usize,
out_features: usize,
add_self_loops: bool,
normalize: bool,
}
impl GraphConvolution {
pub fn new(
in_features: usize,
out_features: usize,
use_bias: bool,
add_self_loops: bool,
normalize: bool,
) -> TorshResult<Self> {
let _std_dev = (2.0 / (in_features + out_features) as f32).sqrt();
let weight = randn::<f32>(&[in_features, out_features])?;
let bias = if use_bias {
Some(zeros::<f32>(&[out_features])?)
} else {
None
};
Ok(Self {
weight,
bias,
in_features,
out_features,
add_self_loops,
normalize,
})
}
pub fn forward(&self, node_features: &Tensor, adjacency: &CsrTensor) -> TorshResult<Tensor> {
let feature_shape = node_features.shape();
if feature_shape.ndim() != 2 {
return Err(TorshError::InvalidArgument(
"Node features must be 2D tensor (num_nodes x in_features)".to_string(),
));
}
let num_nodes = feature_shape.dims()[0];
let input_features = feature_shape.dims()[1];
if input_features != self.in_features {
return Err(TorshError::InvalidArgument(format!(
"Input features {} don't match layer input features {}",
input_features, self.in_features
)));
}
let adj_shape = adjacency.shape();
if adj_shape.dims() != [num_nodes, num_nodes] {
return Err(TorshError::InvalidArgument(
"Adjacency matrix must be square and match number of nodes".to_string(),
));
}
let adj_processed = if self.add_self_loops {
self.add_self_loops_to_adjacency(adjacency)?
} else {
adjacency.clone()
};
let adj_normalized = if self.normalize {
self.normalize_adjacency(&adj_processed)?
} else {
adj_processed
};
let transformed_features = zeros::<f32>(&[num_nodes, self.out_features])?;
for i in 0..num_nodes {
for j in 0..self.out_features {
let mut sum = 0.0;
for k in 0..self.in_features {
sum += node_features.get(&[i, k])? * self.weight.get(&[k, j])?;
}
transformed_features.set(&[i, j], sum)?;
}
}
let output = zeros::<f32>(&[num_nodes, self.out_features])?;
for i in 0..num_nodes {
let (neighbors, weights) = adj_normalized.get_row(i)?;
for j in 0..self.out_features {
let mut sum = 0.0;
for (&neighbor, &weight) in neighbors.iter().zip(weights.iter()) {
sum += weight * transformed_features.get(&[neighbor, j])?;
}
output.set(&[i, j], sum)?;
}
}
if let Some(ref bias) = self.bias {
for i in 0..num_nodes {
for j in 0..self.out_features {
let current = output.get(&[i, j])?;
output.set(&[i, j], current + bias.get(&[j])?)?;
}
}
}
Ok(output)
}
fn add_self_loops_to_adjacency(&self, adjacency: &CsrTensor) -> TorshResult<CsrTensor> {
let coo = adjacency.to_coo()?;
let mut triplets = coo.triplets();
let num_nodes = adjacency.shape().dims()[0];
let mut self_loop_set = std::collections::HashSet::new();
for (row, col, _) in &triplets {
if row == col {
self_loop_set.insert(*row);
}
}
for i in 0..num_nodes {
if !self_loop_set.contains(&i) {
triplets.push((i, i, 1.0));
}
}
let (row_indices, col_indices, values): (Vec<_>, Vec<_>, Vec<_>) =
triplets.into_iter().fold(
(Vec::new(), Vec::new(), Vec::new()),
|(mut rows, mut cols, mut vals), (r, c, v)| {
rows.push(r);
cols.push(c);
vals.push(v);
(rows, cols, vals)
},
);
let new_coo = CooTensor::new(row_indices, col_indices, values, adjacency.shape().clone())?;
CsrTensor::from_coo(&new_coo)
}
fn normalize_adjacency(&self, adjacency: &CsrTensor) -> TorshResult<CsrTensor> {
let num_nodes = adjacency.shape().dims()[0];
let mut degrees = vec![0.0; num_nodes];
let coo = adjacency.to_coo()?;
let triplets = coo.triplets();
for (row, _col, val) in &triplets {
degrees[*row] += val;
}
let inv_sqrt_degrees: Vec<f32> = degrees
.iter()
.map(|&d| if d > 0.0 { 1.0 / d.sqrt() } else { 0.0 })
.collect();
let normalized_triplets: Vec<_> = triplets
.into_iter()
.map(|(row, col, val)| {
let normalized_val = inv_sqrt_degrees[row] * val * inv_sqrt_degrees[col];
(row, col, normalized_val)
})
.collect();
let (row_indices, col_indices, values): (Vec<_>, Vec<_>, Vec<_>) =
normalized_triplets.into_iter().fold(
(Vec::new(), Vec::new(), Vec::new()),
|(mut rows, mut cols, mut vals), (r, c, v)| {
rows.push(r);
cols.push(c);
vals.push(v);
(rows, cols, vals)
},
);
let normalized_coo =
CooTensor::new(row_indices, col_indices, values, adjacency.shape().clone())?;
CsrTensor::from_coo(&normalized_coo)
}
pub fn num_parameters(&self) -> usize {
let weight_params = self.in_features * self.out_features;
let bias_params = self.bias.as_ref().map_or(0, |b| b.shape().numel());
weight_params + bias_params
}
pub fn in_features(&self) -> usize {
self.in_features
}
pub fn out_features(&self) -> usize {
self.out_features
}
pub fn adds_self_loops(&self) -> bool {
self.add_self_loops
}
pub fn normalizes(&self) -> bool {
self.normalize
}
pub fn weight(&self) -> &Tensor {
&self.weight
}
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
}
#[derive(Debug, Clone)]
pub struct GraphAttention {
weight: Tensor,
attention_weights: Tensor,
bias: Option<Tensor>,
in_features: usize,
out_features: usize,
num_heads: usize,
dropout: f32,
}
impl GraphAttention {
pub fn new(
in_features: usize,
out_features: usize,
num_heads: usize,
dropout: f32,
use_bias: bool,
) -> TorshResult<Self> {
if !(0.0..=1.0).contains(&dropout) {
return Err(TorshError::InvalidArgument(
"Dropout must be between 0.0 and 1.0".to_string(),
));
}
if num_heads == 0 {
return Err(TorshError::InvalidArgument(
"Number of heads must be greater than 0".to_string(),
));
}
let weight = randn::<f32>(&[in_features, out_features * num_heads])?;
let attention_weights = randn::<f32>(&[2 * out_features, num_heads])?;
let bias = if use_bias {
Some(zeros::<f32>(&[out_features * num_heads])?)
} else {
None
};
Ok(Self {
weight,
attention_weights,
bias,
in_features,
out_features,
num_heads,
dropout,
})
}
pub fn forward(&self, node_features: &Tensor, _adjacency: &CsrTensor) -> TorshResult<Tensor> {
let feature_shape = node_features.shape();
if feature_shape.ndim() != 2 {
return Err(TorshError::InvalidArgument(
"Node features must be 2D tensor".to_string(),
));
}
let num_nodes = feature_shape.dims()[0];
let output = zeros::<f32>(&[num_nodes, self.out_features * self.num_heads])?;
for i in 0..num_nodes {
for j in 0..(self.out_features * self.num_heads) {
let mut sum = 0.0;
for k in 0..self.in_features {
sum += node_features.get(&[i, k])? * self.weight.get(&[k, j])?;
}
output.set(&[i, j], sum)?;
}
}
if let Some(ref bias) = self.bias {
for i in 0..num_nodes {
for j in 0..(self.out_features * self.num_heads) {
let current = output.get(&[i, j])?;
output.set(&[i, j], current + bias.get(&[j])?)?;
}
}
}
Ok(output)
}
pub fn num_parameters(&self) -> usize {
let weight_params = self.in_features * self.out_features * self.num_heads;
let attention_params = 2 * self.out_features * self.num_heads;
let bias_params = self.bias.as_ref().map_or(0, |b| b.shape().numel());
weight_params + attention_params + bias_params
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{CooTensor, CsrTensor};
use torsh_core::Shape;
use torsh_tensor::creation::ones;
#[test]
fn test_graph_convolution_creation() {
let gcn = GraphConvolution::new(16, 32, true, true, true)
.expect("Graph Convolution should succeed");
assert_eq!(gcn.in_features(), 16);
assert_eq!(gcn.out_features(), 32);
assert!(gcn.adds_self_loops());
assert!(gcn.normalizes());
assert!(gcn.num_parameters() > 0);
}
#[test]
fn test_graph_convolution_forward() {
let gcn = GraphConvolution::new(4, 2, false, false, false)
.expect("Graph Convolution should succeed");
let features = ones::<f32>(&[3, 4]).expect("operation should succeed");
let row_indices = vec![0, 1, 2, 0, 1, 2];
let col_indices = vec![1, 2, 0, 0, 1, 2];
let values = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
let shape = Shape::new(vec![3, 3]);
let coo = CooTensor::new(row_indices, col_indices, values, shape)
.expect("Coo Tensor should succeed");
let adjacency = CsrTensor::from_coo(&coo).expect("Csr Tensor should succeed");
let output = gcn
.forward(&features, &adjacency)
.expect("forward pass should succeed");
assert_eq!(output.shape().dims(), &[3, 2]);
}
#[test]
fn test_graph_attention_creation() {
let gat = GraphAttention::new(8, 16, 4, 0.1, true).expect("Graph Attention should succeed");
assert!(gat.num_parameters() > 0);
}
#[test]
fn test_invalid_parameters() {
assert!(GraphAttention::new(8, 16, 0, 0.1, true).is_err()); assert!(GraphAttention::new(8, 16, 4, 1.5, true).is_err()); }
#[test]
fn test_self_loop_addition() {
let gcn = GraphConvolution::new(2, 2, false, true, false)
.expect("Graph Convolution should succeed");
let row_indices = vec![0, 1];
let col_indices = vec![1, 0];
let values = vec![1.0, 1.0];
let shape = Shape::new(vec![2, 2]);
let coo = CooTensor::new(row_indices, col_indices, values, shape)
.expect("Coo Tensor should succeed");
let adjacency = CsrTensor::from_coo(&coo).expect("Csr Tensor should succeed");
let features = ones::<f32>(&[2, 2]).expect("operation should succeed");
let _output = gcn
.forward(&features, &adjacency)
.expect("forward pass should succeed");
}
#[test]
fn test_normalization() {
let gcn = GraphConvolution::new(3, 3, false, false, true)
.expect("Graph Convolution should succeed");
let row_indices = vec![0, 0, 1, 1, 2, 2];
let col_indices = vec![0, 1, 0, 1, 1, 2];
let values = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
let shape = Shape::new(vec![3, 3]);
let coo = CooTensor::new(row_indices, col_indices, values, shape)
.expect("Coo Tensor should succeed");
let adjacency = CsrTensor::from_coo(&coo).expect("Csr Tensor should succeed");
let features = ones::<f32>(&[3, 3]).expect("operation should succeed");
let _output = gcn
.forward(&features, &adjacency)
.expect("forward pass should succeed");
}
#[test]
fn test_dimension_validation() {
let gcn = GraphConvolution::new(4, 2, false, false, false)
.expect("Graph Convolution should succeed");
let wrong_features = ones::<f32>(&[3, 5]).expect("operation should succeed");
let row_indices = vec![0, 1, 2];
let col_indices = vec![1, 2, 0];
let values = vec![1.0, 1.0, 1.0];
let shape = Shape::new(vec![3, 3]);
let coo = CooTensor::new(row_indices, col_indices, values, shape)
.expect("Coo Tensor should succeed");
let adjacency = CsrTensor::from_coo(&coo).expect("Csr Tensor should succeed");
assert!(gcn.forward(&wrong_features, &adjacency).is_err());
}
}