use crate::error::NumRs2Error;
use scirs2_core::ndarray::{s, Array, Array1, Array2, ArrayView1, ArrayView2, Axis};
use scirs2_core::numeric::Float;
use scirs2_core::simd_ops::SimdUnifiedOps;
use std::collections::HashMap;
pub type GraphResult<T> = Result<T, NumRs2Error>;
#[derive(Debug, Clone)]
pub struct AdjacencyMatrix<T: Float> {
pub num_nodes: usize,
pub adj: Array2<T>,
}
impl<T: Float> AdjacencyMatrix<T> {
pub fn from_edges(num_nodes: usize, edges: &[(usize, usize)]) -> GraphResult<Self> {
let mut adj = Array2::zeros((num_nodes, num_nodes));
for &(src, dst) in edges {
if src >= num_nodes || dst >= num_nodes {
return Err(NumRs2Error::ValueError(format!(
"Edge ({}, {}) out of bounds for {} nodes",
src, dst, num_nodes
)));
}
adj[[src, dst]] = T::one();
}
Ok(Self { num_nodes, adj })
}
pub fn with_self_loops(&self) -> GraphResult<Self> {
let mut adj = self.adj.clone();
for i in 0..self.num_nodes {
adj[[i, i]] = T::one();
}
Ok(Self {
num_nodes: self.num_nodes,
adj,
})
}
pub fn degree_matrix(&self) -> GraphResult<Array1<T>> {
let mut degrees = self.adj.sum_axis(Axis(1));
for deg in degrees.iter_mut() {
if *deg < T::one() {
*deg = T::one();
}
}
Ok(degrees)
}
pub fn symmetric_normalize(&self) -> GraphResult<Array2<T>> {
let degrees = self.degree_matrix()?;
let mut d_inv_sqrt = Array1::zeros(self.num_nodes);
for (i, °) in degrees.iter().enumerate() {
if deg > T::zero() {
d_inv_sqrt[i] = T::one() / deg.sqrt();
}
}
let mut norm_adj = self.adj.clone();
for i in 0..self.num_nodes {
for j in 0..self.num_nodes {
norm_adj[[i, j]] = norm_adj[[i, j]] * d_inv_sqrt[i] * d_inv_sqrt[j];
}
}
Ok(norm_adj)
}
}
#[derive(Debug, Clone)]
pub struct EdgeList<T: Float> {
pub num_nodes: usize,
pub edges: Vec<(usize, usize, T)>,
}
impl<T: Float> EdgeList<T> {
pub fn from_edges(num_nodes: usize, edges: &[(usize, usize)]) -> GraphResult<Self> {
for &(src, dst) in edges {
if src >= num_nodes || dst >= num_nodes {
return Err(NumRs2Error::ValueError(format!(
"Edge ({}, {}) out of bounds for {} nodes",
src, dst, num_nodes
)));
}
}
let weighted_edges: Vec<_> = edges
.iter()
.map(|&(src, dst)| (src, dst, T::one()))
.collect();
Ok(Self {
num_nodes,
edges: weighted_edges,
})
}
pub fn from_weighted_edges(
num_nodes: usize,
edges: Vec<(usize, usize, T)>,
) -> GraphResult<Self> {
for &(src, dst, _) in &edges {
if src >= num_nodes || dst >= num_nodes {
return Err(NumRs2Error::ValueError(format!(
"Edge ({}, {}) out of bounds for {} nodes",
src, dst, num_nodes
)));
}
}
Ok(Self { num_nodes, edges })
}
pub fn to_csr(&self) -> GraphResult<SparseAdjacency<T>> {
SparseAdjacency::from_edge_list(self)
}
}
#[derive(Debug, Clone)]
pub struct SparseAdjacency<T: Float> {
pub num_nodes: usize,
pub row_ptr: Vec<usize>,
pub col_indices: Vec<usize>,
pub values: Vec<T>,
}
impl<T: Float> SparseAdjacency<T> {
pub fn from_edge_list(edge_list: &EdgeList<T>) -> GraphResult<Self> {
let num_nodes = edge_list.num_nodes;
let mut row_ptr = vec![0; num_nodes + 1];
for &(src, _, _) in &edge_list.edges {
row_ptr[src + 1] += 1;
}
for i in 1..=num_nodes {
row_ptr[i] += row_ptr[i - 1];
}
let num_edges = edge_list.edges.len();
let mut col_indices = vec![0; num_edges];
let mut values = vec![T::zero(); num_edges];
let mut current_pos = row_ptr[..num_nodes].to_vec();
for &(src, dst, weight) in &edge_list.edges {
let pos = current_pos[src];
col_indices[pos] = dst;
values[pos] = weight;
current_pos[src] += 1;
}
Ok(Self {
num_nodes,
row_ptr,
col_indices,
values,
})
}
pub fn from_edges(num_nodes: usize, edges: &[(usize, usize)]) -> GraphResult<Self> {
let edge_list = EdgeList::from_edges(num_nodes, edges)?;
Self::from_edge_list(&edge_list)
}
pub fn neighbors(&self, node: usize) -> GraphResult<(&[usize], &[T])> {
if node >= self.num_nodes {
return Err(NumRs2Error::ValueError(format!(
"Node {} out of bounds for {} nodes",
node, self.num_nodes
)));
}
let start = self.row_ptr[node];
let end = self.row_ptr[node + 1];
Ok((&self.col_indices[start..end], &self.values[start..end]))
}
pub fn degrees(&self) -> Array1<T> {
let mut degrees = Array1::zeros(self.num_nodes);
for i in 0..self.num_nodes {
let start = self.row_ptr[i];
let end = self.row_ptr[i + 1];
let degree = T::from((end - start) as f64).unwrap_or(T::zero());
degrees[i] = if degree < T::one() { T::one() } else { degree };
}
degrees
}
}
#[derive(Debug, Clone)]
pub struct GraphData<T: Float> {
pub adjacency: SparseAdjacency<T>,
pub node_features: Array2<T>,
pub edge_features: Option<Array2<T>>,
}
impl<T: Float> GraphData<T> {
pub fn new(
num_nodes: usize,
edges: &[(usize, usize)],
node_features: Array2<T>,
) -> GraphResult<Self> {
if node_features.nrows() != num_nodes {
return Err(NumRs2Error::ValueError(format!(
"Node features has {} rows but graph has {} nodes",
node_features.nrows(),
num_nodes
)));
}
let adjacency = SparseAdjacency::from_edges(num_nodes, edges)?;
Ok(Self {
adjacency,
node_features,
edge_features: None,
})
}
pub fn with_edge_features(mut self, edge_features: Array2<T>) -> Self {
self.edge_features = Some(edge_features);
self
}
}
pub fn mean_aggregation<T>(
adj: &SparseAdjacency<T>,
features: &ArrayView2<T>,
) -> GraphResult<Array2<T>>
where
T: Float + SimdUnifiedOps,
{
if features.nrows() != adj.num_nodes {
return Err(NumRs2Error::ValueError(format!(
"Features has {} rows but adjacency has {} nodes",
features.nrows(),
adj.num_nodes
)));
}
let (num_nodes, feat_dim) = (features.nrows(), features.ncols());
let mut aggregated = Array2::zeros((num_nodes, feat_dim));
for i in 0..num_nodes {
let (neighbors, _weights) = adj.neighbors(i)?;
if neighbors.is_empty() {
continue;
}
let num_neighbors = T::from(neighbors.len() as f64).unwrap_or(T::one());
for &neighbor in neighbors {
for j in 0..feat_dim {
aggregated[[i, j]] = aggregated[[i, j]] + features[[neighbor, j]];
}
}
for j in 0..feat_dim {
aggregated[[i, j]] = aggregated[[i, j]] / num_neighbors;
}
}
Ok(aggregated)
}
pub fn sum_aggregation<T>(
adj: &SparseAdjacency<T>,
features: &ArrayView2<T>,
) -> GraphResult<Array2<T>>
where
T: Float + SimdUnifiedOps,
{
if features.nrows() != adj.num_nodes {
return Err(NumRs2Error::ValueError(format!(
"Features has {} rows but adjacency has {} nodes",
features.nrows(),
adj.num_nodes
)));
}
let (num_nodes, feat_dim) = (features.nrows(), features.ncols());
let mut aggregated = Array2::zeros((num_nodes, feat_dim));
for i in 0..num_nodes {
let (neighbors, _weights) = adj.neighbors(i)?;
for &neighbor in neighbors {
for j in 0..feat_dim {
aggregated[[i, j]] = aggregated[[i, j]] + features[[neighbor, j]];
}
}
}
Ok(aggregated)
}
pub fn max_pooling_aggregation<T>(
adj: &SparseAdjacency<T>,
features: &ArrayView2<T>,
) -> GraphResult<Array2<T>>
where
T: Float + SimdUnifiedOps,
{
if features.nrows() != adj.num_nodes {
return Err(NumRs2Error::ValueError(format!(
"Features has {} rows but adjacency has {} nodes",
features.nrows(),
adj.num_nodes
)));
}
let (num_nodes, feat_dim) = (features.nrows(), features.ncols());
let mut aggregated = Array2::from_elem((num_nodes, feat_dim), T::neg_infinity());
for i in 0..num_nodes {
let (neighbors, _weights) = adj.neighbors(i)?;
if neighbors.is_empty() {
for j in 0..feat_dim {
aggregated[[i, j]] = T::zero();
}
continue;
}
for &neighbor in neighbors {
for j in 0..feat_dim {
let val = features[[neighbor, j]];
if val > aggregated[[i, j]] {
aggregated[[i, j]] = val;
}
}
}
}
Ok(aggregated)
}
#[derive(Debug, Clone)]
pub struct GcnLayer<T: Float> {
pub in_features: usize,
pub out_features: usize,
pub weight: Array2<T>,
pub bias: Option<Array1<T>>,
pub use_bias: bool,
}
impl<T: Float + SimdUnifiedOps + 'static> GcnLayer<T> {
pub fn new(in_features: usize, out_features: usize) -> GraphResult<Self> {
Self::new_with_bias(in_features, out_features, true)
}
pub fn new_with_bias(
in_features: usize,
out_features: usize,
use_bias: bool,
) -> GraphResult<Self> {
let scale = T::from((6.0 / (in_features + out_features) as f64).sqrt()).unwrap_or(T::one());
let weight = Array2::from_shape_fn((in_features, out_features), |(i, j)| {
let val = (((i * out_features + j) % 100) as f64 - 50.0) / 50.0;
T::from(val).unwrap_or(T::zero()) * scale
});
let bias = if use_bias {
Some(Array1::zeros(out_features))
} else {
None
};
Ok(Self {
in_features,
out_features,
weight,
bias,
use_bias,
})
}
pub fn forward(
&self,
adj: &AdjacencyMatrix<T>,
features: &ArrayView2<T>,
) -> GraphResult<Array2<T>> {
if features.ncols() != self.in_features {
return Err(NumRs2Error::ValueError(format!(
"Expected {} input features, got {}",
self.in_features,
features.ncols()
)));
}
if features.nrows() != adj.num_nodes {
return Err(NumRs2Error::ValueError(format!(
"Features has {} rows but adjacency has {} nodes",
features.nrows(),
adj.num_nodes
)));
}
let adj_self = adj.with_self_loops()?;
let norm_adj = adj_self.symmetric_normalize()?;
let ah = norm_adj.dot(features);
let output = ah.dot(&self.weight);
let mut output = output;
if let Some(ref bias) = self.bias {
for i in 0..output.nrows() {
for j in 0..output.ncols() {
output[[i, j]] = output[[i, j]] + bias[j];
}
}
}
Ok(output)
}
}
#[derive(Debug, Clone)]
pub struct GatLayer<T: Float> {
pub in_features: usize,
pub out_features: usize,
pub num_heads: usize,
pub concat: bool,
pub alpha: T, pub weights: Vec<Array2<T>>,
pub attention_weights: Vec<Array1<T>>,
}
impl<T: Float + SimdUnifiedOps + scirs2_core::ndarray::ScalarOperand> GatLayer<T> {
pub fn new(
in_features: usize,
out_features: usize,
num_heads: usize,
concat: bool,
alpha: f64,
) -> GraphResult<Self> {
if num_heads == 0 {
return Err(NumRs2Error::InvalidOperation(
"Number of attention heads must be > 0".to_string(),
));
}
let mut weights = Vec::new();
let mut attention_weights = Vec::new();
for _ in 0..num_heads {
let w = Array2::from_shape_fn((in_features, out_features), |(i, j)| {
let val = (((i * out_features + j) % 100) as f64 - 50.0) / 100.0;
T::from(val).unwrap_or(T::zero())
});
weights.push(w);
let a = Array1::from_shape_fn(2 * out_features, |i| {
let val = ((i % 100) as f64 - 50.0) / 100.0;
T::from(val).unwrap_or(T::zero())
});
attention_weights.push(a);
}
Ok(Self {
in_features,
out_features,
num_heads,
concat,
alpha: T::from(alpha).unwrap_or(T::from(0.2).unwrap_or(T::zero())),
weights,
attention_weights,
})
}
fn compute_attention(
&self,
h: &Array2<T>,
adj: &SparseAdjacency<T>,
head_idx: usize,
) -> GraphResult<HashMap<(usize, usize), T>> {
let mut attention_map = HashMap::new();
let num_nodes = h.nrows();
for i in 0..num_nodes {
let (neighbors, _) = adj.neighbors(i)?;
if neighbors.is_empty() {
continue;
}
let mut logits = Vec::new();
for &j in neighbors {
let mut concat = Array1::zeros(2 * self.out_features);
for k in 0..self.out_features {
concat[k] = h[[i, k]];
concat[k + self.out_features] = h[[j, k]];
}
let mut e_ij = T::zero();
for k in 0..2 * self.out_features {
e_ij = e_ij + self.attention_weights[head_idx][k] * concat[k];
}
if e_ij < T::zero() {
e_ij = e_ij * self.alpha;
}
logits.push((j, e_ij));
}
let max_logit = logits
.iter()
.map(|(_, e)| *e)
.fold(T::neg_infinity(), |a, b| if a > b { a } else { b });
let mut sum_exp = T::zero();
let mut exp_logits = Vec::new();
for (j, e_ij) in logits {
let exp_val = (e_ij - max_logit).exp();
sum_exp = sum_exp + exp_val;
exp_logits.push((j, exp_val));
}
for (j, exp_val) in exp_logits {
let alpha_ij = exp_val / sum_exp;
attention_map.insert((i, j), alpha_ij);
}
}
Ok(attention_map)
}
pub fn forward(
&self,
adj: &SparseAdjacency<T>,
features: &ArrayView2<T>,
) -> GraphResult<Array2<T>> {
if features.ncols() != self.in_features {
return Err(NumRs2Error::ValueError(format!(
"Expected {} input features, got {}",
self.in_features,
features.ncols()
)));
}
let num_nodes = features.nrows();
let mut head_outputs = Vec::new();
for head in 0..self.num_heads {
let h = features.dot(&self.weights[head]);
let attention = self.compute_attention(&h, adj, head)?;
let mut output = Array2::zeros((num_nodes, self.out_features));
for i in 0..num_nodes {
for j in 0..num_nodes {
if let Some(&alpha_ij) = attention.get(&(i, j)) {
for k in 0..self.out_features {
output[[i, k]] = output[[i, k]] + alpha_ij * h[[j, k]];
}
}
}
}
head_outputs.push(output);
}
let final_output = if self.concat {
let total_dim = self.num_heads * self.out_features;
let mut combined = Array2::zeros((num_nodes, total_dim));
for (head_idx, head_out) in head_outputs.iter().enumerate() {
let start_col = head_idx * self.out_features;
for i in 0..num_nodes {
for j in 0..self.out_features {
combined[[i, start_col + j]] = head_out[[i, j]];
}
}
}
combined
} else {
let mut combined = Array2::zeros((num_nodes, self.out_features));
let num_heads_t = T::from(self.num_heads as f64).unwrap_or(T::one());
for head_out in &head_outputs {
combined = combined + head_out;
}
combined / num_heads_t
};
Ok(final_output)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SageAggregator {
Mean,
Pool,
Lstm,
}
#[derive(Debug, Clone)]
pub struct GraphSageLayer<T: Float> {
pub in_features: usize,
pub out_features: usize,
pub aggregator: SageAggregator,
pub normalize: bool,
pub weight: Array2<T>,
}
impl<T: Float + SimdUnifiedOps + 'static> GraphSageLayer<T> {
pub fn new(
in_features: usize,
out_features: usize,
aggregator: SageAggregator,
normalize: bool,
) -> GraphResult<Self> {
let weight = Array2::from_shape_fn((2 * in_features, out_features), |(i, j)| {
let val = (((i * out_features + j) % 100) as f64 - 50.0) / 100.0;
T::from(val).unwrap_or(T::zero())
});
Ok(Self {
in_features,
out_features,
aggregator,
normalize,
weight,
})
}
pub fn forward(
&self,
adj: &SparseAdjacency<T>,
features: &ArrayView2<T>,
) -> GraphResult<Array2<T>> {
if features.ncols() != self.in_features {
return Err(NumRs2Error::ValueError(format!(
"Expected {} input features, got {}",
self.in_features,
features.ncols()
)));
}
let aggregated = match self.aggregator {
SageAggregator::Mean => mean_aggregation(adj, features)?,
SageAggregator::Pool => max_pooling_aggregation(adj, features)?,
SageAggregator::Lstm => mean_aggregation(adj, features)?, };
let num_nodes = features.nrows();
let mut concat = Array2::zeros((num_nodes, 2 * self.in_features));
for i in 0..num_nodes {
for j in 0..self.in_features {
concat[[i, j]] = features[[i, j]];
concat[[i, j + self.in_features]] = aggregated[[i, j]];
}
}
let mut output = concat.dot(&self.weight);
if self.normalize {
for i in 0..num_nodes {
let mut norm = T::zero();
for j in 0..self.out_features {
norm = norm + output[[i, j]] * output[[i, j]];
}
norm = norm.sqrt();
if norm > T::zero() {
for j in 0..self.out_features {
output[[i, j]] = output[[i, j]] / norm;
}
}
}
}
Ok(output)
}
}
#[derive(Debug, Clone)]
pub struct MpnnLayer<T: Float> {
pub in_features: usize,
pub out_features: usize,
pub message_weight: Array2<T>,
pub update_weight: Array2<T>,
}
impl<T: Float + SimdUnifiedOps + 'static> MpnnLayer<T> {
pub fn new(in_features: usize, out_features: usize) -> GraphResult<Self> {
let message_weight = Array2::from_shape_fn((in_features, out_features), |(i, j)| {
let val = (((i * out_features + j) % 100) as f64 - 50.0) / 100.0;
T::from(val).unwrap_or(T::zero())
});
let update_weight =
Array2::from_shape_fn((in_features + out_features, out_features), |(i, j)| {
let val = (((i * out_features + j + 13) % 100) as f64 - 50.0) / 100.0;
T::from(val).unwrap_or(T::zero())
});
Ok(Self {
in_features,
out_features,
message_weight,
update_weight,
})
}
pub fn forward(
&self,
adj: &SparseAdjacency<T>,
features: &ArrayView2<T>,
) -> GraphResult<Array2<T>> {
if features.ncols() != self.in_features {
return Err(NumRs2Error::ValueError(format!(
"Expected {} input features, got {}",
self.in_features,
features.ncols()
)));
}
let num_nodes = features.nrows();
let messages = sum_aggregation(adj, features)?;
let transformed_messages = messages.dot(&self.message_weight);
let mut concat = Array2::zeros((num_nodes, self.in_features + self.out_features));
for i in 0..num_nodes {
for j in 0..self.in_features {
concat[[i, j]] = features[[i, j]];
}
for j in 0..self.out_features {
concat[[i, j + self.in_features]] = transformed_messages[[i, j]];
}
}
let output = concat.dot(&self.update_weight);
Ok(output)
}
}
#[derive(Debug, Clone)]
pub struct GinLayer<T: Float> {
pub in_features: usize,
pub out_features: usize,
pub epsilon: T,
pub mlp_weight: Array2<T>,
}
impl<T: Float + SimdUnifiedOps + 'static> GinLayer<T> {
pub fn new(in_features: usize, out_features: usize, epsilon: f64) -> GraphResult<Self> {
let mlp_weight = Array2::from_shape_fn((in_features, out_features), |(i, j)| {
let val = (((i * out_features + j) % 100) as f64 - 50.0) / 100.0;
T::from(val).unwrap_or(T::zero())
});
Ok(Self {
in_features,
out_features,
epsilon: T::from(epsilon).unwrap_or(T::zero()),
mlp_weight,
})
}
pub fn forward(
&self,
adj: &SparseAdjacency<T>,
features: &ArrayView2<T>,
) -> GraphResult<Array2<T>> {
if features.ncols() != self.in_features {
return Err(NumRs2Error::ValueError(format!(
"Expected {} input features, got {}",
self.in_features,
features.ncols()
)));
}
let neighbor_sum = sum_aggregation(adj, features)?;
let one_plus_eps = T::one() + self.epsilon;
let num_nodes = features.nrows();
let mut combined = Array2::zeros((num_nodes, self.in_features));
for i in 0..num_nodes {
for j in 0..self.in_features {
combined[[i, j]] = one_plus_eps * features[[i, j]] + neighbor_sum[[i, j]];
}
}
let output = combined.dot(&self.mlp_weight);
Ok(output)
}
}
pub fn global_mean_pool<T>(node_features: &ArrayView2<T>) -> GraphResult<Array1<T>>
where
T: Float + SimdUnifiedOps + scirs2_core::ndarray::ScalarOperand,
{
let num_nodes = T::from(node_features.nrows() as f64).unwrap_or(T::one());
let mean = node_features.sum_axis(Axis(0)) / num_nodes;
Ok(mean)
}
pub fn global_max_pool<T>(node_features: &ArrayView2<T>) -> GraphResult<Array1<T>>
where
T: Float + SimdUnifiedOps,
{
let mut max_feat = Array1::from_elem(node_features.ncols(), T::neg_infinity());
for i in 0..node_features.nrows() {
for j in 0..node_features.ncols() {
let val = node_features[[i, j]];
if val > max_feat[j] {
max_feat[j] = val;
}
}
}
Ok(max_feat)
}
pub fn global_sum_pool<T>(node_features: &ArrayView2<T>) -> GraphResult<Array1<T>>
where
T: Float + SimdUnifiedOps,
{
Ok(node_features.sum_axis(Axis(0)))
}
pub fn topk_pool<T>(
node_features: &ArrayView2<T>,
scores: &ArrayView1<T>,
k: usize,
) -> GraphResult<Array2<T>>
where
T: Float + SimdUnifiedOps,
{
let num_nodes = node_features.nrows();
if scores.len() != num_nodes {
return Err(NumRs2Error::ValueError(format!(
"Scores length {} doesn't match number of nodes {}",
scores.len(),
num_nodes
)));
}
if k > num_nodes {
return Err(NumRs2Error::ValueError(format!(
"k={} exceeds number of nodes={}",
k, num_nodes
)));
}
let mut indexed_scores: Vec<_> = scores.iter().enumerate().map(|(i, &s)| (i, s)).collect();
indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let top_k_indices: Vec<_> = indexed_scores.iter().take(k).map(|(i, _)| *i).collect();
let feat_dim = node_features.ncols();
let mut pooled = Array2::zeros((k, feat_dim));
for (new_idx, &orig_idx) in top_k_indices.iter().enumerate() {
for j in 0..feat_dim {
pooled[[new_idx, j]] = node_features[[orig_idx, j]];
}
}
Ok(pooled)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_adjacency_matrix_creation() {
let edges = vec![(0, 1), (1, 2), (2, 0)];
let adj =
AdjacencyMatrix::<f64>::from_edges(3, &edges).expect("test: valid adjacency matrix");
assert_eq!(adj.num_nodes, 3);
assert_eq!(adj.adj[[0, 1]], 1.0);
assert_eq!(adj.adj[[1, 2]], 1.0);
assert_eq!(adj.adj[[2, 0]], 1.0);
assert_eq!(adj.adj[[0, 0]], 0.0);
}
#[test]
fn test_adjacency_with_self_loops() {
let edges = vec![(0, 1), (1, 2)];
let adj =
AdjacencyMatrix::<f64>::from_edges(3, &edges).expect("test: valid adjacency matrix");
let adj_self = adj
.with_self_loops()
.expect("test: valid self-loop addition");
assert_eq!(adj_self.adj[[0, 0]], 1.0);
assert_eq!(adj_self.adj[[1, 1]], 1.0);
assert_eq!(adj_self.adj[[2, 2]], 1.0);
}
#[test]
fn test_degree_matrix() {
let edges = vec![(0, 1), (0, 2), (1, 2)];
let adj =
AdjacencyMatrix::<f64>::from_edges(3, &edges).expect("test: valid adjacency matrix");
let degrees = adj.degree_matrix().expect("test: valid degree matrix");
assert_eq!(degrees[0], 2.0); assert_eq!(degrees[1], 1.0);
assert_eq!(degrees[2], 1.0);
}
#[test]
fn test_symmetric_normalization() {
let edges = vec![(0, 1), (1, 0)];
let adj =
AdjacencyMatrix::<f64>::from_edges(2, &edges).expect("test: valid adjacency matrix");
let norm = adj
.symmetric_normalize()
.expect("test: valid symmetric normalization");
assert!((norm[[0, 1]] - 1.0).abs() < 1e-10);
assert!((norm[[1, 0]] - 1.0).abs() < 1e-10);
}
#[test]
fn test_edge_list_creation() {
let edges = vec![(0, 1), (1, 2), (2, 0)];
let edge_list = EdgeList::<f64>::from_edges(3, &edges).expect("test: valid edge list");
assert_eq!(edge_list.num_nodes, 3);
assert_eq!(edge_list.edges.len(), 3);
}
#[test]
fn test_edge_list_out_of_bounds() {
let edges = vec![(0, 5)]; let result = EdgeList::<f64>::from_edges(3, &edges);
assert!(result.is_err());
}
#[test]
fn test_sparse_adjacency_from_edges() {
let edges = vec![(0, 1), (0, 2), (1, 2)];
let sparse =
SparseAdjacency::<f64>::from_edges(3, &edges).expect("test: valid sparse adjacency");
assert_eq!(sparse.num_nodes, 3);
assert_eq!(sparse.row_ptr.len(), 4); assert_eq!(sparse.col_indices.len(), 3);
}
#[test]
fn test_sparse_adjacency_neighbors() {
let edges = vec![(0, 1), (0, 2), (1, 2)];
let sparse =
SparseAdjacency::<f64>::from_edges(3, &edges).expect("test: valid sparse adjacency");
let (neighbors, weights) = sparse.neighbors(0).expect("test: valid neighbor retrieval");
assert_eq!(neighbors.len(), 2);
assert!(neighbors.contains(&1));
assert!(neighbors.contains(&2));
}
#[test]
fn test_sparse_adjacency_degrees() {
let edges = vec![(0, 1), (0, 2), (1, 2)];
let sparse =
SparseAdjacency::<f64>::from_edges(3, &edges).expect("test: valid sparse adjacency");
let degrees = sparse.degrees();
assert_eq!(degrees[0], 2.0);
assert_eq!(degrees[1], 1.0);
assert_eq!(degrees[2], 1.0);
}
#[test]
fn test_mean_aggregation() {
let edges = vec![(0, 1), (0, 2), (1, 2)];
let sparse =
SparseAdjacency::<f64>::from_edges(3, &edges).expect("test: valid sparse adjacency");
let features = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
let agg =
mean_aggregation(&sparse, &features.view()).expect("test: valid mean aggregation");
assert_eq!(agg[[0, 0]], 4.0);
assert_eq!(agg[[0, 1]], 5.0);
}
#[test]
fn test_sum_aggregation() {
let edges = vec![(0, 1), (0, 2)];
let sparse =
SparseAdjacency::<f64>::from_edges(3, &edges).expect("test: valid sparse adjacency");
let features = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
let agg = sum_aggregation(&sparse, &features.view()).expect("test: valid sum aggregation");
assert_eq!(agg[[0, 0]], 8.0);
assert_eq!(agg[[0, 1]], 10.0);
}
#[test]
fn test_max_pooling_aggregation() {
let edges = vec![(0, 1), (0, 2)];
let sparse =
SparseAdjacency::<f64>::from_edges(3, &edges).expect("test: valid sparse adjacency");
let features = array![[1.0, 6.0], [3.0, 4.0], [5.0, 2.0]];
let agg = max_pooling_aggregation(&sparse, &features.view())
.expect("test: valid max pooling aggregation");
assert_eq!(agg[[0, 0]], 5.0);
assert_eq!(agg[[0, 1]], 4.0);
}
#[test]
fn test_gcn_layer_creation() {
let gcn = GcnLayer::<f64>::new(10, 20).expect("test: valid GCN layer");
assert_eq!(gcn.in_features, 10);
assert_eq!(gcn.out_features, 20);
assert_eq!(gcn.weight.shape(), &[10, 20]);
assert!(gcn.use_bias);
}
#[test]
fn test_gcn_layer_forward() {
let edges = vec![(0, 1), (1, 2)];
let adj =
AdjacencyMatrix::<f64>::from_edges(3, &edges).expect("test: valid adjacency matrix");
let features = Array2::ones((3, 5));
let gcn = GcnLayer::new(5, 10).expect("test: valid GCN layer");
let output = gcn
.forward(&adj, &features.view())
.expect("test: valid GCN forward pass");
assert_eq!(output.shape(), &[3, 10]);
}
#[test]
fn test_gcn_layer_dimension_mismatch() {
let edges = vec![(0, 1)];
let adj =
AdjacencyMatrix::<f64>::from_edges(2, &edges).expect("test: valid adjacency matrix");
let features = Array2::ones((2, 10));
let gcn = GcnLayer::new(5, 10).expect("test: valid GCN layer (expects 5 input features)");
let result = gcn.forward(&adj, &features.view());
assert!(result.is_err());
}
#[test]
fn test_gat_layer_creation() {
let gat = GatLayer::<f64>::new(10, 8, 4, true, 0.2).expect("test: valid GAT layer");
assert_eq!(gat.in_features, 10);
assert_eq!(gat.out_features, 8);
assert_eq!(gat.num_heads, 4);
assert!(gat.concat);
}
#[test]
fn test_gat_layer_zero_heads() {
let result = GatLayer::<f64>::new(10, 8, 0, true, 0.2);
assert!(result.is_err());
}
#[test]
fn test_gat_layer_forward() {
let edges = vec![(0, 1), (1, 2)];
let sparse =
SparseAdjacency::<f64>::from_edges(3, &edges).expect("test: valid sparse adjacency");
let features = Array2::ones((3, 4));
let gat = GatLayer::new(4, 2, 2, true, 0.2).expect("test: valid GAT layer");
let output = gat
.forward(&sparse, &features.view())
.expect("test: valid GAT forward pass");
assert_eq!(output.shape(), &[3, 4]);
}
#[test]
fn test_gat_layer_average_heads() {
let edges = vec![(0, 1)];
let sparse =
SparseAdjacency::<f64>::from_edges(2, &edges).expect("test: valid sparse adjacency");
let features = Array2::ones((2, 4));
let gat = GatLayer::new(4, 8, 2, false, 0.2).expect("test: valid GAT layer (concat=false)");
let output = gat
.forward(&sparse, &features.view())
.expect("test: valid GAT forward pass");
assert_eq!(output.shape(), &[2, 8]);
}
#[test]
fn test_graphsage_layer_creation() {
let sage = GraphSageLayer::<f64>::new(10, 20, SageAggregator::Mean, true)
.expect("test: valid GraphSAGE layer");
assert_eq!(sage.in_features, 10);
assert_eq!(sage.out_features, 20);
assert_eq!(sage.aggregator, SageAggregator::Mean);
assert!(sage.normalize);
}
#[test]
fn test_graphsage_layer_forward() {
let edges = vec![(0, 1), (1, 2)];
let sparse =
SparseAdjacency::<f64>::from_edges(3, &edges).expect("test: valid sparse adjacency");
let features = Array2::ones((3, 5));
let sage = GraphSageLayer::new(5, 10, SageAggregator::Mean, false)
.expect("test: valid GraphSAGE layer");
let output = sage
.forward(&sparse, &features.view())
.expect("test: valid GraphSAGE forward pass");
assert_eq!(output.shape(), &[3, 10]);
}
#[test]
fn test_graphsage_pool_aggregator() {
let edges = vec![(0, 1), (0, 2)];
let sparse =
SparseAdjacency::<f64>::from_edges(3, &edges).expect("test: valid sparse adjacency");
let features = Array2::ones((3, 4));
let sage = GraphSageLayer::new(4, 8, SageAggregator::Pool, true)
.expect("test: valid GraphSAGE pool layer");
let output = sage
.forward(&sparse, &features.view())
.expect("test: valid GraphSAGE forward pass");
assert_eq!(output.shape(), &[3, 8]);
}
#[test]
fn test_mpnn_layer_creation() {
let mpnn = MpnnLayer::<f64>::new(10, 20).expect("test: valid MPNN layer");
assert_eq!(mpnn.in_features, 10);
assert_eq!(mpnn.out_features, 20);
}
#[test]
fn test_mpnn_layer_forward() {
let edges = vec![(0, 1), (1, 2)];
let sparse =
SparseAdjacency::<f64>::from_edges(3, &edges).expect("test: valid sparse adjacency");
let features = Array2::ones((3, 5));
let mpnn = MpnnLayer::new(5, 10).expect("test: valid MPNN layer");
let output = mpnn
.forward(&sparse, &features.view())
.expect("test: valid MPNN forward pass");
assert_eq!(output.shape(), &[3, 10]);
}
#[test]
fn test_gin_layer_creation() {
let gin = GinLayer::<f64>::new(10, 20, 0.0).expect("test: valid GIN layer");
assert_eq!(gin.in_features, 10);
assert_eq!(gin.out_features, 20);
assert_eq!(gin.epsilon, 0.0);
}
#[test]
fn test_gin_layer_forward() {
let edges = vec![(0, 1), (1, 2)];
let sparse =
SparseAdjacency::<f64>::from_edges(3, &edges).expect("test: valid sparse adjacency");
let features = Array2::ones((3, 5));
let gin = GinLayer::new(5, 10, 0.0).expect("test: valid GIN layer");
let output = gin
.forward(&sparse, &features.view())
.expect("test: valid GIN forward pass");
assert_eq!(output.shape(), &[3, 10]);
}
#[test]
fn test_gin_layer_with_epsilon() {
let edges = vec![(0, 1)];
let sparse =
SparseAdjacency::<f64>::from_edges(2, &edges).expect("test: valid sparse adjacency");
let features = array![[1.0, 2.0], [3.0, 4.0]];
let gin = GinLayer::new(2, 2, 0.5).expect("test: valid GIN layer");
let output = gin
.forward(&sparse, &features.view())
.expect("test: valid GIN forward pass");
assert_eq!(output.shape(), &[2, 2]);
}
#[test]
fn test_global_mean_pool() {
let features = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
let pooled = global_mean_pool(&features.view()).expect("test: valid global mean pool");
assert_eq!(pooled.len(), 2);
assert_eq!(pooled[0], 3.0); assert_eq!(pooled[1], 4.0); }
#[test]
fn test_global_max_pool() {
let features = array![[1.0, 6.0], [3.0, 4.0], [5.0, 2.0]];
let pooled = global_max_pool(&features.view()).expect("test: valid global max pool");
assert_eq!(pooled.len(), 2);
assert_eq!(pooled[0], 5.0);
assert_eq!(pooled[1], 6.0);
}
#[test]
fn test_global_sum_pool() {
let features = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
let pooled = global_sum_pool(&features.view()).expect("test: valid global sum pool");
assert_eq!(pooled.len(), 2);
assert_eq!(pooled[0], 9.0);
assert_eq!(pooled[1], 12.0);
}
#[test]
fn test_topk_pool() {
let features = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
let scores = array![0.1, 0.4, 0.2, 0.5];
let pooled = topk_pool(&features.view(), &scores.view(), 2).expect("test: valid topk pool");
assert_eq!(pooled.shape(), &[2, 2]);
assert_eq!(pooled[[0, 0]], 7.0); assert_eq!(pooled[[0, 1]], 8.0);
}
#[test]
fn test_topk_pool_k_exceeds_nodes() {
let features = array![[1.0, 2.0], [3.0, 4.0]];
let scores = array![0.1, 0.4];
let result = topk_pool(&features.view(), &scores.view(), 5);
assert!(result.is_err());
}
#[test]
fn test_topk_pool_score_mismatch() {
let features = array![[1.0, 2.0], [3.0, 4.0]];
let scores = array![0.1]; let result = topk_pool(&features.view(), &scores.view(), 1);
assert!(result.is_err());
}
#[test]
fn test_graph_data_creation() {
let edges = vec![(0, 1), (1, 2)];
let features = Array2::<f64>::ones((3, 5));
let graph = GraphData::new(3, &edges, features).expect("test: valid graph data creation");
assert_eq!(graph.adjacency.num_nodes, 3);
assert_eq!(graph.node_features.shape(), &[3, 5]);
assert!(graph.edge_features.is_none());
}
#[test]
fn test_graph_data_feature_mismatch() {
let edges = vec![(0, 1)];
let features = Array2::<f64>::ones((3, 5)); let result = GraphData::new(2, &edges, features); assert!(result.is_err());
}
#[test]
fn test_graph_data_with_edge_features() {
let edges = vec![(0, 1), (1, 2)];
let features = Array2::<f64>::ones((3, 5));
let edge_features = Array2::<f64>::ones((2, 3));
let graph = GraphData::new(3, &edges, features)
.expect("test: valid graph data creation")
.with_edge_features(edge_features);
assert!(graph.edge_features.is_some());
assert_eq!(
graph
.edge_features
.expect("test: edge features are some")
.shape(),
&[2, 3]
);
}
#[test]
fn test_empty_graph() {
let edges: Vec<(usize, usize)> = vec![];
let adj =
AdjacencyMatrix::<f64>::from_edges(3, &edges).expect("test: valid adjacency matrix");
assert_eq!(adj.num_nodes, 3);
for i in 0..3 {
for j in 0..3 {
assert_eq!(adj.adj[[i, j]], 0.0);
}
}
}
#[test]
fn test_self_loop_graph() {
let edges = vec![(0, 0), (1, 1), (2, 2)];
let adj =
AdjacencyMatrix::<f64>::from_edges(3, &edges).expect("test: valid adjacency matrix");
assert_eq!(adj.adj[[0, 0]], 1.0);
assert_eq!(adj.adj[[1, 1]], 1.0);
assert_eq!(adj.adj[[2, 2]], 1.0);
}
#[test]
fn test_complete_graph() {
let edges = vec![(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)];
let sparse =
SparseAdjacency::<f64>::from_edges(3, &edges).expect("test: valid sparse adjacency");
let degrees = sparse.degrees();
assert_eq!(degrees[0], 2.0);
assert_eq!(degrees[1], 2.0);
assert_eq!(degrees[2], 2.0);
}
#[test]
fn test_aggregation_isolated_node() {
let edges = vec![(0, 1)];
let sparse =
SparseAdjacency::<f64>::from_edges(3, &edges).expect("test: valid sparse adjacency");
let features = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
let agg =
mean_aggregation(&sparse, &features.view()).expect("test: valid mean aggregation");
assert_eq!(agg[[2, 0]], 0.0);
assert_eq!(agg[[2, 1]], 0.0);
}
}