#[cfg(feature = "tensor-gnn")]
use rand_distr::{Distribution, StandardNormal};
#[cfg(feature = "tensor-gnn")]
use crate::tensor::traits::{TensorBase, TensorOps};
#[cfg(feature = "tensor-gnn")]
use crate::tensor::dense::DenseTensor;
#[cfg(feature = "tensor-gnn")]
use crate::tensor::sparse::SparseTensor;
#[cfg(all(feature = "tensor-gnn", not(feature = "std")))]
use rand::{rngs::StdRng, SeedableRng};
#[cfg(all(feature = "tensor-gnn", feature = "std"))]
use rand::thread_rng;
pub trait MessageFunction<H: TensorBase>: Send + Sync {
fn message(&self, src_features: &H, edge_features: Option<&H>, dst_features: &H) -> H;
}
pub trait Aggregator<H: TensorBase>: Send + Sync {
fn aggregate(&self, messages: &[H]) -> H;
}
pub trait UpdateFunction<H: TensorBase>: Send + Sync {
fn update(&self, old_state: &H, new_message: &H) -> H;
}
#[derive(Debug, Clone, Default)]
pub struct SumAggregator;
#[cfg(feature = "tensor-gnn")]
impl Aggregator<DenseTensor> for SumAggregator {
fn aggregate(&self, messages: &[DenseTensor]) -> DenseTensor {
if messages.is_empty() {
return DenseTensor::zeros(vec![1]);
}
let mut result = messages[0].clone();
for msg in &messages[1..] {
result = result.add(msg);
}
result
}
}
#[derive(Debug, Clone, Default)]
pub struct MeanAggregator;
#[cfg(feature = "tensor-gnn")]
impl Aggregator<DenseTensor> for MeanAggregator {
fn aggregate(&self, messages: &[DenseTensor]) -> DenseTensor {
if messages.is_empty() {
return DenseTensor::zeros(vec![1]);
}
let sum = SumAggregator.aggregate(messages);
sum.mul_scalar(1.0 / messages.len() as f64)
}
}
#[derive(Debug, Clone, Default)]
pub struct MaxAggregator;
#[cfg(feature = "tensor-gnn")]
impl Aggregator<DenseTensor> for MaxAggregator {
fn aggregate(&self, messages: &[DenseTensor]) -> DenseTensor {
if messages.is_empty() {
return DenseTensor::zeros(vec![1]);
}
let mut result = messages[0].clone();
for msg in &messages[1..] {
let data = result.data().to_vec();
let msg_data = msg.data();
let max_data: Vec<f64> = data
.iter()
.zip(msg_data.iter())
.map(|(&a, &b)| a.max(b))
.collect();
result = DenseTensor::new(max_data, result.shape().to_vec());
}
result
}
}
#[derive(Debug, Clone, Default)]
pub struct IdentityMessage;
#[cfg(feature = "tensor-gnn")]
impl MessageFunction<DenseTensor> for IdentityMessage {
fn message(
&self,
src_features: &DenseTensor,
_edge_features: Option<&DenseTensor>,
_dst_features: &DenseTensor,
) -> DenseTensor {
src_features.clone()
}
}
#[derive(Debug, Clone)]
pub struct LinearMessage {
weight: DenseTensor,
}
#[cfg(feature = "tensor-gnn")]
impl LinearMessage {
pub fn new(in_features: usize, out_features: usize) -> Self {
let std = (2.0 / (in_features + out_features) as f64).sqrt();
let mut rng = thread_rng();
let weight_data: Vec<f64> = (0..in_features * out_features)
.map(|_| {
let x: f64 = StandardNormal.sample(&mut rng);
x * std
})
.collect();
Self {
weight: DenseTensor::new(weight_data, vec![in_features, out_features]),
}
}
}
#[cfg(feature = "tensor-gnn")]
impl MessageFunction<DenseTensor> for LinearMessage {
fn message(
&self,
src_features: &DenseTensor,
_edge_features: Option<&DenseTensor>,
_dst_features: &DenseTensor,
) -> DenseTensor {
src_features.matmul(&self.weight.transpose(None))
}
}
pub struct MessagePassingLayer<M, A, U> {
message_fn: M,
aggregator: A,
update_fn: U,
}
impl<M, A, U> MessagePassingLayer<M, A, U>
where
M: MessageFunction<DenseTensor>,
A: Aggregator<DenseTensor>,
U: UpdateFunction<DenseTensor>,
{
pub fn new(message_fn: M, aggregator: A, update_fn: U) -> Self {
Self {
message_fn,
aggregator,
update_fn,
}
}
pub fn forward(
&self,
node_features: &DenseTensor,
edge_index: &[(usize, usize)],
edge_features: Option<&DenseTensor>,
) -> DenseTensor {
let mut messages: Vec<Vec<DenseTensor>> = vec![Vec::new(); node_features.shape()[0]];
for (src, dst) in edge_index {
let src_feat = self.extract_node(node_features, *src);
let dst_feat = self.extract_node(node_features, *dst);
let edge_feat = edge_features.map(|_| DenseTensor::scalar(1.0));
let msg = self
.message_fn
.message(&src_feat, edge_feat.as_ref(), &dst_feat);
messages[*dst].push(msg);
}
let mut updated_features = Vec::new();
for (node_idx, node_msgs) in messages.iter().enumerate() {
let old_state = self.extract_node(node_features, node_idx);
if node_msgs.is_empty() {
updated_features.extend_from_slice(old_state.data());
} else {
let aggregated = self.aggregator.aggregate(node_msgs);
let updated = self.update_fn.update(&old_state, &aggregated);
updated_features.extend_from_slice(updated.data());
}
}
DenseTensor::new(updated_features, node_features.shape().to_vec())
}
fn extract_node(&self, features: &DenseTensor, node_idx: usize) -> DenseTensor {
let num_features = features.shape()[1];
let start = node_idx * num_features;
let _end = start + num_features;
features.slice(&[0, 1], &[node_idx..node_idx + 1, 0..num_features])
}
}
#[allow(dead_code)]
pub struct GCNConv {
in_features: usize,
out_features: usize,
weight: DenseTensor,
bias: DenseTensor,
}
#[cfg(feature = "tensor-gnn")]
impl GCNConv {
pub fn new(in_features: usize, out_features: usize) -> Self {
let std = (6.0 / (in_features + out_features) as f64).sqrt();
let mut rng = thread_rng();
let weight_data: Vec<f64> = (0..in_features * out_features)
.map(|_| {
let x: f64 = StandardNormal.sample(&mut rng);
x * std
})
.collect();
let bias_data = vec![0.0; out_features];
Self {
in_features,
out_features,
weight: DenseTensor::new(weight_data, vec![in_features, out_features]),
bias: DenseTensor::new(bias_data, vec![out_features]),
}
}
pub fn forward(&self, node_features: &DenseTensor, adjacency: &SparseTensor) -> DenseTensor {
let h_transformed = node_features.matmul(&self.weight);
let normalized = self.normalize_adjacency(adjacency);
normalized.spmv(&h_transformed).unwrap()
}
fn normalize_adjacency(&self, adjacency: &SparseTensor) -> SparseTensor {
let degrees = self.compute_degrees(adjacency);
let _inv_sqrt_degrees = degrees.map(|d: f64| if d > 1e-10 { 1.0 / d.sqrt() } else { 0.0 });
adjacency.clone() }
fn compute_degrees(&self, adjacency: &SparseTensor) -> DenseTensor {
let num_nodes = adjacency.shape()[0];
let mut degrees = vec![0.0; num_nodes];
let coo = adjacency.to_coo();
for &row in coo.row_indices() {
degrees[row] += 1.0;
}
DenseTensor::new(degrees, vec![num_nodes])
}
}
#[allow(dead_code)]
pub struct GATConv {
in_features: usize,
out_features: usize,
num_heads: usize,
attention_vec: DenseTensor,
}
#[cfg(feature = "tensor-gnn")]
impl GATConv {
pub fn new(in_features: usize, out_features: usize, num_heads: usize) -> Self {
let std = (6.0 / (in_features + out_features) as f64).sqrt();
let mut rng = thread_rng();
let attention_data: Vec<f64> = (0..out_features * 2)
.map(|_| {
let x: f64 = StandardNormal.sample(&mut rng);
x * std
})
.collect();
Self {
in_features,
out_features,
num_heads,
attention_vec: DenseTensor::new(attention_data, vec![out_features * 2]),
}
}
pub fn forward(
&self,
node_features: &DenseTensor,
edge_index: &[(usize, usize)],
) -> DenseTensor {
let h_transformed = node_features.matmul(&self.weight());
let attention_scores = self.compute_attention(node_features, edge_index);
let normalized_attention = self.softmax(&attention_scores, edge_index);
self.aggregate_with_attention(&h_transformed, &normalized_attention, edge_index)
}
fn weight(&self) -> DenseTensor {
DenseTensor::eye(self.in_features)
}
fn compute_attention(
&self,
node_features: &DenseTensor,
edge_index: &[(usize, usize)],
) -> Vec<f64> {
edge_index
.iter()
.map(|(src, dst)| {
let src_feat = node_features.data()
[src * self.in_features..(src + 1) * self.in_features]
.to_vec();
let dst_feat = node_features.data()
[dst * self.in_features..(dst + 1) * self.in_features]
.to_vec();
let mut concatenated = src_feat;
concatenated.extend_from_slice(&dst_feat);
let score: f64 = concatenated
.iter()
.zip(self.attention_vec.data().iter().cycle())
.map(|(&a, &b)| a * b)
.sum();
score.max(0.0) })
.collect()
}
fn softmax(&self, scores: &[f64], edge_index: &[(usize, usize)]) -> Vec<f64> {
let mut dst_scores: std::collections::HashMap<usize, Vec<(usize, f64)>> =
std::collections::HashMap::new();
for ((src, dst), score) in edge_index.iter().zip(scores.iter()) {
dst_scores.entry(*dst).or_default().push((*src, *score));
}
let mut normalized = vec![0.0; scores.len()];
for (dst, scores) in dst_scores {
let max_score = scores
.iter()
.map(|(_, s)| *s)
.fold(f64::NEG_INFINITY, f64::max);
let exp_scores: Vec<(usize, f64)> = scores
.iter()
.map(|(src, s)| (*src, (*s - max_score).exp()))
.collect();
let sum_exp: f64 = exp_scores.iter().map(|(_, e)| *e).sum();
for (src, exp_val) in exp_scores {
if let Some(idx) = edge_index.iter().position(|(s, d)| *s == src && *d == dst) {
normalized[idx] = exp_val / sum_exp;
}
}
}
normalized
}
fn aggregate_with_attention(
&self,
node_features: &DenseTensor,
attention: &[f64],
edge_index: &[(usize, usize)],
) -> DenseTensor {
let num_nodes = node_features.shape()[0];
let mut result = vec![0.0; num_nodes * self.out_features];
for ((src, dst), &attn) in edge_index.iter().zip(attention.iter()) {
for i in 0..self.out_features {
result[dst * self.out_features + i] +=
attn * node_features.data()[src * self.in_features + i];
}
}
DenseTensor::new(result, vec![num_nodes, self.out_features])
}
}
pub struct GraphSAGE {
in_features: usize,
out_features: usize,
num_samples: usize,
}
#[cfg(feature = "tensor-gnn")]
impl GraphSAGE {
pub fn new(in_features: usize, out_features: usize, num_samples: usize) -> Self {
Self {
in_features,
out_features,
num_samples,
}
}
pub fn forward(
&self,
node_features: &DenseTensor,
edge_index: &[(usize, usize)],
) -> DenseTensor {
let num_nodes = node_features.shape()[0];
let mut result = Vec::new();
for node_idx in 0..num_nodes {
let neighbors: Vec<usize> = edge_index
.iter()
.filter(|(src, _)| *src == node_idx)
.take(self.num_samples)
.map(|(_, dst)| *dst)
.collect();
let neighbor_features = if neighbors.is_empty() {
DenseTensor::zeros(vec![self.in_features])
} else {
let features: Vec<DenseTensor> = neighbors
.iter()
.map(|&n| {
let start = n * self.in_features;
let end = start + self.in_features;
DenseTensor::new(
node_features.data()[start..end].to_vec(),
vec![self.in_features],
)
})
.collect();
MeanAggregator.aggregate(&features)
};
let self_features = node_features.data()
[node_idx * self.in_features..(node_idx + 1) * self.in_features]
.to_vec();
let mut concatenated = self_features;
concatenated.extend_from_slice(neighbor_features.data());
let transformed: Vec<f64> = concatenated
.iter()
.take(self.out_features)
.copied()
.collect();
result.extend_from_slice(&transformed);
}
DenseTensor::new(result, vec![num_nodes, self.out_features])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sum_aggregator() {
let aggregator = SumAggregator;
let messages = vec![
DenseTensor::new(vec![1.0, 2.0], vec![2]),
DenseTensor::new(vec![3.0, 4.0], vec![2]),
DenseTensor::new(vec![5.0, 6.0], vec![2]),
];
let result = aggregator.aggregate(&messages);
assert_eq!(result.data(), &[9.0, 12.0]);
}
#[test]
fn test_mean_aggregator() {
let aggregator = MeanAggregator;
let messages = vec![
DenseTensor::new(vec![1.0, 2.0], vec![2]),
DenseTensor::new(vec![3.0, 4.0], vec![2]),
DenseTensor::new(vec![5.0, 6.0], vec![2]),
];
let result = aggregator.aggregate(&messages);
assert_eq!(result.data(), &[3.0, 4.0]);
}
#[test]
fn test_identity_message() {
let message_fn = IdentityMessage;
let src = DenseTensor::new(vec![1.0, 2.0, 3.0], vec![3]);
let dst = DenseTensor::new(vec![4.0, 5.0, 6.0], vec![3]);
let result = message_fn.message(&src, None, &dst);
assert_eq!(result.data(), src.data());
}
}