use crate::error::{GnnError, GnnResult};
use crate::graph::csr::CsrGraph;
use crate::message_passing::update::relu;
pub fn sign_precompute(
graph: &CsrGraph,
x: &[f32],
feat_dim: usize,
r_max: usize,
) -> GnnResult<Vec<Vec<f32>>> {
if feat_dim == 0 {
return Err(GnnError::InvalidLayerConfig(
"SIGN: feat_dim must be > 0".to_string(),
));
}
let n = graph.n_nodes();
if x.len() != n * feat_dim {
return Err(GnnError::NodeFeatureMismatch(n, x.len() / feat_dim));
}
let (rows, cols, vals) = graph.normalized_adjacency();
let mut hops: Vec<Vec<f32>> = Vec::with_capacity(r_max + 1);
hops.push(x.to_vec());
for r in 1..=r_max {
let prev = &hops[r - 1];
let mut next = vec![0.0_f32; n * feat_dim];
for idx in 0..rows.len() {
let i = rows[idx];
let j = cols[idx];
let v = vals[idx];
for d in 0..feat_dim {
next[i * feat_dim + d] += v * prev[j * feat_dim + d];
}
}
if next.iter().any(|v| !v.is_finite()) {
return Err(GnnError::NonFiniteOutput("sign_precompute"));
}
hops.push(next);
}
Ok(hops)
}
#[derive(Debug, Clone)]
pub struct SignConfig {
pub in_features: usize,
pub hop_features: usize,
pub out_features: usize,
pub r_max: usize,
}
pub struct SignConv {
config: SignConfig,
}
impl SignConv {
pub fn new(config: SignConfig) -> GnnResult<Self> {
if config.in_features == 0 {
return Err(GnnError::InvalidLayerConfig(
"SIGN: in_features must be > 0".to_string(),
));
}
if config.hop_features == 0 {
return Err(GnnError::InvalidLayerConfig(
"SIGN: hop_features must be > 0".to_string(),
));
}
if config.out_features == 0 {
return Err(GnnError::InvalidLayerConfig(
"SIGN: out_features must be > 0".to_string(),
));
}
Ok(Self { config })
}
pub fn n_hops(&self) -> usize {
self.config.r_max + 1
}
pub fn concat_dim(&self) -> usize {
self.n_hops() * self.config.hop_features
}
pub fn forward(
&self,
hops: &[Vec<f32>],
hop_weights: &[Vec<f32>],
out_weight: &[f32],
) -> GnnResult<Vec<f32>> {
let d_in = self.config.in_features;
let d_hop = self.config.hop_features;
let d_out = self.config.out_features;
let n_hops = self.n_hops();
if hops.len() != n_hops {
return Err(GnnError::DimensionMismatch {
expected: n_hops,
got: hops.len(),
});
}
if hop_weights.len() != n_hops {
return Err(GnnError::DimensionMismatch {
expected: n_hops,
got: hop_weights.len(),
});
}
if hops[0].is_empty() || hops[0].len() / d_in * d_in != hops[0].len() {
return Err(GnnError::NodeFeatureMismatch(0, hops[0].len()));
}
let n = hops[0].len() / d_in;
for h in hops {
if h.len() != n * d_in {
return Err(GnnError::NodeFeatureMismatch(n, h.len() / d_in.max(1)));
}
}
for w in hop_weights {
if w.len() != d_in * d_hop {
return Err(GnnError::WeightShapeMismatch {
r: d_in,
c: d_hop,
d: d_in,
});
}
}
let concat_dim = self.concat_dim();
if out_weight.len() != concat_dim * d_out {
return Err(GnnError::WeightShapeMismatch {
r: concat_dim,
c: d_out,
d: concat_dim,
});
}
let mut concat = vec![0.0_f32; n * concat_dim];
for (r, (hop, w)) in hops.iter().zip(hop_weights.iter()).enumerate() {
let col_off = r * d_hop;
for i in 0..n {
for k in 0..d_hop {
let mut acc = 0.0_f32;
for j in 0..d_in {
acc += hop[i * d_in + j] * w[j * d_hop + k];
}
concat[i * concat_dim + col_off + k] = acc;
}
}
}
let concat = relu(&concat);
let mut out = vec![0.0_f32; n * d_out];
for i in 0..n {
for k in 0..d_out {
let mut acc = 0.0_f32;
for j in 0..concat_dim {
acc += concat[i * concat_dim + j] * out_weight[j * d_out + k];
}
out[i * d_out + k] = acc;
}
}
if out.iter().any(|v| !v.is_finite()) {
return Err(GnnError::NonFiniteOutput("SignConv::forward"));
}
Ok(out)
}
pub fn output_dim(&self) -> usize {
self.config.out_features
}
}
#[cfg(test)]
mod tests {
use super::*;
fn ring(n: usize) -> CsrGraph {
let edges: Vec<(usize, usize)> = (0..n)
.flat_map(|i| [(i, (i + 1) % n), ((i + 1) % n, i)])
.collect();
CsrGraph::from_edges(n, &edges).expect("test invariant: value must be valid")
}
#[test]
fn precompute_returns_r_plus_one_hops() {
let g = ring(4);
let x = vec![1.0_f32; 4 * 2];
let hops = sign_precompute(&g, &x, 2, 3).expect("precompute");
assert_eq!(hops.len(), 4); for h in &hops {
assert_eq!(h.len(), 4 * 2);
}
}
#[test]
fn precompute_hop0_is_identity() {
let g = ring(3);
let x = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let hops = sign_precompute(&g, &x, 2, 2).expect("precompute");
assert_eq!(hops[0], x);
}
#[test]
fn precompute_zero_rmax_only_hop0() {
let g = ring(3);
let x = vec![1.0_f32; 3];
let hops = sign_precompute(&g, &x, 1, 0).expect("precompute");
assert_eq!(hops.len(), 1);
}
#[test]
fn precompute_feat_dim_zero_errors() {
let g = ring(3);
let err = sign_precompute(&g, &[1.0_f32; 3], 0, 1);
assert!(matches!(err, Err(GnnError::InvalidLayerConfig(_))));
}
#[test]
fn precompute_feature_mismatch_errors() {
let g = ring(4);
let err = sign_precompute(&g, &[1.0_f32; 5], 2, 1); assert!(matches!(err, Err(GnnError::NodeFeatureMismatch(..))));
}
#[test]
fn build_and_dims() {
let conv = SignConv::new(SignConfig {
in_features: 3,
hop_features: 4,
out_features: 5,
r_max: 2,
})
.expect("build");
assert_eq!(conv.n_hops(), 3);
assert_eq!(conv.concat_dim(), 3 * 4);
assert_eq!(conv.output_dim(), 5);
}
#[test]
fn build_zero_dims_error() {
assert!(
SignConv::new(SignConfig {
in_features: 0,
hop_features: 4,
out_features: 5,
r_max: 1,
})
.is_err()
);
assert!(
SignConv::new(SignConfig {
in_features: 3,
hop_features: 0,
out_features: 5,
r_max: 1,
})
.is_err()
);
assert!(
SignConv::new(SignConfig {
in_features: 3,
hop_features: 4,
out_features: 0,
r_max: 1,
})
.is_err()
);
}
#[test]
fn forward_output_shape() {
let g = ring(5);
let d_in = 3;
let d_hop = 4;
let d_out = 2;
let r_max = 2;
let conv = SignConv::new(SignConfig {
in_features: d_in,
hop_features: d_hop,
out_features: d_out,
r_max,
})
.expect("build");
let x = vec![0.1_f32; 5 * d_in];
let hops = sign_precompute(&g, &x, d_in, r_max).expect("precompute");
let hop_weights: Vec<Vec<f32>> = (0..=r_max).map(|_| vec![0.1_f32; d_in * d_hop]).collect();
let out_weight = vec![0.1_f32; conv.concat_dim() * d_out];
let out = conv
.forward(&hops, &hop_weights, &out_weight)
.expect("forward");
assert_eq!(out.len(), 5 * d_out);
assert!(out.iter().all(|v| v.is_finite()));
}
#[test]
fn forward_zero_weights_zero_output() {
let g = ring(4);
let d_in = 2;
let d_hop = 3;
let d_out = 2;
let r_max = 1;
let conv = SignConv::new(SignConfig {
in_features: d_in,
hop_features: d_hop,
out_features: d_out,
r_max,
})
.expect("build");
let x = vec![1.0_f32; 4 * d_in];
let hops = sign_precompute(&g, &x, d_in, r_max).expect("precompute");
let hop_weights: Vec<Vec<f32>> = (0..=r_max).map(|_| vec![0.0_f32; d_in * d_hop]).collect();
let out_weight = vec![0.5_f32; conv.concat_dim() * d_out];
let out = conv
.forward(&hops, &hop_weights, &out_weight)
.expect("forward");
assert!(out.iter().all(|&v| v.abs() < 1e-7));
}
#[test]
fn forward_wrong_hop_count_errors() {
let g = ring(4);
let conv = SignConv::new(SignConfig {
in_features: 2,
hop_features: 3,
out_features: 2,
r_max: 2,
})
.expect("build");
let x = vec![1.0_f32; 4 * 2];
let hops = sign_precompute(&g, &x, 2, 1).expect("precompute"); let hop_weights: Vec<Vec<f32>> = (0..3).map(|_| vec![0.1_f32; 2 * 3]).collect();
let out_weight = vec![0.1_f32; conv.concat_dim() * 2];
let err = conv.forward(&hops, &hop_weights, &out_weight);
assert!(matches!(err, Err(GnnError::DimensionMismatch { .. })));
}
#[test]
fn forward_wrong_weight_shape_errors() {
let g = ring(4);
let r_max = 1;
let conv = SignConv::new(SignConfig {
in_features: 2,
hop_features: 3,
out_features: 2,
r_max,
})
.expect("build");
let x = vec![1.0_f32; 4 * 2];
let hops = sign_precompute(&g, &x, 2, r_max).expect("precompute");
let hop_weights: Vec<Vec<f32>> = (0..=r_max).map(|_| vec![0.1_f32; 99]).collect(); let out_weight = vec![0.1_f32; conv.concat_dim() * 2];
let err = conv.forward(&hops, &hop_weights, &out_weight);
assert!(matches!(err, Err(GnnError::WeightShapeMismatch { .. })));
}
#[test]
fn forward_relu_clamps_negatives() {
let g = ring(4);
let d_in = 2;
let d_hop = 2;
let d_out = 1;
let r_max = 1;
let conv = SignConv::new(SignConfig {
in_features: d_in,
hop_features: d_hop,
out_features: d_out,
r_max,
})
.expect("build");
let x = vec![1.0_f32; 4 * d_in];
let hops = sign_precompute(&g, &x, d_in, r_max).expect("precompute");
let hop_weights: Vec<Vec<f32>> =
(0..=r_max).map(|_| vec![-1.0_f32; d_in * d_hop]).collect();
let out_weight = vec![1.0_f32; conv.concat_dim() * d_out];
let out = conv
.forward(&hops, &hop_weights, &out_weight)
.expect("forward");
assert!(
out.iter().all(|&v| v.abs() < 1e-6),
"ReLU should zero negatives"
);
}
#[test]
fn different_rmax_changes_output() {
let g = ring(6);
let d_in = 2;
let d_hop = 2;
let d_out = 2;
let x: Vec<f32> = (0..6 * d_in).map(|i| (i as f32) * 0.1).collect();
let build = |r_max: usize| {
let conv = SignConv::new(SignConfig {
in_features: d_in,
hop_features: d_hop,
out_features: d_out,
r_max,
})
.expect("build");
let hops = sign_precompute(&g, &x, d_in, r_max).expect("precompute");
let hw: Vec<Vec<f32>> = (0..=r_max).map(|_| vec![0.3_f32; d_in * d_hop]).collect();
let ow = vec![0.2_f32; conv.concat_dim() * d_out];
conv.forward(&hops, &hw, &ow).expect("forward")
};
let o1 = build(1);
let o2 = build(2);
let diff: f32 = o1.iter().zip(o2.iter()).map(|(a, b)| (a - b).abs()).sum();
assert!(diff >= 0.0); assert!(o1.iter().all(|v| v.is_finite()) && o2.iter().all(|v| v.is_finite()));
}
}