use crate::error::{GnnError, GnnResult};
use crate::graph::csr::CsrGraph;
use crate::message_passing::update::relu;
#[derive(Debug, Clone)]
pub struct GcnConfig {
pub in_features: usize,
pub out_features: usize,
pub bias: bool,
pub normalize: bool,
}
pub struct GcnLayer {
config: GcnConfig,
}
impl GcnLayer {
pub fn new(config: GcnConfig) -> GnnResult<Self> {
if config.in_features == 0 {
return Err(GnnError::InvalidLayerConfig(
"in_features must be > 0".to_string(),
));
}
if config.out_features == 0 {
return Err(GnnError::InvalidLayerConfig(
"out_features must be > 0".to_string(),
));
}
Ok(Self { config })
}
pub fn forward(
&self,
graph: &CsrGraph,
node_features: &[f32],
weight: &[f32],
bias: Option<&[f32]>,
) -> GnnResult<Vec<f32>> {
let n = graph.n_nodes();
let in_f = self.config.in_features;
let out_f = self.config.out_features;
if node_features.len() != n * in_f {
return Err(GnnError::NodeFeatureMismatch(
n,
node_features.len() / in_f.max(1),
));
}
if weight.len() != in_f * out_f {
return Err(GnnError::WeightShapeMismatch {
r: in_f,
c: out_f,
d: in_f,
});
}
if let Some(b) = bias {
if b.len() != out_f {
return Err(GnnError::DimensionMismatch {
expected: out_f,
got: b.len(),
});
}
}
let mut h_proj = vec![0.0_f32; n * out_f];
for i in 0..n {
for k in 0..out_f {
let mut acc = 0.0_f32;
for j in 0..in_f {
acc += node_features[i * in_f + j] * weight[j * out_f + k];
}
h_proj[i * out_f + k] = acc;
}
}
if let Some(b) = bias {
for i in 0..n {
for k in 0..out_f {
h_proj[i * out_f + k] += b[k];
}
}
}
let h_aggr = if self.config.normalize {
let (rows, cols, vals) = graph.normalized_adjacency();
let mut out = vec![0.0_f32; n * out_f];
for ((r, c), v) in rows.iter().zip(cols.iter()).zip(vals.iter()) {
for k in 0..out_f {
out[r * out_f + k] += v * h_proj[c * out_f + k];
}
}
out
} else {
graph.spmv(&h_proj, out_f)?
};
Ok(relu(&h_aggr))
}
pub fn output_dim(&self) -> usize {
self.config.out_features
}
}
#[cfg(test)]
mod tests {
use super::*;
fn simple_graph() -> CsrGraph {
CsrGraph::from_edges(4, &[(0, 1), (1, 0), (1, 2), (2, 1), (2, 3), (3, 2)])
.expect("test invariant: value must be valid")
}
fn identity_weight(d: usize) -> Vec<f32> {
let mut w = vec![0.0_f32; d * d];
for i in 0..d {
w[i * d + i] = 1.0;
}
w
}
#[test]
fn output_shape_correct() {
let g = simple_graph();
let config = GcnConfig {
in_features: 3,
out_features: 5,
bias: false,
normalize: true,
};
let layer = GcnLayer::new(config).expect("test invariant: value must be valid");
let feats = vec![1.0_f32; 4 * 3];
let w = vec![0.1_f32; 3 * 5];
let out = layer
.forward(&g, &feats, &w, None)
.expect("test invariant: value must be valid");
assert_eq!(out.len(), 4 * 5);
}
#[test]
fn zero_weights_zero_output() {
let g = simple_graph();
let config = GcnConfig {
in_features: 2,
out_features: 2,
bias: false,
normalize: false,
};
let layer = GcnLayer::new(config).expect("test invariant: value must be valid");
let feats = vec![1.0_f32; 4 * 2];
let w = vec![0.0_f32; 2 * 2];
let out = layer
.forward(&g, &feats, &w, None)
.expect("test invariant: value must be valid");
assert!(out.iter().all(|&v| v.abs() < 1e-6));
}
#[test]
fn relu_applied_no_negatives() {
let g = simple_graph();
let config = GcnConfig {
in_features: 2,
out_features: 2,
bias: false,
normalize: true,
};
let layer = GcnLayer::new(config).expect("test invariant: value must be valid");
let feats = vec![-1.0_f32; 4 * 2];
let w = vec![-1.0_f32; 2 * 2];
let out = layer
.forward(&g, &feats, &w, None)
.expect("test invariant: value must be valid");
assert!(out.iter().all(|&v| v >= 0.0));
}
#[test]
fn one_node_graph() {
let g = CsrGraph::from_edges(1, &[(0, 0)]).expect("test invariant: value must be valid");
let config = GcnConfig {
in_features: 2,
out_features: 2,
bias: false,
normalize: true,
};
let layer = GcnLayer::new(config).expect("test invariant: value must be valid");
let feats = vec![1.0_f32, 2.0];
let w = identity_weight(2);
let out = layer
.forward(&g, &feats, &w, None)
.expect("test invariant: value must be valid");
assert_eq!(out.len(), 2);
assert!(out.iter().all(|&v| v >= 0.0));
}
#[test]
fn bias_added_correctly() {
let g = CsrGraph::from_edges(1, &[(0, 0)]).expect("test invariant: value must be valid");
let config = GcnConfig {
in_features: 2,
out_features: 2,
bias: true,
normalize: false,
};
let layer = GcnLayer::new(config).expect("test invariant: value must be valid");
let feats = vec![0.0_f32, 0.0]; let w = vec![0.0_f32; 2 * 2];
let b = vec![1.0_f32, 2.0];
let out = layer
.forward(&g, &feats, &w, Some(&b))
.expect("test invariant: value must be valid");
assert!(out[0] > 0.0 || out[1] > 0.0);
}
#[test]
fn invalid_zero_in_features() {
let err = GcnLayer::new(GcnConfig {
in_features: 0,
out_features: 4,
bias: false,
normalize: true,
});
assert!(err.is_err());
}
#[test]
fn invalid_zero_out_features() {
let err = GcnLayer::new(GcnConfig {
in_features: 4,
out_features: 0,
bias: false,
normalize: true,
});
assert!(err.is_err());
}
#[test]
fn feature_mismatch_error() {
let g = simple_graph(); let config = GcnConfig {
in_features: 3,
out_features: 3,
bias: false,
normalize: true,
};
let layer = GcnLayer::new(config).expect("test invariant: value must be valid");
let feats = vec![1.0_f32; 3 * 3]; let w = identity_weight(3);
let err = layer.forward(&g, &feats, &w, None);
assert!(matches!(err, Err(GnnError::NodeFeatureMismatch(..))));
}
#[test]
fn normalize_and_nonnormalize_differ() {
let g = simple_graph();
let feats = vec![1.0_f32; 4 * 2];
let w = vec![0.5_f32; 2 * 2];
let layer_norm = GcnLayer::new(GcnConfig {
in_features: 2,
out_features: 2,
bias: false,
normalize: true,
})
.expect("test invariant: value must be valid");
let layer_plain = GcnLayer::new(GcnConfig {
in_features: 2,
out_features: 2,
bias: false,
normalize: false,
})
.expect("test invariant: value must be valid");
let out_norm = layer_norm
.forward(&g, &feats, &w, None)
.expect("test invariant: value must be valid");
let out_plain = layer_plain
.forward(&g, &feats, &w, None)
.expect("test invariant: value must be valid");
let same = out_norm
.iter()
.zip(out_plain.iter())
.all(|(a, b)| (a - b).abs() < 1e-6);
assert!(!same || out_norm.iter().all(|&v| v.abs() < 1e-6));
}
#[test]
fn identity_weight_preserves_features_no_normalize() {
let g = CsrGraph::from_edges(3, &[(0, 1), (1, 0), (1, 2), (2, 1)])
.expect("test invariant: value must be valid");
let config = GcnConfig {
in_features: 2,
out_features: 2,
bias: false,
normalize: false,
};
let layer = GcnLayer::new(config).expect("test invariant: value must be valid");
let feats = vec![1.0_f32, 0.0, 0.0, 1.0, 1.0, 1.0];
let w = identity_weight(2);
let out = layer
.forward(&g, &feats, &w, None)
.expect("test invariant: value must be valid");
assert_eq!(out.len(), 6);
assert!(out.iter().all(|&v| v >= 0.0));
}
}