use crate::error::{NeuralError, Result};
use crate::layers::Layer;
use scirs2_core::ndarray::{Array, Array1, Array2, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use scirs2_core::random::{Distribution, Uniform};
use std::fmt::Debug;
use std::sync::{Arc, RwLock};
fn xavier_init<F, R>(fan_in: usize, fan_out: usize, rng: &mut R) -> Result<Array2<F>>
where
F: Float + Debug + NumAssign,
R: scirs2_core::random::Rng,
{
let limit = (6.0_f64 / (fan_in + fan_out) as f64).sqrt();
let dist = Uniform::new(-limit, limit).map_err(|e| {
NeuralError::InvalidArchitecture(format!("Failed to create Uniform distribution: {e}"))
})?;
let data: Vec<F> = (0..fan_in * fan_out)
.map(|_| F::from(dist.sample(rng)).unwrap_or(F::zero()))
.collect();
Array2::from_shape_vec((fan_in, fan_out), data).map_err(|e| {
NeuralError::InvalidArchitecture(format!("Failed to create weight matrix: {e}"))
})
}
#[derive(Debug, Clone)]
pub struct GinLayer<F: Float + Debug + Send + Sync + NumAssign> {
weight1: Array2<F>,
bias1: Array1<F>,
weight2: Array2<F>,
bias2: Array1<F>,
learn_epsilon: bool,
epsilon: F,
in_features: usize,
out_features: usize,
hidden_dim: usize,
dweight1: Arc<RwLock<Array2<F>>>,
dbias1: Arc<RwLock<Array1<F>>>,
dweight2: Arc<RwLock<Array2<F>>>,
dbias2: Arc<RwLock<Array1<F>>>,
depsilon: Arc<RwLock<F>>,
}
impl<F: Float + Debug + Send + Sync + ScalarOperand + NumAssign + 'static> GinLayer<F> {
pub fn new<R: scirs2_core::random::Rng>(
in_features: usize,
out_features: usize,
hidden_dim: usize,
learn_epsilon: bool,
init_epsilon: f64,
rng: &mut R,
) -> Result<Self> {
if in_features == 0 || out_features == 0 || hidden_dim == 0 {
return Err(NeuralError::InvalidArchitecture(
"in_features, out_features, and hidden_dim must all be > 0".to_string(),
));
}
let weight1 = xavier_init(in_features, hidden_dim, rng)?;
let bias1 = Array1::zeros(hidden_dim);
let weight2 = xavier_init(hidden_dim, out_features, rng)?;
let bias2 = Array1::zeros(out_features);
let epsilon = F::from(init_epsilon).unwrap_or(F::zero());
Ok(Self {
weight1,
bias1,
weight2,
bias2,
learn_epsilon,
epsilon,
in_features,
out_features,
hidden_dim,
dweight1: Arc::new(RwLock::new(Array2::zeros((in_features, hidden_dim)))),
dbias1: Arc::new(RwLock::new(Array1::zeros(hidden_dim))),
dweight2: Arc::new(RwLock::new(Array2::zeros((hidden_dim, out_features)))),
dbias2: Arc::new(RwLock::new(Array1::zeros(out_features))),
depsilon: Arc::new(RwLock::new(F::zero())),
})
}
#[inline]
fn relu(x: F) -> F {
if x > F::zero() { x } else { F::zero() }
}
fn sum_aggregate(
features: &Array2<F>,
adjacency: &Array2<F>,
epsilon: F,
in_features: usize,
) -> Result<Array2<F>> {
let n = features.nrows();
if adjacency.nrows() != n || adjacency.ncols() != n {
return Err(NeuralError::InvalidArgument(format!(
"Adjacency must be {}×{} but got {}×{}",
n,
n,
adjacency.nrows(),
adjacency.ncols()
)));
}
let one_plus_eps = F::one() + epsilon;
let mut agg = Array2::<F>::zeros((n, in_features));
for v in 0..n {
for k in 0..in_features {
agg[[v, k]] = one_plus_eps * features[[v, k]];
}
for u in 0..n {
if adjacency[[v, u]] != F::zero() {
for k in 0..in_features {
agg[[v, k]] += features[[u, k]];
}
}
}
}
Ok(agg)
}
fn apply_mlp(&self, agg: &Array2<F>) -> Array2<F> {
let n = agg.nrows();
let mut hidden = Array2::<F>::zeros((n, self.hidden_dim));
for v in 0..n {
for h_idx in 0..self.hidden_dim {
let mut val = self.bias1[h_idx];
for k in 0..self.in_features {
val += agg[[v, k]] * self.weight1[[k, h_idx]];
}
hidden[[v, h_idx]] = Self::relu(val);
}
}
let mut out = Array2::<F>::zeros((n, self.out_features));
for v in 0..n {
for j in 0..self.out_features {
let mut val = self.bias2[j];
for h_idx in 0..self.hidden_dim {
val += hidden[[v, h_idx]] * self.weight2[[h_idx, j]];
}
out[[v, j]] = val;
}
}
out
}
pub fn forward_graph(
&self,
features: &Array2<F>,
adjacency: &Array2<F>,
epsilon: f64,
) -> Result<Array2<F>> {
if features.ncols() != self.in_features {
return Err(NeuralError::DimensionMismatch(format!(
"Expected in_features={}, got {}",
self.in_features,
features.ncols()
)));
}
let eps = if self.learn_epsilon {
self.epsilon
} else {
F::from(epsilon).unwrap_or(F::zero())
};
let agg = Self::sum_aggregate(features, adjacency, eps, self.in_features)?;
Ok(self.apply_mlp(&agg))
}
pub fn num_parameters(&self) -> usize {
let mlp_params = self.in_features * self.hidden_dim + self.hidden_dim + self.hidden_dim * self.out_features + self.out_features; mlp_params + if self.learn_epsilon { 1 } else { 0 }
}
}
impl<F: Float + Debug + Send + Sync + ScalarOperand + NumAssign + 'static> Layer<F>
for GinLayer<F>
{
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let f2 = input
.clone()
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
.map_err(|e| NeuralError::DimensionMismatch(format!("Expected 2D input: {e}")))?;
let n = f2.nrows();
let adj = Array2::<F>::eye(n);
self.forward_graph(&f2, &adj, 0.0).map(|a| a.into_dyn())
}
fn backward(
&self,
_input: &Array<F, IxDyn>,
grad: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
Ok(grad.clone())
}
fn update(&mut self, lr: F) -> Result<()> {
{
let dw1 = self
.dweight1
.read()
.map_err(|e| NeuralError::ComputationError(format!("RwLock poisoned: {e}")))?;
for (w, g) in self.weight1.iter_mut().zip(dw1.iter()) {
*w -= lr * *g;
}
}
{
let db1 = self
.dbias1
.read()
.map_err(|e| NeuralError::ComputationError(format!("RwLock poisoned: {e}")))?;
for (b, g) in self.bias1.iter_mut().zip(db1.iter()) {
*b -= lr * *g;
}
}
{
let dw2 = self
.dweight2
.read()
.map_err(|e| NeuralError::ComputationError(format!("RwLock poisoned: {e}")))?;
for (w, g) in self.weight2.iter_mut().zip(dw2.iter()) {
*w -= lr * *g;
}
}
{
let db2 = self
.dbias2
.read()
.map_err(|e| NeuralError::ComputationError(format!("RwLock poisoned: {e}")))?;
for (b, g) in self.bias2.iter_mut().zip(db2.iter()) {
*b -= lr * *g;
}
}
if self.learn_epsilon {
let de = self
.depsilon
.read()
.map_err(|e| NeuralError::ComputationError(format!("RwLock poisoned: {e}")))?;
self.epsilon -= lr * *de;
}
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn layer_type(&self) -> &str {
"GIN"
}
fn parameter_count(&self) -> usize {
self.num_parameters()
}
fn layer_description(&self) -> String {
format!(
"type:GIN in:{} hidden:{} out:{} learn_eps:{} eps:{:?}",
self.in_features, self.hidden_dim, self.out_features, self.learn_epsilon, self.epsilon
)
}
fn params(&self) -> Vec<Array<F, IxDyn>> {
vec![
self.weight1.clone().into_dyn(),
self.bias1.clone().into_dyn(),
self.weight2.clone().into_dyn(),
self.bias2.clone().into_dyn(),
]
}
}
#[derive(Debug, Clone)]
pub struct GinConv<F: Float + Debug + Send + Sync + NumAssign> {
weight: Array2<F>,
bias: Array1<F>,
in_features: usize,
out_features: usize,
dweight: Arc<RwLock<Array2<F>>>,
dbias: Arc<RwLock<Array1<F>>>,
}
impl<F: Float + Debug + Send + Sync + ScalarOperand + NumAssign + 'static> GinConv<F> {
pub fn new<R: scirs2_core::random::Rng>(
in_features: usize,
out_features: usize,
rng: &mut R,
) -> Result<Self> {
if in_features == 0 || out_features == 0 {
return Err(NeuralError::InvalidArchitecture(
"in_features and out_features must be > 0".to_string(),
));
}
let weight = xavier_init(in_features, out_features, rng)?;
let bias = Array1::zeros(out_features);
Ok(Self {
weight,
bias,
in_features,
out_features,
dweight: Arc::new(RwLock::new(Array2::zeros((in_features, out_features)))),
dbias: Arc::new(RwLock::new(Array1::zeros(out_features))),
})
}
pub fn forward_graph(
&self,
features: &Array2<F>,
adjacency: &Array2<F>,
epsilon: f64,
) -> Result<Array2<F>> {
let n = features.nrows();
if features.ncols() != self.in_features {
return Err(NeuralError::DimensionMismatch(format!(
"Expected in_features={}, got {}",
self.in_features,
features.ncols()
)));
}
if adjacency.nrows() != n || adjacency.ncols() != n {
return Err(NeuralError::InvalidArgument(format!(
"Adjacency must be {}×{} but got {}×{}",
n,
n,
adjacency.nrows(),
adjacency.ncols()
)));
}
let eps = F::from(epsilon).unwrap_or(F::zero());
let one_plus_eps = F::one() + eps;
let mut agg = Array2::<F>::zeros((n, self.in_features));
for v in 0..n {
for k in 0..self.in_features {
agg[[v, k]] = one_plus_eps * features[[v, k]];
}
for u in 0..n {
if adjacency[[v, u]] != F::zero() {
for k in 0..self.in_features {
agg[[v, k]] += features[[u, k]];
}
}
}
}
let mut out = Array2::<F>::zeros((n, self.out_features));
for v in 0..n {
for j in 0..self.out_features {
let mut val = self.bias[j];
for k in 0..self.in_features {
val += agg[[v, k]] * self.weight[[k, j]];
}
out[[v, j]] = if val > F::zero() { val } else { F::zero() };
}
}
Ok(out)
}
pub fn num_parameters(&self) -> usize {
self.in_features * self.out_features + self.out_features
}
}
impl<F: Float + Debug + Send + Sync + ScalarOperand + NumAssign + 'static> Layer<F>
for GinConv<F>
{
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let f2 = input
.clone()
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
.map_err(|e| NeuralError::DimensionMismatch(format!("Expected 2D input: {e}")))?;
let n = f2.nrows();
let adj = Array2::<F>::eye(n);
self.forward_graph(&f2, &adj, 0.0).map(|a| a.into_dyn())
}
fn backward(
&self,
_input: &Array<F, IxDyn>,
grad: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
Ok(grad.clone())
}
fn update(&mut self, lr: F) -> Result<()> {
{
let dw = self
.dweight
.read()
.map_err(|e| NeuralError::ComputationError(format!("RwLock poisoned: {e}")))?;
for (w, g) in self.weight.iter_mut().zip(dw.iter()) {
*w -= lr * *g;
}
}
{
let db = self
.dbias
.read()
.map_err(|e| NeuralError::ComputationError(format!("RwLock poisoned: {e}")))?;
for (b, g) in self.bias.iter_mut().zip(db.iter()) {
*b -= lr * *g;
}
}
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn layer_type(&self) -> &str {
"GinConv"
}
fn parameter_count(&self) -> usize {
self.num_parameters()
}
fn layer_description(&self) -> String {
format!(
"type:GinConv in:{} out:{}",
self.in_features, self.out_features
)
}
fn params(&self) -> Vec<Array<F, IxDyn>> {
vec![
self.weight.clone().into_dyn(),
self.bias.clone().into_dyn(),
]
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
use scirs2_core::random::rng;
#[test]
fn test_gin_layer_output_shape() {
let mut rng = rng();
let layer = GinLayer::<f64>::new(4, 8, 16, false, 0.0, &mut rng)
.expect("Failed to create GIN layer");
let features = Array2::<f64>::from_elem((5, 4), 0.1);
let adj = Array2::<f64>::eye(5);
let out = layer
.forward_graph(&features, &adj, 0.0)
.expect("Forward failed");
assert_eq!(out.shape(), &[5, 8]);
}
#[test]
fn test_gin_layer_with_adjacency() {
let mut rng = rng();
let layer = GinLayer::<f64>::new(3, 6, 12, false, 0.0, &mut rng)
.expect("Failed to create GIN layer");
let features = Array2::<f64>::from_elem((4, 3), 0.5);
let adj = Array2::<f64>::from_shape_vec(
(4, 4),
vec![
0.0_f64, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0,
0.0,
],
)
.expect("shape");
let out = layer
.forward_graph(&features, &adj, 0.0)
.expect("Forward failed");
assert_eq!(out.shape(), &[4, 6]);
assert!(out.iter().all(|&v| v.is_finite()));
}
#[test]
fn test_gin_layer_epsilon_effect() {
let mut rng = rng();
let layer = GinLayer::<f64>::new(2, 4, 8, false, 0.0, &mut rng)
.expect("Failed to create GIN layer");
let features = Array2::<f64>::from_elem((3, 2), 1.0);
let adj = Array2::<f64>::eye(3);
let out_eps0 = layer
.forward_graph(&features, &adj, 0.0)
.expect("Forward eps0");
let out_eps1 = layer
.forward_graph(&features, &adj, 1.0)
.expect("Forward eps1");
let changed = out_eps0
.iter()
.zip(out_eps1.iter())
.any(|(a, b)| (a - b).abs() > 1e-10);
assert!(changed, "ε should change the output");
}
#[test]
fn test_gin_layer_learn_epsilon_uses_stored() {
let mut rng = rng();
let layer = GinLayer::<f64>::new(2, 4, 8, true, 0.5, &mut rng)
.expect("Failed to create GIN layer");
let features = Array2::<f64>::from_elem((3, 2), 1.0);
let adj = Array2::<f64>::eye(3);
let out_stored = layer
.forward_graph(&features, &adj, 0.0)
.expect("Forward stored");
let out_ignored = layer
.forward_graph(&features, &adj, 99.0)
.expect("Forward ignored");
let all_equal = out_stored
.iter()
.zip(out_ignored.iter())
.all(|(a, b)| (a - b).abs() < 1e-10);
assert!(
all_equal,
"learn_epsilon=true should ignore the explicit epsilon argument"
);
}
#[test]
fn test_gin_layer_parameter_count() {
let mut rng = rng();
let layer = GinLayer::<f64>::new(4, 8, 16, true, 0.0, &mut rng)
.expect("Failed to create GIN layer");
assert_eq!(layer.num_parameters(), 4 * 16 + 16 + 16 * 8 + 8 + 1);
}
#[test]
fn test_gin_layer_no_learn_epsilon_parameter_count() {
let mut rng = rng();
let layer = GinLayer::<f64>::new(4, 8, 16, false, 0.0, &mut rng)
.expect("Failed to create GIN layer");
assert_eq!(layer.num_parameters(), 4 * 16 + 16 + 16 * 8 + 8);
}
#[test]
fn test_gin_layer_zero_input_relu_output() {
let mut rng = rng();
let layer = GinLayer::<f64>::new(3, 6, 12, false, 0.0, &mut rng)
.expect("Failed to create GIN layer");
let features = Array2::<f64>::zeros((4, 3));
let adj = Array2::<f64>::from_elem((4, 4), 1.0);
let out = layer
.forward_graph(&features, &adj, 0.0)
.expect("Forward failed");
assert!(
out.iter().all(|&v| v == 0.0),
"Zero input with zero biases should yield zero output"
);
}
#[test]
fn test_gin_layer_wrong_features_error() {
let mut rng = rng();
let layer = GinLayer::<f64>::new(4, 8, 16, false, 0.0, &mut rng)
.expect("Failed to create GIN layer");
let features = Array2::<f64>::from_elem((5, 3), 0.1); let adj = Array2::<f64>::eye(5);
assert!(layer.forward_graph(&features, &adj, 0.0).is_err());
}
#[test]
fn test_gin_layer_trait_forward() {
let mut rng = rng();
let layer = GinLayer::<f64>::new(4, 4, 8, false, 0.0, &mut rng)
.expect("Failed to create GIN layer");
let input = Array2::<f64>::from_elem((3, 4), 0.3).into_dyn();
let out = layer.forward(&input).expect("Layer trait forward failed");
assert_eq!(out.shape(), &[3, 4]);
}
#[test]
fn test_gin_conv_output_shape() {
let mut rng = rng();
let layer =
GinConv::<f64>::new(4, 8, &mut rng).expect("Failed to create GinConv");
let features = Array2::<f64>::from_elem((6, 4), 0.2);
let adj = Array2::<f64>::eye(6);
let out = layer
.forward_graph(&features, &adj, 0.0)
.expect("Forward failed");
assert_eq!(out.shape(), &[6, 8]);
}
#[test]
fn test_gin_conv_non_negative_output() {
let mut rng = rng();
let layer =
GinConv::<f64>::new(3, 6, &mut rng).expect("Failed to create GinConv");
let features = Array2::<f64>::from_elem((5, 3), 0.5);
let adj = Array2::<f64>::from_elem((5, 5), 1.0);
let out = layer
.forward_graph(&features, &adj, 0.0)
.expect("Forward failed");
assert!(
out.iter().all(|&v| v >= 0.0),
"GinConv output must be non-negative (ReLU)"
);
}
#[test]
fn test_gin_conv_epsilon_changes_output() {
let mut rng = rng();
let layer =
GinConv::<f64>::new(2, 4, &mut rng).expect("Failed to create GinConv");
let features = Array2::<f64>::from_elem((3, 2), 1.0);
let adj = Array2::<f64>::eye(3);
let out0 = layer
.forward_graph(&features, &adj, 0.0)
.expect("fwd eps0");
let out1 = layer
.forward_graph(&features, &adj, 2.0)
.expect("fwd eps2");
let changed = out0
.iter()
.zip(out1.iter())
.any(|(a, b)| (a - b).abs() > 1e-10);
assert!(changed, "ε should change GinConv output");
}
#[test]
fn test_gin_conv_parameter_count() {
let mut rng = rng();
let layer =
GinConv::<f64>::new(4, 8, &mut rng).expect("Failed to create GinConv");
assert_eq!(layer.num_parameters(), 4 * 8 + 8);
}
#[test]
fn test_gin_conv_wrong_adj_error() {
let mut rng = rng();
let layer =
GinConv::<f64>::new(4, 8, &mut rng).expect("Failed to create GinConv");
let features = Array2::<f64>::from_elem((4, 4), 0.1);
let adj = Array2::<f64>::eye(3); assert!(layer.forward_graph(&features, &adj, 0.0).is_err());
}
#[test]
fn test_gin_conv_layer_trait_forward() {
let mut rng = rng();
let layer =
GinConv::<f64>::new(4, 6, &mut rng).expect("Failed to create GinConv");
let input = Array2::<f64>::from_elem((5, 4), 0.2).into_dyn();
let out = layer.forward(&input).expect("Layer trait forward failed");
assert_eq!(out.shape(), &[5, 6]);
}
#[test]
fn test_gin_conv_fully_connected_graph() {
let mut rng = rng();
let layer =
GinConv::<f64>::new(3, 5, &mut rng).expect("Failed to create GinConv");
let features = Array2::<f64>::from_elem((4, 3), 0.3);
let adj = Array2::<f64>::from_elem((4, 4), 1.0);
let out = layer
.forward_graph(&features, &adj, 0.0)
.expect("Forward failed");
let row0 = out.row(0).to_owned();
for i in 1..4 {
let diff: f64 = row0
.iter()
.zip(out.row(i).iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0_f64, f64::max);
assert!(
diff < 1e-10,
"Row {i} differs from row 0 by {diff} in a fully connected uniform graph"
);
}
}
#[test]
fn test_sum_aggregation_correctness() {
let mut rng = rng();
let layer = GinConv::<f64>::new(2, 2, &mut rng).expect("GinConv");
let features = Array2::<f64>::from_shape_vec(
(3, 2),
vec![1.0_f64, 0.0, 0.0, 1.0, 1.0, 1.0],
)
.expect("shape");
let adj = Array2::<f64>::from_shape_vec(
(3, 3),
vec![0.0_f64, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0],
)
.expect("shape");
let out = layer
.forward_graph(&features, &adj, 0.0)
.expect("Forward failed");
assert_eq!(out.shape(), &[3, 2]);
assert!(out.iter().all(|&v| v.is_finite()));
}
}