use crate::error::{GnnError, GnnResult};
use crate::graph::csr::CsrGraph;
#[derive(Debug, Clone, Copy)]
pub struct GcniiConfig {
pub dim: usize,
pub num_layers: usize,
pub alpha: f32,
pub theta: f32,
}
struct NormAdj {
n: usize,
p: Vec<f32>,
}
impl NormAdj {
fn build(graph: &CsrGraph) -> GnnResult<Self> {
let n = graph.n_nodes();
let mut adj = vec![false; n * n];
for i in 0..n {
adj[i * n + i] = true; for &j in graph.neighbors(i)? {
adj[i * n + j] = true;
adj[j * n + i] = true; }
}
let mut deg = vec![0.0_f32; n];
for (i, d) in deg.iter_mut().enumerate() {
let mut count = 0.0_f32;
for j in 0..n {
if adj[i * n + j] {
count += 1.0;
}
}
*d = count;
}
let d_inv_sqrt: Vec<f32> = deg
.iter()
.map(|&d| if d > 0.0 { 1.0 / d.sqrt() } else { 0.0 })
.collect();
let mut p = vec![0.0_f32; n * n];
for i in 0..n {
for j in 0..n {
if adj[i * n + j] {
p[i * n + j] = d_inv_sqrt[i] * d_inv_sqrt[j];
}
}
}
Ok(Self { n, p })
}
fn propagate(&self, x: &[f32], dim: usize) -> Vec<f32> {
let n = self.n;
let mut out = vec![0.0_f32; n * dim];
for i in 0..n {
for j in 0..n {
let pij = self.p[i * n + j];
if pij != 0.0 {
for k in 0..dim {
out[i * dim + k] += pij * x[j * dim + k];
}
}
}
}
out
}
#[inline]
fn dense(&self) -> &[f32] {
&self.p
}
}
#[must_use]
pub fn gcnii_beta(theta: f32, layer: usize) -> f32 {
let l = layer.max(1) as f32;
(theta / l + 1.0).ln()
}
pub struct Gcnii {
config: GcniiConfig,
weights: Vec<Vec<f32>>,
}
impl Gcnii {
pub fn new(config: GcniiConfig, weights: Vec<Vec<f32>>) -> GnnResult<Self> {
if config.dim == 0 {
return Err(GnnError::InvalidLayerConfig(
"GCNII: dim must be > 0".to_string(),
));
}
if config.num_layers == 0 {
return Err(GnnError::InvalidLayerConfig(
"GCNII: num_layers must be >= 1".to_string(),
));
}
if !(0.0..=1.0).contains(&config.alpha) {
return Err(GnnError::InvalidLayerConfig(format!(
"GCNII: alpha must be in [0, 1], got {}",
config.alpha
)));
}
if !config.theta.is_finite() || config.theta < 0.0 {
return Err(GnnError::InvalidLayerConfig(format!(
"GCNII: theta must be finite and >= 0, got {}",
config.theta
)));
}
if weights.len() != config.num_layers {
return Err(GnnError::DimensionMismatch {
expected: config.num_layers,
got: weights.len(),
});
}
for w in &weights {
if w.len() != config.dim * config.dim {
return Err(GnnError::WeightShapeMismatch {
r: config.dim,
c: config.dim,
d: config.dim,
});
}
}
Ok(Self { config, weights })
}
pub fn with_identity_weights(config: GcniiConfig) -> GnnResult<Self> {
let dim = config.dim;
let mut id = vec![0.0_f32; dim * dim];
for i in 0..dim {
id[i * dim + i] = 1.0;
}
let weights = vec![id; config.num_layers];
Self::new(config, weights)
}
#[inline]
pub fn dim(&self) -> usize {
self.config.dim
}
#[inline]
pub fn num_layers(&self) -> usize {
self.config.num_layers
}
#[inline]
pub fn beta(&self, layer: usize) -> f32 {
gcnii_beta(self.config.theta, layer)
}
fn layer_forward(&self, norm_adj: &NormAdj, h: &[f32], h0: &[f32], layer: usize) -> Vec<f32> {
let n = norm_adj.n;
let dim = self.config.dim;
let alpha = self.config.alpha;
let beta = self.beta(layer);
let w = &self.weights[layer - 1];
let prop = norm_adj.propagate(h, dim);
let mut m = vec![0.0_f32; n * dim];
for idx in 0..n * dim {
m[idx] = (1.0 - alpha) * prop[idx] + alpha * h0[idx];
}
let mut out = vec![0.0_f32; n * dim];
for i in 0..n {
let m_row = &m[i * dim..(i + 1) * dim];
let out_row = &mut out[i * dim..(i + 1) * dim];
for k in 0..dim {
let mut wk = 0.0_f32;
for (j, &mj) in m_row.iter().enumerate() {
wk += mj * w[j * dim + k];
}
let val = (1.0 - beta) * m_row[k] + beta * wk;
out_row[k] = val.max(0.0); }
}
out
}
pub fn forward(&self, graph: &CsrGraph, h0: &[f32]) -> GnnResult<Vec<f32>> {
let n = graph.n_nodes();
let dim = self.config.dim;
if h0.len() != n * dim {
return Err(GnnError::NodeFeatureMismatch(n, h0.len() / dim.max(1)));
}
let norm_adj = NormAdj::build(graph)?;
let mut h = h0.to_vec();
for layer in 1..=self.config.num_layers {
h = self.layer_forward(&norm_adj, &h, h0, layer);
}
if h.iter().any(|v| !v.is_finite()) {
return Err(GnnError::NonFiniteOutput("GCNII forward"));
}
Ok(h)
}
pub fn normalized_adjacency_dense(&self, graph: &CsrGraph) -> GnnResult<Vec<f32>> {
Ok(NormAdj::build(graph)?.dense().to_vec())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn path_graph(n: usize) -> CsrGraph {
let mut edges = Vec::new();
for i in 0..n - 1 {
edges.push((i, i + 1));
edges.push((i + 1, i));
}
CsrGraph::from_edges(n, &edges).expect("path graph")
}
fn ring_graph(n: usize) -> CsrGraph {
let mut edges = Vec::new();
for i in 0..n {
let j = (i + 1) % n;
edges.push((i, j));
edges.push((j, i));
}
CsrGraph::from_edges(n, &edges).expect("ring graph")
}
fn random_feats(n: usize, dim: usize, seed: u64) -> Vec<f32> {
let mut r = crate::handle::LcgRng::new(seed);
(0..n * dim).map(|_| r.next_f32() * 2.0 - 1.0).collect()
}
fn identity(dim: usize) -> Vec<f32> {
let mut w = vec![0.0_f32; dim * dim];
for i in 0..dim {
w[i * dim + i] = 1.0;
}
w
}
#[test]
fn beta_decreases_with_depth() {
let theta = 1.0_f32;
let mut prev = f32::INFINITY;
for l in 1..=32 {
let b = gcnii_beta(theta, l);
assert!(b > 0.0, "beta should be positive for theta>0");
assert!(b < prev, "beta_{l}={b} should be < beta_{}={prev}", l - 1);
prev = b;
}
assert!((gcnii_beta(1.0, 1) - std::f32::consts::LN_2).abs() < 1e-6);
}
#[test]
fn identity_mapping_limit_weight_has_no_effect() {
let g = path_graph(6);
let dim = 4;
let h0 = random_feats(6, dim, 11);
let cfg = GcniiConfig {
dim,
num_layers: 3,
alpha: 0.1,
theta: 0.0, };
let weights: Vec<Vec<f32>> = (0..3).map(|s| random_feats(dim, dim, 100 + s)).collect();
let model_w = Gcnii::new(cfg, weights).expect("model");
let model_id = Gcnii::with_identity_weights(cfg).expect("model id");
let out_w = model_w.forward(&g, &h0).expect("fwd w");
let out_id = model_id.forward(&g, &h0).expect("fwd id");
for (a, b) in out_w.iter().zip(out_id.iter()) {
assert!((a - b).abs() < 1e-5, "β=0 must ignore W: {a} vs {b}");
}
}
#[test]
fn initial_residual_alpha_one_ignores_propagation() {
let dim = 3;
let n = 5;
let h0 = random_feats(n, dim, 7);
let cfg = GcniiConfig {
dim,
num_layers: 4,
alpha: 1.0,
theta: 1.0,
};
let g_path = path_graph(n);
let g_ring = ring_graph(n);
let m_path = Gcnii::with_identity_weights(cfg).expect("m");
let m_ring = Gcnii::with_identity_weights(cfg).expect("m");
let out_path = m_path.forward(&g_path, &h0).expect("fwd");
let out_ring = m_ring.forward(&g_ring, &h0).expect("fwd");
for (a, b) in out_path.iter().zip(out_ring.iter()) {
assert!((a - b).abs() < 1e-6, "α=1 must ignore graph: {a} vs {b}");
}
let relu_h0: Vec<f32> = h0.iter().map(|&v| v.max(0.0)).collect();
for (a, b) in out_path.iter().zip(relu_h0.iter()) {
assert!((a - b).abs() < 1e-6, "α=1 output should be ReLU(h0)");
}
}
#[test]
fn anti_oversmoothing_beats_vanilla_propagation() {
let n = 20;
let dim = 8;
let depth = 32;
let g = ring_graph(n);
let h0: Vec<f32> = {
let mut r = crate::handle::LcgRng::new(2024);
(0..n * dim).map(|_| r.next_f32() * 2.0 + 0.1).collect()
};
let cfg = GcniiConfig {
dim,
num_layers: depth,
alpha: 0.2,
theta: 1.0,
};
let model = Gcnii::with_identity_weights(cfg).expect("model");
let out = model.forward(&g, &h0).expect("fwd");
let norm_adj = NormAdj::build(&g).expect("norm");
let mut hv = h0.clone();
for _ in 0..depth {
hv = norm_adj.propagate(&hv, dim);
}
let feature_variance = |h: &[f32]| -> f32 {
let mut total = 0.0_f32;
for k in 0..dim {
let mut mean = 0.0_f32;
for i in 0..n {
mean += h[i * dim + k];
}
mean /= n as f32;
let mut var = 0.0_f32;
for i in 0..n {
let d = h[i * dim + k] - mean;
var += d * d;
}
total += var / n as f32;
}
total / dim as f32
};
let var_gcnii = feature_variance(&out);
let var_vanilla = feature_variance(&hv);
assert!(var_gcnii > 1e-3, "GCNII variance collapsed: {var_gcnii}");
assert!(
var_gcnii > 2.0 * var_vanilla,
"GCNII variance {var_gcnii} must exceed vanilla {var_vanilla}"
);
}
#[test]
fn output_shape_correct() {
let g = path_graph(7);
let dim = 5;
let cfg = GcniiConfig {
dim,
num_layers: 3,
alpha: 0.1,
theta: 0.5,
};
let model = Gcnii::with_identity_weights(cfg).expect("model");
let h0 = random_feats(7, dim, 3);
let out = model.forward(&g, &h0).expect("fwd");
assert_eq!(out.len(), 7 * dim);
assert!(out.iter().all(|v| v.is_finite()));
assert_eq!(model.dim(), dim);
assert_eq!(model.num_layers(), 3);
}
#[test]
fn normalized_adjacency_symmetric_with_self_loops() {
let g = ring_graph(6);
let cfg = GcniiConfig {
dim: 2,
num_layers: 1,
alpha: 0.1,
theta: 1.0,
};
let model = Gcnii::with_identity_weights(cfg).expect("model");
let p = model.normalized_adjacency_dense(&g).expect("p");
let n = g.n_nodes();
for i in 0..n {
assert!(p[i * n + i] > 0.0, "diag[{i}]={} not >0", p[i * n + i]);
}
for i in 0..n {
for j in 0..n {
assert!(
(p[i * n + j] - p[j * n + i]).abs() < 1e-6,
"asymmetry at ({i},{j})"
);
}
}
for i in 0..n {
assert!((p[i * n + i] - 1.0 / 3.0).abs() < 1e-5);
}
assert!((p[1] - 1.0 / 3.0).abs() < 1e-5);
}
#[test]
fn alpha_zero_identity_equals_single_propagation() {
let g = path_graph(5);
let dim = 3;
let h0 = random_feats(5, dim, 55);
let cfg = GcniiConfig {
dim,
num_layers: 1,
alpha: 0.0,
theta: 0.0, };
let model = Gcnii::with_identity_weights(cfg).expect("model");
let out = model.forward(&g, &h0).expect("fwd");
let norm_adj = NormAdj::build(&g).expect("norm");
let prop = norm_adj.propagate(&h0, dim);
let relu_prop: Vec<f32> = prop.iter().map(|&v| v.max(0.0)).collect();
for (a, b) in out.iter().zip(relu_prop.iter()) {
assert!((a - b).abs() < 1e-5, "{a} vs {b}");
}
}
#[test]
fn rejects_bad_config_and_weights() {
let dim = 3;
let ok_w = vec![identity(dim)];
assert!(
Gcnii::new(
GcniiConfig {
dim: 0,
num_layers: 1,
alpha: 0.1,
theta: 1.0
},
vec![]
)
.is_err()
);
assert!(
Gcnii::new(
GcniiConfig {
dim,
num_layers: 1,
alpha: 2.0,
theta: 1.0
},
ok_w.clone()
)
.is_err()
);
assert!(
Gcnii::new(
GcniiConfig {
dim,
num_layers: 2,
alpha: 0.1,
theta: 1.0
},
ok_w.clone()
)
.is_err() );
assert!(
Gcnii::new(
GcniiConfig {
dim,
num_layers: 1,
alpha: 0.1,
theta: 1.0
},
vec![vec![0.0_f32; dim * dim + 1]]
)
.is_err()
);
let g = path_graph(4);
let model = Gcnii::with_identity_weights(GcniiConfig {
dim,
num_layers: 1,
alpha: 0.1,
theta: 1.0,
})
.expect("model");
assert!(model.forward(&g, &vec![0.0_f32; 4 * dim + 2]).is_err());
}
}