use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::random::{Rng, RngExt};
use crate::error::{GraphError, Result};
use crate::gnn::gcn::CsrMatrix;
#[derive(Debug, Clone)]
pub struct RandomWalkPe {
pub walk_length: usize,
pub projection: Array2<f64>,
pub pe_dim: usize,
}
impl RandomWalkPe {
pub fn new(walk_length: usize, pe_dim: usize) -> Self {
let mut rng = scirs2_core::random::rng();
let scale = (6.0_f64 / (walk_length + pe_dim) as f64).sqrt();
let projection = Array2::from_shape_fn((walk_length, pe_dim), |_| {
(rng.random::<f64>() * 2.0 - 1.0) * scale
});
RandomWalkPe {
walk_length,
projection,
pe_dim,
}
}
pub fn compute_landing_probs(&self, adj: &CsrMatrix) -> Array2<f64> {
let n = adj.n_rows;
let row_sums = adj.row_sums();
let mut p_data: Vec<f64> = Vec::with_capacity(adj.nnz());
for (row, _col, val) in adj.iter() {
let d = row_sums[row];
if d > 0.0 {
p_data.push(val / d);
} else {
p_data.push(0.0);
}
}
let p = CsrMatrix {
n_rows: adj.n_rows,
n_cols: adj.n_cols,
indptr: adj.indptr.clone(),
indices: adj.indices.clone(),
data: p_data,
};
let mut landing = Array2::zeros((n, self.walk_length));
if n <= 500 {
let mut power = Array2::<f64>::eye(n);
for k in 0..self.walk_length {
let mut new_power = Array2::zeros((n, n));
for (row, col, val) in p.iter() {
for j in 0..n {
new_power[[row, j]] += val * power[[col, j]];
}
}
power = new_power;
for i in 0..n {
landing[[i, k]] = power[[i, i]];
}
}
} else {
for i in 0..n {
let mut vec_cur = vec![0.0f64; n];
vec_cur[i] = 1.0;
for k in 0..self.walk_length {
let mut vec_next = vec![0.0f64; n];
for (row, col, val) in p.iter() {
vec_next[row] += val * vec_cur[col];
}
landing[[i, k]] = vec_next[i];
vec_cur = vec_next;
}
}
}
landing
}
pub fn forward(&self, adj: &CsrMatrix) -> Array2<f64> {
let landing = self.compute_landing_probs(adj);
let n = adj.n_rows;
let mut pe = Array2::zeros((n, self.pe_dim));
for i in 0..n {
for j in 0..self.pe_dim {
let mut s = 0.0;
for k in 0..self.walk_length {
s += landing[[i, k]] * self.projection[[k, j]];
}
pe[[i, j]] = s;
}
}
pe
}
}
#[derive(Debug, Clone)]
pub struct LaplacianPe {
pub k: usize,
pub projection: Array2<f64>,
pub pe_dim: usize,
}
impl LaplacianPe {
pub fn new(k: usize, pe_dim: usize) -> Self {
let mut rng = scirs2_core::random::rng();
let scale = (6.0_f64 / (k + pe_dim) as f64).sqrt();
let projection =
Array2::from_shape_fn((k, pe_dim), |_| (rng.random::<f64>() * 2.0 - 1.0) * scale);
LaplacianPe {
k,
projection,
pe_dim,
}
}
pub fn compute_eigenvectors(&self, adj: &CsrMatrix) -> Array2<f64> {
let n = adj.n_rows;
let actual_k = self.k.min(n.saturating_sub(1));
if actual_k == 0 || n < 2 {
return Array2::zeros((n, self.k));
}
let row_sums = adj.row_sums();
let mut lap = Array2::zeros((n, n));
for i in 0..n {
lap[[i, i]] = row_sums[i];
}
for (row, col, val) in adj.iter() {
lap[[row, col]] -= val;
}
let mut eigvecs = Array2::zeros((n, self.k));
let max_lambda_estimate = row_sums.iter().cloned().fold(0.0_f64, f64::max) * 2.0 + 1.0;
let mut m_mat = Array2::zeros((n, n));
for i in 0..n {
for j in 0..n {
m_mat[[i, j]] = -lap[[i, j]];
}
m_mat[[i, i]] += max_lambda_estimate;
}
let mut found_vecs: Vec<Vec<f64>> = Vec::new();
let trivial: Vec<f64> = vec![1.0 / (n as f64).sqrt(); n];
found_vecs.push(trivial);
let num_iters = 200;
for _ev_idx in 0..actual_k {
let mut rng = scirs2_core::random::rng();
let mut v: Vec<f64> = (0..n).map(|_| rng.random::<f64>() * 2.0 - 1.0).collect();
for fv in &found_vecs {
let dot: f64 = v.iter().zip(fv.iter()).map(|(a, b)| a * b).sum();
for (vi, fi) in v.iter_mut().zip(fv.iter()) {
*vi -= dot * fi;
}
}
let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-12);
v.iter_mut().for_each(|x| *x /= norm);
for _ in 0..num_iters {
let mut v_new = vec![0.0f64; n];
for i in 0..n {
for j in 0..n {
v_new[i] += m_mat[[i, j]] * v[j];
}
}
for fv in &found_vecs {
let dot: f64 = v_new.iter().zip(fv.iter()).map(|(a, b)| a * b).sum();
for (vi, fi) in v_new.iter_mut().zip(fv.iter()) {
*vi -= dot * fi;
}
}
let norm: f64 = v_new.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-12);
v_new.iter_mut().for_each(|x| *x /= norm);
v = v_new;
}
found_vecs.push(v);
}
for (idx, fv) in found_vecs.iter().skip(1).take(self.k).enumerate() {
for i in 0..n {
eigvecs[[i, idx]] = fv[i];
}
}
eigvecs
}
pub fn forward(&self, adj: &CsrMatrix) -> Array2<f64> {
let eigvecs = self.compute_eigenvectors(adj);
let n = adj.n_rows;
let mut pe = Array2::zeros((n, self.pe_dim));
for i in 0..n {
for j in 0..self.pe_dim {
let mut s = 0.0;
for m in 0..self.k {
s += eigvecs[[i, m]] * self.projection[[m, j]];
}
pe[[i, j]] = s;
}
}
pe
}
}
#[derive(Debug, Clone)]
struct GinLocal {
w1: Array2<f64>,
w2: Array2<f64>,
b1: Array1<f64>,
b2: Array1<f64>,
eps: f64,
hidden_dim: usize,
}
impl GinLocal {
fn new(hidden_dim: usize) -> Self {
let mut rng = scirs2_core::random::rng();
let scale = (6.0_f64 / (2 * hidden_dim) as f64).sqrt();
GinLocal {
w1: Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
(rng.random::<f64>() * 2.0 - 1.0) * scale
}),
w2: Array2::from_shape_fn((hidden_dim, hidden_dim), |_| {
(rng.random::<f64>() * 2.0 - 1.0) * scale
}),
b1: Array1::zeros(hidden_dim),
b2: Array1::zeros(hidden_dim),
eps: 0.0,
hidden_dim,
}
}
fn forward(&self, x: &Array2<f64>, adj: &CsrMatrix) -> Array2<f64> {
let n = x.dim().0;
let d = self.hidden_dim;
let mut agg = Array2::zeros((n, d));
for i in 0..n {
for j in 0..d {
agg[[i, j]] = (1.0 + self.eps) * x[[i, j]];
}
}
for (row, col, _) in adj.iter() {
for j in 0..d {
agg[[row, j]] += x[[col, j]];
}
}
let mut h = Array2::zeros((n, d));
for i in 0..n {
for j in 0..d {
let mut s = self.b1[j];
for m in 0..d {
s += agg[[i, m]] * self.w1[[m, j]];
}
h[[i, j]] = s.max(0.0); }
}
let mut out = Array2::zeros((n, d));
for i in 0..n {
for j in 0..d {
let mut s = self.b2[j];
for m in 0..d {
s += h[[i, m]] * self.w2[[m, j]];
}
out[[i, j]] = s;
}
}
out
}
}
#[derive(Debug, Clone)]
struct GlobalAttention {
w_q: Array2<f64>,
w_k: Array2<f64>,
w_v: Array2<f64>,
w_o: Array2<f64>,
num_heads: usize,
hidden_dim: usize,
head_dim: usize,
}
impl GlobalAttention {
fn new(hidden_dim: usize, num_heads: usize) -> Result<Self> {
if !hidden_dim.is_multiple_of(num_heads) {
return Err(GraphError::InvalidParameter {
param: "hidden_dim".to_string(),
value: format!("{hidden_dim}"),
expected: format!("divisible by num_heads={num_heads}"),
context: "GlobalAttention::new".to_string(),
});
}
let head_dim = hidden_dim / num_heads;
let mut rng = scirs2_core::random::rng();
let scale = (6.0_f64 / (2 * hidden_dim) as f64).sqrt();
let mut init = |r, c| -> Array2<f64> {
Array2::from_shape_fn((r, c), |_| (rng.random::<f64>() * 2.0 - 1.0) * scale)
};
Ok(GlobalAttention {
w_q: init(hidden_dim, hidden_dim),
w_k: init(hidden_dim, hidden_dim),
w_v: init(hidden_dim, hidden_dim),
w_o: init(hidden_dim, hidden_dim),
num_heads,
hidden_dim,
head_dim,
})
}
fn forward(&self, x: &Array2<f64>) -> Array2<f64> {
let n = x.dim().0;
let d = self.hidden_dim;
let h = self.num_heads;
let dk = self.head_dim;
let scale = 1.0 / (dk as f64).sqrt();
let mut q = Array2::zeros((n, d));
let mut k = Array2::zeros((n, d));
let mut v = Array2::zeros((n, d));
for i in 0..n {
for j in 0..d {
let mut sq = 0.0;
let mut sk = 0.0;
let mut sv = 0.0;
for m in 0..d {
let xi = x[[i, m]];
sq += xi * self.w_q[[m, j]];
sk += xi * self.w_k[[m, j]];
sv += xi * self.w_v[[m, j]];
}
q[[i, j]] = sq;
k[[i, j]] = sk;
v[[i, j]] = sv;
}
}
let mut output = Array2::<f64>::zeros((n, d));
for head in 0..h {
let offset = head * dk;
let mut scores = vec![vec![0.0f64; n]; n];
for i in 0..n {
for j in 0..n {
let mut dot = 0.0;
for m in 0..dk {
dot += q[[i, offset + m]] * k[[j, offset + m]];
}
scores[i][j] = dot * scale;
}
}
for i in 0..n {
let alphas = softmax_row(&scores[i]);
for j in 0..n {
let a = alphas[j];
for m in 0..dk {
output[[i, offset + m]] += a * v[[j, offset + m]];
}
}
}
}
let mut projected = Array2::zeros((n, d));
for i in 0..n {
for j in 0..d {
let mut s = 0.0;
for m in 0..d {
s += output[[i, m]] * self.w_o[[m, j]];
}
projected[[i, j]] = s;
}
}
projected
}
}
fn softmax_row(row: &[f64]) -> Vec<f64> {
if row.is_empty() {
return Vec::new();
}
let max_val = row.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<f64> = row.iter().map(|x| (x - max_val).exp()).collect();
let sum = exps.iter().sum::<f64>().max(1e-12);
exps.iter().map(|e| e / sum).collect()
}
fn layer_norm_vec(x: &mut [f64], eps: f64) {
let n = x.len();
if n == 0 {
return;
}
let mean = x.iter().sum::<f64>() / n as f64;
let var = x.iter().map(|v| (v - mean) * (v - mean)).sum::<f64>() / n as f64;
let inv_std = 1.0 / (var + eps).sqrt();
for v in x.iter_mut() {
*v = (*v - mean) * inv_std;
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LocalModel {
Gin,
Gat,
}
#[derive(Debug, Clone)]
pub struct GpsConfig {
pub in_dim: usize,
pub hidden_dim: usize,
pub num_heads: usize,
pub num_layers: usize,
pub ffn_dim: usize,
pub local_model: LocalModel,
pub pe_dim: usize,
pub rw_walk_length: usize,
pub layer_norm_eps: f64,
}
impl Default for GpsConfig {
fn default() -> Self {
GpsConfig {
in_dim: 64,
hidden_dim: 64,
num_heads: 4,
num_layers: 3,
ffn_dim: 256,
local_model: LocalModel::Gin,
pe_dim: 16,
rw_walk_length: 8,
layer_norm_eps: 1e-5,
}
}
}
#[derive(Debug, Clone)]
pub struct GpsLayer {
gin_local: GinLocal,
global_attn: GlobalAttention,
ffn_w1: Array2<f64>,
ffn_w2: Array2<f64>,
ffn_b1: Array1<f64>,
ffn_b2: Array1<f64>,
hidden_dim: usize,
layer_norm_eps: f64,
}
impl GpsLayer {
pub fn new(
hidden_dim: usize,
num_heads: usize,
ffn_dim: usize,
layer_norm_eps: f64,
) -> Result<Self> {
let mut rng = scirs2_core::random::rng();
let ffn_scale = (6.0_f64 / (hidden_dim + ffn_dim) as f64).sqrt();
Ok(GpsLayer {
gin_local: GinLocal::new(hidden_dim),
global_attn: GlobalAttention::new(hidden_dim, num_heads)?,
ffn_w1: Array2::from_shape_fn((hidden_dim, ffn_dim), |_| {
(rng.random::<f64>() * 2.0 - 1.0) * ffn_scale
}),
ffn_w2: Array2::from_shape_fn((ffn_dim, hidden_dim), |_| {
(rng.random::<f64>() * 2.0 - 1.0) * ffn_scale
}),
ffn_b1: Array1::zeros(ffn_dim),
ffn_b2: Array1::zeros(hidden_dim),
hidden_dim,
layer_norm_eps,
})
}
fn ffn(&self, x: &Array2<f64>) -> Array2<f64> {
let n = x.dim().0;
let d = self.hidden_dim;
let ffn_dim = self.ffn_w1.dim().1;
let mut h = Array2::zeros((n, ffn_dim));
for i in 0..n {
for j in 0..ffn_dim {
let mut s = self.ffn_b1[j];
for m in 0..d {
s += x[[i, m]] * self.ffn_w1[[m, j]];
}
h[[i, j]] = s.max(0.0); }
}
let mut out = Array2::zeros((n, d));
for i in 0..n {
for j in 0..d {
let mut s = self.ffn_b2[j];
for m in 0..ffn_dim {
s += h[[i, m]] * self.ffn_w2[[m, j]];
}
out[[i, j]] = s;
}
}
out
}
pub fn forward(&self, x: &Array2<f64>, adj: &CsrMatrix) -> Result<Array2<f64>> {
let (n, d) = x.dim();
if d != self.hidden_dim {
return Err(GraphError::InvalidParameter {
param: "x".to_string(),
value: format!("dim={d}"),
expected: format!("dim={}", self.hidden_dim),
context: "GpsLayer::forward".to_string(),
});
}
let local_out = self.gin_local.forward(x, adj);
let global_out = self.global_attn.forward(x);
let ffn_out = self.ffn(x);
let mut out = x.clone();
for i in 0..n {
for j in 0..d {
out[[i, j]] += local_out[[i, j]] + global_out[[i, j]] + ffn_out[[i, j]];
}
}
for i in 0..n {
let mut row: Vec<f64> = (0..d).map(|j| out[[i, j]]).collect();
layer_norm_vec(&mut row, self.layer_norm_eps);
for j in 0..d {
out[[i, j]] = row[j];
}
}
Ok(out)
}
}
#[derive(Debug, Clone)]
pub struct GpsModel {
pub input_proj: Array2<f64>,
pub rwpe: RandomWalkPe,
pub layers: Vec<GpsLayer>,
pub config: GpsConfig,
}
impl GpsModel {
pub fn new(config: GpsConfig) -> Result<Self> {
let mut rng = scirs2_core::random::rng();
let total_in = config.in_dim + config.pe_dim;
let proj_scale = (6.0_f64 / (total_in + config.hidden_dim) as f64).sqrt();
let input_proj = Array2::from_shape_fn((total_in, config.hidden_dim), |_| {
(rng.random::<f64>() * 2.0 - 1.0) * proj_scale
});
let rwpe = RandomWalkPe::new(config.rw_walk_length, config.pe_dim);
let mut layers = Vec::with_capacity(config.num_layers);
for _ in 0..config.num_layers {
layers.push(GpsLayer::new(
config.hidden_dim,
config.num_heads,
config.ffn_dim,
config.layer_norm_eps,
)?);
}
Ok(GpsModel {
input_proj,
rwpe,
layers,
config,
})
}
pub fn forward(&self, features: &Array2<f64>, adj: &CsrMatrix) -> Result<Array2<f64>> {
let (n, in_dim) = features.dim();
if in_dim != self.config.in_dim {
return Err(GraphError::InvalidParameter {
param: "features".to_string(),
value: format!("in_dim={in_dim}"),
expected: format!("in_dim={}", self.config.in_dim),
context: "GpsModel::forward".to_string(),
});
}
if adj.n_rows != n {
return Err(GraphError::InvalidParameter {
param: "adj".to_string(),
value: format!("n_rows={}", adj.n_rows),
expected: format!("n_rows={n}"),
context: "GpsModel::forward".to_string(),
});
}
let pe = self.rwpe.forward(adj);
let total_in = self.config.in_dim + self.config.pe_dim;
let mut concat = Array2::zeros((n, total_in));
for i in 0..n {
for j in 0..in_dim {
concat[[i, j]] = features[[i, j]];
}
for j in 0..self.config.pe_dim {
concat[[i, in_dim + j]] = pe[[i, j]];
}
}
let d = self.config.hidden_dim;
let mut h = Array2::zeros((n, d));
for i in 0..n {
for j in 0..d {
let mut s = 0.0;
for m in 0..total_in {
s += concat[[i, m]] * self.input_proj[[m, j]];
}
h[[i, j]] = s;
}
}
for layer in &self.layers {
h = layer.forward(&h, adj)?;
}
Ok(h)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn triangle_csr() -> CsrMatrix {
let coo = vec![
(0, 1, 1.0),
(1, 0, 1.0),
(1, 2, 1.0),
(2, 1, 1.0),
(0, 2, 1.0),
(2, 0, 1.0),
];
CsrMatrix::from_coo(3, 3, &coo).expect("triangle CSR")
}
fn path_csr() -> CsrMatrix {
let coo = vec![
(0, 1, 1.0),
(1, 0, 1.0),
(1, 2, 1.0),
(2, 1, 1.0),
(2, 3, 1.0),
(3, 2, 1.0),
];
CsrMatrix::from_coo(4, 4, &coo).expect("path CSR")
}
fn feats(n: usize, d: usize) -> Array2<f64> {
Array2::from_shape_fn((n, d), |(i, j)| (i * d + j) as f64 * 0.1)
}
#[test]
fn test_rwpe_landing_probs_shape() {
let adj = triangle_csr();
let rwpe = RandomWalkPe::new(4, 8);
let landing = rwpe.compute_landing_probs(&adj);
assert_eq!(landing.dim(), (3, 4));
for &v in landing.iter() {
assert!(v.is_finite(), "landing prob should be finite, got {v}");
assert!(v >= 0.0, "landing prob should be non-negative, got {v}");
}
}
#[test]
fn test_rwpe_produces_correct_features() {
let adj = triangle_csr();
let rwpe = RandomWalkPe::new(3, 6);
let pe = rwpe.forward(&adj);
assert_eq!(pe.dim(), (3, 6));
let landing = rwpe.compute_landing_probs(&adj);
for k in 0..3 {
let val0 = landing[[0, k]];
let val1 = landing[[1, k]];
let val2 = landing[[2, k]];
assert!(
(val0 - val1).abs() < 1e-10 && (val1 - val2).abs() < 1e-10,
"symmetric graph should have equal landing probs at step {k}: {val0}, {val1}, {val2}"
);
}
}
#[test]
fn test_rwpe_path_graph_different_probs() {
let adj = path_csr();
let rwpe = RandomWalkPe::new(3, 4);
let landing = rwpe.compute_landing_probs(&adj);
assert_eq!(landing.dim(), (4, 3));
let end_prob = landing[[0, 0]]; let mid_prob = landing[[1, 0]];
assert!(end_prob.is_finite());
assert!(mid_prob.is_finite());
}
#[test]
fn test_laplacian_pe_shape() {
let adj = triangle_csr();
let lpe = LaplacianPe::new(2, 6);
let pe = lpe.forward(&adj);
assert_eq!(pe.dim(), (3, 6));
for &v in pe.iter() {
assert!(v.is_finite(), "Laplacian PE should be finite, got {v}");
}
}
#[test]
fn test_gps_hybrid_combines_local_and_global() {
let adj = triangle_csr();
let features = feats(3, 8);
let config = GpsConfig {
in_dim: 8,
hidden_dim: 8,
num_heads: 2,
num_layers: 1,
ffn_dim: 16,
local_model: LocalModel::Gin,
pe_dim: 4,
rw_walk_length: 3,
..Default::default()
};
let model = GpsModel::new(config).expect("GPS model");
let out = model.forward(&features, &adj).expect("GPS forward");
assert_eq!(out.dim(), (3, 8));
for &v in out.iter() {
assert!(v.is_finite(), "GPS output should be finite, got {v}");
}
let has_variation = out.iter().any(|&v| v.abs() > 1e-12);
assert!(has_variation, "GPS output should have non-trivial values");
}
#[test]
fn test_gps_layer_forward_shape() {
let adj = triangle_csr();
let x = feats(3, 8);
let layer = GpsLayer::new(8, 2, 16, 1e-5).expect("GPS layer");
let out = layer.forward(&x, &adj).expect("GPS layer forward");
assert_eq!(out.dim(), (3, 8));
for i in 0..3 {
let mean: f64 = (0..8).map(|j| out[[i, j]]).sum::<f64>() / 8.0;
assert!(
mean.abs() < 0.1,
"after layer norm, mean should be near 0, got {mean}"
);
}
}
#[test]
fn test_gps_multi_layer() {
let adj = path_csr();
let config = GpsConfig {
in_dim: 4,
hidden_dim: 8,
num_heads: 2,
num_layers: 3,
ffn_dim: 16,
pe_dim: 4,
rw_walk_length: 3,
..Default::default()
};
let model = GpsModel::new(config).expect("GPS model");
let features = feats(4, 4);
let out = model.forward(&features, &adj).expect("GPS forward");
assert_eq!(out.dim(), (4, 8));
for &v in out.iter() {
assert!(v.is_finite(), "multi-layer GPS output should be finite");
}
}
#[test]
fn test_gps_invalid_dim_error() {
let adj = triangle_csr();
let config = GpsConfig {
in_dim: 4,
hidden_dim: 7, num_heads: 4,
..Default::default()
};
let result = GpsModel::new(config);
assert!(result.is_err());
}
#[test]
fn test_gin_local_aggregation() {
let adj = triangle_csr();
let x = feats(3, 8);
let gin = GinLocal::new(8);
let out = gin.forward(&x, &adj);
assert_eq!(out.dim(), (3, 8));
for &v in out.iter() {
assert!(v.is_finite());
}
}
}