use scivex_core::random::Rng;
use scivex_core::{Float, Tensor};
use crate::error::{NnError, Result};
use crate::init;
use crate::ops;
use crate::variable::Variable;
pub struct GCNConv<T: Float> {
weight: Variable<T>,
bias: Option<Variable<T>>,
}
impl<T: Float> GCNConv<T> {
pub fn new(in_features: usize, out_features: usize, use_bias: bool, rng: &mut Rng) -> Self {
let w_data = init::kaiming_uniform::<T>(&[in_features, out_features], rng);
let weight = Variable::new(w_data, true);
let bias = if use_bias {
Some(Variable::new(Tensor::zeros(vec![out_features]), true))
} else {
None
};
Self { weight, bias }
}
pub fn forward(&self, x: &Variable<T>, adj: &Tensor<T>) -> Result<Variable<T>> {
let x_shape = x.shape();
if x_shape.len() != 2 {
return Err(NnError::ShapeMismatch {
expected: vec![0, 0],
got: x_shape,
});
}
let n = x_shape[0];
let adj_shape = adj.shape();
if adj_shape != [n, n] {
return Err(NnError::ShapeMismatch {
expected: vec![n, n],
got: adj_shape.to_vec(),
});
}
let adj_norm = normalize_adjacency(adj)?;
let adj_var = Variable::new(adj_norm, false);
let ax = ops::matmul(&adj_var, x);
let y = ops::matmul(&ax, &self.weight);
match &self.bias {
Some(b) => Ok(ops::add_bias(&y, b)),
None => Ok(y),
}
}
pub fn parameters(&self) -> Vec<Variable<T>> {
let mut params = vec![self.weight.clone()];
if let Some(b) = &self.bias {
params.push(b.clone());
}
params
}
}
fn normalize_adjacency<T: Float>(adj: &Tensor<T>) -> Result<Tensor<T>> {
let n = adj.shape()[0];
let adj_slice = adj.as_slice();
let mut a_hat = adj_slice.to_vec();
for i in 0..n {
a_hat[i * n + i] += T::one();
}
let mut deg = vec![T::zero(); n];
for i in 0..n {
for j in 0..n {
deg[i] += a_hat[i * n + j];
}
}
let deg_inv_sqrt: Vec<T> = deg
.iter()
.map(|&d| {
if d > T::zero() {
T::one() / d.sqrt()
} else {
T::zero()
}
})
.collect();
let mut normed = vec![T::zero(); n * n];
for i in 0..n {
for j in 0..n {
normed[i * n + j] = deg_inv_sqrt[i] * a_hat[i * n + j] * deg_inv_sqrt[j];
}
}
Tensor::from_vec(normed, vec![n, n]).map_err(NnError::CoreError)
}
pub struct GATConv<T: Float> {
weight: Variable<T>,
attn_left: Variable<T>,
attn_right: Variable<T>,
#[allow(dead_code)]
num_heads: usize,
out_features: usize,
}
impl<T: Float> GATConv<T> {
pub fn new(in_features: usize, out_features: usize, rng: &mut Rng) -> Self {
let w_data = init::kaiming_uniform::<T>(&[in_features, out_features], rng);
let weight = Variable::new(w_data, true);
let attn_l_data = init::kaiming_uniform::<T>(&[out_features, 1], rng);
let attn_left = Variable::new(
Tensor::from_vec(attn_l_data.as_slice().to_vec(), vec![out_features])
.expect("reshape attn_left"),
true,
);
let attn_r_data = init::kaiming_uniform::<T>(&[out_features, 1], rng);
let attn_right = Variable::new(
Tensor::from_vec(attn_r_data.as_slice().to_vec(), vec![out_features])
.expect("reshape attn_right"),
true,
);
Self {
weight,
attn_left,
attn_right,
num_heads: 1,
out_features,
}
}
pub fn forward(&self, x: &Variable<T>, adj: &Tensor<T>) -> Result<Variable<T>> {
let x_shape = x.shape();
if x_shape.len() != 2 {
return Err(NnError::ShapeMismatch {
expected: vec![0, 0],
got: x_shape,
});
}
let n = x_shape[0];
let adj_shape = adj.shape();
if adj_shape != [n, n] {
return Err(NnError::ShapeMismatch {
expected: vec![n, n],
got: adj_shape.to_vec(),
});
}
let h = ops::matmul(x, &self.weight);
let h_data = h.data();
let h_slice = h_data.as_slice();
let out_f = self.out_features;
let al = self.attn_left.data();
let ar = self.attn_right.data();
let al_slice = al.as_slice();
let ar_slice = ar.as_slice();
let mut left_scores = vec![T::zero(); n];
let mut right_scores = vec![T::zero(); n];
for i in 0..n {
let mut sl = T::zero();
let mut sr = T::zero();
for f in 0..out_f {
sl += h_slice[i * out_f + f] * al_slice[f];
sr += h_slice[i * out_f + f] * ar_slice[f];
}
left_scores[i] = sl;
right_scores[i] = sr;
}
let adj_slice = adj.as_slice();
let neg_slope = T::from_f64(0.2);
let neg_inf = T::from_f64(-1e9);
let mut attn_scores = vec![T::zero(); n * n];
for i in 0..n {
for j in 0..n {
let connected = adj_slice[i * n + j] > T::zero() || i == j; if connected {
let e = left_scores[i] + right_scores[j];
attn_scores[i * n + j] = if e > T::zero() { e } else { neg_slope * e };
} else {
attn_scores[i * n + j] = neg_inf;
}
}
}
for i in 0..n {
let row = &mut attn_scores[i * n..(i + 1) * n];
let max = row.iter().copied().fold(T::neg_infinity(), T::max);
let mut sum = T::zero();
for v in row.iter_mut() {
*v = (*v - max).exp();
sum += *v;
}
if sum > T::zero() {
for v in row.iter_mut() {
*v /= sum;
}
}
}
let alpha_tensor = Tensor::from_vec(attn_scores, vec![n, n]).map_err(NnError::CoreError)?;
let alpha_var = Variable::new(alpha_tensor, false);
let out = ops::matmul(&alpha_var, &h);
Ok(out)
}
pub fn parameters(&self) -> Vec<Variable<T>> {
vec![
self.weight.clone(),
self.attn_left.clone(),
self.attn_right.clone(),
]
}
}
pub struct SAGEConv<T: Float> {
weight: Variable<T>,
bias: Option<Variable<T>>,
in_features: usize,
}
impl<T: Float> SAGEConv<T> {
pub fn new(in_features: usize, out_features: usize, use_bias: bool, rng: &mut Rng) -> Self {
let w_data = init::kaiming_uniform::<T>(&[2 * in_features, out_features], rng);
let weight = Variable::new(w_data, true);
let bias = if use_bias {
Some(Variable::new(Tensor::zeros(vec![out_features]), true))
} else {
None
};
Self {
weight,
bias,
in_features,
}
}
pub fn forward(&self, x: &Variable<T>, adj: &Tensor<T>) -> Result<Variable<T>> {
let x_shape = x.shape();
if x_shape.len() != 2 {
return Err(NnError::ShapeMismatch {
expected: vec![0, 0],
got: x_shape,
});
}
let n = x_shape[0];
let in_f = x_shape[1];
let adj_shape = adj.shape();
if adj_shape != [n, n] {
return Err(NnError::ShapeMismatch {
expected: vec![n, n],
got: adj_shape.to_vec(),
});
}
let adj_norm = row_normalize(adj)?;
let adj_var = Variable::new(adj_norm, false);
let neigh = ops::matmul(&adj_var, x);
let x_data = x.data();
let neigh_data = neigh.data();
let x_slice = x_data.as_slice();
let neigh_slice = neigh_data.as_slice();
let mut concat_data = Vec::with_capacity(n * 2 * in_f);
for i in 0..n {
for f in 0..in_f {
concat_data.push(x_slice[i * in_f + f]);
}
for f in 0..in_f {
concat_data.push(neigh_slice[i * in_f + f]);
}
}
let concat_tensor =
Tensor::from_vec(concat_data, vec![n, 2 * in_f]).map_err(NnError::CoreError)?;
let in_features = self.in_features;
let concat_var = Variable::from_op(
concat_tensor,
vec![x.clone(), neigh],
Box::new(move |g: &Tensor<T>| {
let g_slice = g.as_slice();
let rows = g.shape()[0];
let mut gx = Vec::with_capacity(rows * in_features);
let mut gn = Vec::with_capacity(rows * in_features);
for i in 0..rows {
for f in 0..in_features {
gx.push(g_slice[i * 2 * in_features + f]);
}
for f in 0..in_features {
gn.push(g_slice[i * 2 * in_features + in_features + f]);
}
}
let grad_x =
Tensor::from_vec(gx, vec![rows, in_features]).expect("grad shape matches");
let grad_neigh =
Tensor::from_vec(gn, vec![rows, in_features]).expect("grad shape matches");
vec![grad_x, grad_neigh]
}),
);
let y = ops::matmul(&concat_var, &self.weight);
match &self.bias {
Some(b) => Ok(ops::add_bias(&y, b)),
None => Ok(y),
}
}
pub fn parameters(&self) -> Vec<Variable<T>> {
let mut params = vec![self.weight.clone()];
if let Some(b) = &self.bias {
params.push(b.clone());
}
params
}
}
fn row_normalize<T: Float>(adj: &Tensor<T>) -> Result<Tensor<T>> {
let n = adj.shape()[0];
let adj_slice = adj.as_slice();
let mut normed = adj_slice.to_vec();
for i in 0..n {
let mut row_sum = T::zero();
for j in 0..n {
row_sum += normed[i * n + j];
}
if row_sum > T::zero() {
for j in 0..n {
normed[i * n + j] /= row_sum;
}
}
}
Tensor::from_vec(normed, vec![n, n]).map_err(NnError::CoreError)
}
#[cfg(test)]
mod tests {
use super::*;
use scivex_core::Tensor;
fn simple_adj() -> Tensor<f64> {
Tensor::from_vec(
vec![
0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, ],
vec![3, 3],
)
.unwrap()
}
#[test]
fn test_gcn_forward_shape() {
let mut rng = Rng::new(42);
let gcn = GCNConv::<f64>::new(4, 2, true, &mut rng);
let x = Variable::new(Tensor::ones(vec![3, 4]), true);
let adj = simple_adj();
let y = gcn.forward(&x, &adj).unwrap();
assert_eq!(y.shape(), vec![3, 2]);
}
#[test]
fn test_gcn_parameters() {
let mut rng = Rng::new(42);
let gcn_bias = GCNConv::<f64>::new(4, 2, true, &mut rng);
assert_eq!(gcn_bias.parameters().len(), 2);
let gcn_no_bias = GCNConv::<f64>::new(4, 2, false, &mut rng);
assert_eq!(gcn_no_bias.parameters().len(), 1); }
#[test]
fn test_gat_forward_shape() {
let mut rng = Rng::new(42);
let gat = GATConv::<f64>::new(4, 2, &mut rng);
let x = Variable::new(Tensor::ones(vec![3, 4]), true);
let adj = simple_adj();
let y = gat.forward(&x, &adj).unwrap();
assert_eq!(y.shape(), vec![3, 2]);
}
#[test]
fn test_sage_forward_shape() {
let mut rng = Rng::new(42);
let sage = SAGEConv::<f64>::new(4, 2, true, &mut rng);
let x = Variable::new(Tensor::ones(vec![3, 4]), true);
let adj = simple_adj();
let y = sage.forward(&x, &adj).unwrap();
assert_eq!(y.shape(), vec![3, 2]);
}
#[test]
fn test_gcn_self_loops() {
let mut rng = Rng::new(42);
let gcn = GCNConv::<f64>::new(4, 2, true, &mut rng);
let x = Variable::new(
Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 4]).unwrap(),
true,
);
let adj = Tensor::eye(2);
let y = gcn.forward(&x, &adj).unwrap();
assert_eq!(y.shape(), vec![2, 2]);
let y_slice = y.data();
let y_data = y_slice.as_slice();
let all_zero = y_data.iter().all(|&v| v.abs() < 1e-15);
assert!(
!all_zero,
"GCN output should not be all zeros with identity adjacency"
);
}
}