use crate::error::{GnnError, GnnResult};
use crate::graph::csr::CsrGraph;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EdgeConvMode {
CenterDiff,
DiffOnly,
Concat,
}
#[derive(Debug, Clone)]
pub struct EdgeConvConfig {
pub in_features: usize,
pub hidden_features: usize,
pub out_features: usize,
pub self_loop: bool,
pub mode: EdgeConvMode,
}
impl Default for EdgeConvConfig {
fn default() -> Self {
Self {
in_features: 16,
hidden_features: 64,
out_features: 16,
self_loop: true,
mode: EdgeConvMode::CenterDiff,
}
}
}
pub struct EdgeConvLayer {
pub config: EdgeConvConfig,
}
impl EdgeConvLayer {
pub fn new(config: EdgeConvConfig) -> GnnResult<Self> {
if config.in_features == 0 {
return Err(GnnError::InvalidLayerConfig(
"in_features must be > 0".to_string(),
));
}
if config.hidden_features == 0 {
return Err(GnnError::InvalidLayerConfig(
"hidden_features must be > 0".to_string(),
));
}
if config.out_features == 0 {
return Err(GnnError::InvalidLayerConfig(
"out_features must be > 0".to_string(),
));
}
Ok(Self { config })
}
#[inline]
pub fn edge_feature_dim(&self) -> usize {
match self.config.mode {
EdgeConvMode::CenterDiff | EdgeConvMode::Concat => 2 * self.config.in_features,
EdgeConvMode::DiffOnly => self.config.in_features,
}
}
pub fn forward(
&self,
graph: &CsrGraph,
x: &[f32],
w1: &[f32],
b1: &[f32],
w2: &[f32],
b2: &[f32],
) -> GnnResult<Vec<f32>> {
let n = graph.n_nodes();
let in_f = self.config.in_features;
let hid = self.config.hidden_features;
let out_f = self.config.out_features;
let edge_dim = self.edge_feature_dim();
if x.len() != n * in_f {
return Err(GnnError::NodeFeatureMismatch(n, x.len() / in_f.max(1)));
}
if w1.len() != hid * edge_dim {
return Err(GnnError::WeightShapeMismatch {
r: hid,
c: edge_dim,
d: edge_dim,
});
}
if b1.len() != hid {
return Err(GnnError::DimensionMismatch {
expected: hid,
got: b1.len(),
});
}
if w2.len() != out_f * hid {
return Err(GnnError::WeightShapeMismatch {
r: out_f,
c: hid,
d: hid,
});
}
if b2.len() != out_f {
return Err(GnnError::DimensionMismatch {
expected: out_f,
got: b2.len(),
});
}
let mode = self.config.mode;
let self_loop = self.config.self_loop;
let mut output = vec![0.0_f32; n * out_f];
for i in 0..n {
let neighbors = graph.neighbors(i)?;
let h_i = &x[i * in_f..(i + 1) * in_f];
let mut agg = vec![f32::NEG_INFINITY; out_f];
let mut has_any_edge = false;
for &j in neighbors {
let h_j = &x[j * in_f..(j + 1) * in_f];
let e = edge_feature(h_i, h_j, mode);
let f = apply_mlp(&e, edge_dim, w1, b1, hid, w2, b2, out_f);
for k in 0..out_f {
if f[k] > agg[k] {
agg[k] = f[k];
}
}
has_any_edge = true;
}
if self_loop {
let e = edge_feature(h_i, h_i, mode);
let f = apply_mlp(&e, edge_dim, w1, b1, hid, w2, b2, out_f);
for k in 0..out_f {
if f[k] > agg[k] {
agg[k] = f[k];
}
}
has_any_edge = true;
}
if has_any_edge {
for k in 0..out_f {
if agg[k].is_infinite() {
agg[k] = 0.0;
}
output[i * out_f + k] = agg[k];
}
}
}
Ok(output)
}
}
pub fn edge_feature(h_i: &[f32], h_j: &[f32], mode: EdgeConvMode) -> Vec<f32> {
let d = h_i.len();
match mode {
EdgeConvMode::CenterDiff => {
let mut out = Vec::with_capacity(2 * d);
out.extend_from_slice(h_i);
for k in 0..d {
out.push(h_j[k] - h_i[k]);
}
out
}
EdgeConvMode::DiffOnly => (0..d).map(|k| h_j[k] - h_i[k]).collect(),
EdgeConvMode::Concat => {
let mut out = Vec::with_capacity(2 * d);
out.extend_from_slice(h_i);
out.extend_from_slice(h_j);
out
}
}
}
fn apply_mlp(
e: &[f32],
edge_dim: usize,
w1: &[f32],
b1: &[f32],
hid: usize,
w2: &[f32],
b2: &[f32],
out_f: usize,
) -> Vec<f32> {
let mut h = vec![0.0_f32; hid];
for k in 0..hid {
let mut acc = b1[k];
for j in 0..edge_dim {
acc += w1[k * edge_dim + j] * e[j];
}
h[k] = acc.max(0.0); }
let mut f = vec![0.0_f32; out_f];
for k in 0..out_f {
let mut acc = b2[k];
for j in 0..hid {
acc += w2[k * hid + j] * h[j];
}
f[k] = acc;
}
f
}
#[cfg(test)]
mod tests {
use super::*;
fn triangle_graph() -> CsrGraph {
CsrGraph::from_edges(3, &[(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)])
.expect("test invariant: valid graph")
}
fn make_weights(
edge_dim: usize,
hid: usize,
out: usize,
) -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
let w1 = vec![0.1_f32; hid * edge_dim];
let b1 = vec![0.0_f32; hid];
let w2 = vec![0.1_f32; out * hid];
let b2 = vec![0.0_f32; out];
(w1, b1, w2, b2)
}
#[test]
fn edge_feature_center_diff_length() {
let h_i = vec![1.0_f32, 2.0];
let h_j = vec![3.0_f32, 4.0];
let e = edge_feature(&h_i, &h_j, EdgeConvMode::CenterDiff);
assert_eq!(e.len(), 4); }
#[test]
fn edge_feature_center_diff_values() {
let h_i = vec![1.0_f32, 2.0];
let h_j = vec![3.0_f32, 4.0];
let e = edge_feature(&h_i, &h_j, EdgeConvMode::CenterDiff);
assert!((e[0] - 1.0).abs() < 1e-6, "e[0] should be 1, got {}", e[0]);
assert!((e[1] - 2.0).abs() < 1e-6, "e[1] should be 2, got {}", e[1]);
assert!(
(e[2] - 2.0).abs() < 1e-6,
"e[2] should be h_j[0]-h_i[0]=2, got {}",
e[2]
);
assert!(
(e[3] - 2.0).abs() < 1e-6,
"e[3] should be h_j[1]-h_i[1]=2, got {}",
e[3]
);
}
#[test]
fn edge_feature_diff_only_values() {
let h_i = vec![1.0_f32, 2.0];
let h_j = vec![3.0_f32, 4.0];
let e = edge_feature(&h_i, &h_j, EdgeConvMode::DiffOnly);
assert_eq!(e.len(), 2, "DiffOnly length should be in_features");
assert!((e[0] - 2.0).abs() < 1e-6, "e[0] should be 2, got {}", e[0]);
assert!((e[1] - 2.0).abs() < 1e-6, "e[1] should be 2, got {}", e[1]);
}
#[test]
fn edge_feature_concat_values() {
let h_i = vec![1.0_f32, 2.0];
let h_j = vec![3.0_f32, 4.0];
let e = edge_feature(&h_i, &h_j, EdgeConvMode::Concat);
assert_eq!(e.len(), 4, "Concat length should be 2 * in_features");
assert!((e[0] - 1.0).abs() < 1e-6);
assert!((e[1] - 2.0).abs() < 1e-6);
assert!((e[2] - 3.0).abs() < 1e-6);
assert!((e[3] - 4.0).abs() < 1e-6);
}
#[test]
fn edge_feature_dim_center_diff() {
let cfg = EdgeConvConfig {
in_features: 8,
hidden_features: 16,
out_features: 4,
self_loop: true,
mode: EdgeConvMode::CenterDiff,
};
let layer = EdgeConvLayer::new(cfg).expect("valid config");
assert_eq!(layer.edge_feature_dim(), 16);
}
#[test]
fn edge_feature_dim_diff_only() {
let cfg = EdgeConvConfig {
in_features: 8,
hidden_features: 16,
out_features: 4,
self_loop: true,
mode: EdgeConvMode::DiffOnly,
};
let layer = EdgeConvLayer::new(cfg).expect("valid config");
assert_eq!(layer.edge_feature_dim(), 8);
}
#[test]
fn edge_feature_dim_concat() {
let cfg = EdgeConvConfig {
in_features: 8,
hidden_features: 16,
out_features: 4,
self_loop: false,
mode: EdgeConvMode::Concat,
};
let layer = EdgeConvLayer::new(cfg).expect("valid config");
assert_eq!(layer.edge_feature_dim(), 16);
}
#[test]
fn edgeconv_new_invalid_in_features() {
let cfg = EdgeConvConfig {
in_features: 0,
..Default::default()
};
assert!(EdgeConvLayer::new(cfg).is_err());
}
#[test]
fn edgeconv_new_invalid_hidden_features() {
let cfg = EdgeConvConfig {
hidden_features: 0,
..Default::default()
};
assert!(EdgeConvLayer::new(cfg).is_err());
}
#[test]
fn edgeconv_new_invalid_out_features() {
let cfg = EdgeConvConfig {
out_features: 0,
..Default::default()
};
assert!(EdgeConvLayer::new(cfg).is_err());
}
#[test]
fn edgeconv_forward_output_shape() {
let graph = triangle_graph();
let cfg = EdgeConvConfig {
in_features: 2,
hidden_features: 4,
out_features: 3,
self_loop: true,
mode: EdgeConvMode::CenterDiff,
};
let layer = EdgeConvLayer::new(cfg.clone()).expect("valid config");
let edge_dim = layer.edge_feature_dim();
let (w1, b1, w2, b2) = make_weights(edge_dim, cfg.hidden_features, cfg.out_features);
let x = vec![0.5_f32; 3 * cfg.in_features];
let out = layer
.forward(&graph, &x, &w1, &b1, &w2, &b2)
.expect("forward should succeed");
assert_eq!(out.len(), 3 * cfg.out_features);
}
#[test]
fn edgeconv_forward_self_loop_true_isolated_node_nonzero() {
let graph = CsrGraph::from_edges(3, &[(0, 1), (1, 0)]).expect("valid graph");
let cfg = EdgeConvConfig {
in_features: 2,
hidden_features: 4,
out_features: 2,
self_loop: true,
mode: EdgeConvMode::CenterDiff,
};
let layer = EdgeConvLayer::new(cfg.clone()).expect("valid config");
let edge_dim = layer.edge_feature_dim();
let w1 = vec![1.0_f32; cfg.hidden_features * edge_dim];
let b1 = vec![0.1_f32; cfg.hidden_features];
let w2 = vec![1.0_f32; cfg.out_features * cfg.hidden_features];
let b2 = vec![0.1_f32; cfg.out_features];
let x = vec![1.0_f32; 3 * cfg.in_features];
let out = layer
.forward(&graph, &x, &w1, &b1, &w2, &b2)
.expect("forward should succeed");
assert_eq!(out.len(), 3 * cfg.out_features);
let node2_out = &out[2 * cfg.out_features..(2 + 1) * cfg.out_features];
assert!(
node2_out.iter().all(|v| v.is_finite()),
"self-loop output must be finite"
);
let max_val: f32 = node2_out.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
assert!(
max_val > 0.0,
"self-loop should produce nonzero output, got {max_val}"
);
}
#[test]
fn edgeconv_forward_self_loop_false_isolated_node_zeros() {
let graph = CsrGraph::from_edges(3, &[(0, 1), (1, 0)]).expect("valid graph");
let cfg = EdgeConvConfig {
in_features: 2,
hidden_features: 4,
out_features: 2,
self_loop: false,
mode: EdgeConvMode::CenterDiff,
};
let layer = EdgeConvLayer::new(cfg.clone()).expect("valid config");
let edge_dim = layer.edge_feature_dim();
let (w1, b1, w2, b2) = make_weights(edge_dim, cfg.hidden_features, cfg.out_features);
let x = vec![1.0_f32; 3 * cfg.in_features];
let out = layer
.forward(&graph, &x, &w1, &b1, &w2, &b2)
.expect("forward should succeed");
let node2_out = &out[2 * cfg.out_features..(2 + 1) * cfg.out_features];
assert!(
node2_out.iter().all(|&v| v == 0.0),
"isolated node with no self_loop should be zeros, got {node2_out:?}"
);
}
#[test]
fn edgeconv_forward_output_finite() {
let graph = triangle_graph();
let cfg = EdgeConvConfig::default();
let layer = EdgeConvLayer::new(cfg.clone()).expect("valid config");
let edge_dim = layer.edge_feature_dim();
let (w1, b1, w2, b2) = make_weights(edge_dim, cfg.hidden_features, cfg.out_features);
let x: Vec<f32> = (0..3 * cfg.in_features).map(|i| i as f32 * 0.1).collect();
let out = layer
.forward(&graph, &x, &w1, &b1, &w2, &b2)
.expect("forward should succeed");
assert!(
out.iter().all(|v| v.is_finite()),
"all outputs should be finite"
);
}
#[test]
fn edgeconv_forward_wrong_w1_shape_error() {
let graph = triangle_graph();
let cfg = EdgeConvConfig {
in_features: 2,
hidden_features: 4,
out_features: 2,
self_loop: true,
mode: EdgeConvMode::CenterDiff,
};
let layer = EdgeConvLayer::new(cfg.clone()).expect("valid config");
let x = vec![0.0_f32; 3 * cfg.in_features];
let w1 = vec![0.1_f32; 3]; let b1 = vec![0.0_f32; cfg.hidden_features];
let w2 = vec![0.1_f32; cfg.out_features * cfg.hidden_features];
let b2 = vec![0.0_f32; cfg.out_features];
assert!(layer.forward(&graph, &x, &w1, &b1, &w2, &b2).is_err());
}
#[test]
fn edgeconv_forward_wrong_x_shape_error() {
let graph = triangle_graph();
let cfg = EdgeConvConfig {
in_features: 4,
hidden_features: 8,
out_features: 4,
self_loop: true,
mode: EdgeConvMode::CenterDiff,
};
let layer = EdgeConvLayer::new(cfg.clone()).expect("valid config");
let edge_dim = layer.edge_feature_dim();
let (w1, b1, w2, b2) = make_weights(edge_dim, cfg.hidden_features, cfg.out_features);
let x = vec![0.0_f32; 2 * cfg.in_features]; assert!(layer.forward(&graph, &x, &w1, &b1, &w2, &b2).is_err());
}
#[test]
fn edgeconv_center_diff_identical_features_diff_zero() {
let graph = triangle_graph();
let cfg = EdgeConvConfig {
in_features: 2,
hidden_features: 4,
out_features: 2,
self_loop: false,
mode: EdgeConvMode::CenterDiff,
};
let layer = EdgeConvLayer::new(cfg.clone()).expect("valid config");
let edge_dim = layer.edge_feature_dim();
let (w1, b1, w2, b2) = make_weights(edge_dim, cfg.hidden_features, cfg.out_features);
let x = vec![1.0_f32; 3 * cfg.in_features];
let out = layer
.forward(&graph, &x, &w1, &b1, &w2, &b2)
.expect("forward ok");
let row0 = &out[0..cfg.out_features];
let row1 = &out[cfg.out_features..2 * cfg.out_features];
let row2 = &out[2 * cfg.out_features..3 * cfg.out_features];
for k in 0..cfg.out_features {
assert!(
(row0[k] - row1[k]).abs() < 1e-5,
"rows 0 and 1 should match"
);
assert!(
(row0[k] - row2[k]).abs() < 1e-5,
"rows 0 and 2 should match"
);
}
}
#[test]
fn edgeconv_single_node_self_loop_works() {
let graph = CsrGraph::from_edges(1, &[(0, 0)]).expect("valid graph");
let cfg = EdgeConvConfig {
in_features: 2,
hidden_features: 4,
out_features: 2,
self_loop: true,
mode: EdgeConvMode::CenterDiff,
};
let layer = EdgeConvLayer::new(cfg.clone()).expect("valid config");
let edge_dim = layer.edge_feature_dim();
let w1 = vec![1.0_f32; cfg.hidden_features * edge_dim];
let b1 = vec![0.1_f32; cfg.hidden_features];
let w2 = vec![1.0_f32; cfg.out_features * cfg.hidden_features];
let b2 = vec![0.1_f32; cfg.out_features];
let x = vec![1.0_f32, 2.0];
let out = layer
.forward(&graph, &x, &w1, &b1, &w2, &b2)
.expect("forward ok");
assert_eq!(out.len(), cfg.out_features);
assert!(
out.iter().all(|v| v.is_finite()),
"single-node self-loop output must be finite"
);
}
#[test]
fn edgeconv_diff_only_translation_equivariance() {
let graph = triangle_graph();
let cfg = EdgeConvConfig {
in_features: 2,
hidden_features: 4,
out_features: 2,
self_loop: false,
mode: EdgeConvMode::DiffOnly,
};
let layer = EdgeConvLayer::new(cfg.clone()).expect("valid config");
let edge_dim = layer.edge_feature_dim();
let (w1, b1, w2, b2) = make_weights(edge_dim, cfg.hidden_features, cfg.out_features);
let x_base: Vec<f32> = (0..3 * cfg.in_features).map(|i| i as f32).collect();
let shift = 100.0_f32;
let x_shifted: Vec<f32> = x_base.iter().map(|v| v + shift).collect();
let out_base = layer
.forward(&graph, &x_base, &w1, &b1, &w2, &b2)
.expect("forward ok");
let out_shifted = layer
.forward(&graph, &x_shifted, &w1, &b1, &w2, &b2)
.expect("forward ok");
for (a, b) in out_base.iter().zip(out_shifted.iter()) {
assert!(
(a - b).abs() < 1e-4,
"DiffOnly should be translation-equivariant; diff={}",
(a - b).abs()
);
}
}
#[test]
fn edgeconv_concat_mode_output_shape() {
let graph = triangle_graph();
let cfg = EdgeConvConfig {
in_features: 3,
hidden_features: 6,
out_features: 4,
self_loop: true,
mode: EdgeConvMode::Concat,
};
let layer = EdgeConvLayer::new(cfg.clone()).expect("valid config");
let edge_dim = layer.edge_feature_dim();
let (w1, b1, w2, b2) = make_weights(edge_dim, cfg.hidden_features, cfg.out_features);
let x: Vec<f32> = (0..3 * cfg.in_features).map(|i| i as f32 * 0.5).collect();
let out = layer
.forward(&graph, &x, &w1, &b1, &w2, &b2)
.expect("forward ok");
assert_eq!(out.len(), 3 * cfg.out_features);
}
#[test]
fn edgeconv_default_config() {
let cfg = EdgeConvConfig::default();
assert_eq!(cfg.mode, EdgeConvMode::CenterDiff);
assert!(cfg.self_loop);
assert!(EdgeConvLayer::new(cfg).is_ok());
}
}