use crate::error::{GnnError, GnnResult};
use crate::handle::LcgRng;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PairOp {
Identity,
Transpose,
DiagToDiag,
RowSumToDiag,
ColSumToDiag,
TraceToDiag,
TotalToDiag,
DiagToRows,
DiagToCols,
RowSumToRows,
ColSumToCols,
RowSumToCols,
ColSumToRows,
TraceBroadcast,
TotalBroadcast,
}
impl PairOp {
pub const ALL: [PairOp; 15] = [
PairOp::Identity,
PairOp::Transpose,
PairOp::DiagToDiag,
PairOp::RowSumToDiag,
PairOp::ColSumToDiag,
PairOp::TraceToDiag,
PairOp::TotalToDiag,
PairOp::DiagToRows,
PairOp::DiagToCols,
PairOp::RowSumToRows,
PairOp::ColSumToCols,
PairOp::RowSumToCols,
PairOp::ColSumToRows,
PairOp::TraceBroadcast,
PairOp::TotalBroadcast,
];
}
struct Reductions {
diag: Vec<f32>,
row: Vec<f32>,
col: Vec<f32>,
trace: f32,
total: f32,
}
impl Reductions {
fn compute(x: &[f32], n: usize, dim: usize, c: usize) -> Self {
let mut diag = vec![0.0_f32; n];
let mut row = vec![0.0_f32; n];
let mut col = vec![0.0_f32; n];
let mut trace = 0.0_f32;
let mut total = 0.0_f32;
for i in 0..n {
for j in 0..n {
let v = x[(i * n + j) * dim + c];
row[i] += v;
col[j] += v;
total += v;
if i == j {
diag[i] = v;
trace += v;
}
}
}
Self {
diag,
row,
col,
trace,
total,
}
}
#[inline]
fn op_value(
&self,
op: PairOp,
x: &[f32],
n: usize,
dim: usize,
c: usize,
i: usize,
j: usize,
) -> f32 {
let on_diag = i == j;
match op {
PairOp::Identity => x[(i * n + j) * dim + c],
PairOp::Transpose => x[(j * n + i) * dim + c],
PairOp::DiagToDiag => {
if on_diag {
self.diag[i]
} else {
0.0
}
}
PairOp::RowSumToDiag => {
if on_diag {
self.row[i]
} else {
0.0
}
}
PairOp::ColSumToDiag => {
if on_diag {
self.col[i]
} else {
0.0
}
}
PairOp::TraceToDiag => {
if on_diag {
self.trace
} else {
0.0
}
}
PairOp::TotalToDiag => {
if on_diag {
self.total
} else {
0.0
}
}
PairOp::DiagToRows => self.diag[i],
PairOp::DiagToCols => self.diag[j],
PairOp::RowSumToRows => self.row[i],
PairOp::ColSumToCols => self.col[j],
PairOp::RowSumToCols => self.row[j],
PairOp::ColSumToRows => self.col[i],
PairOp::TraceBroadcast => self.trace,
PairOp::TotalBroadcast => self.total,
}
}
}
pub fn apply_pair_op(op: PairOp, x: &[f32], n: usize, dim: usize) -> GnnResult<Vec<f32>> {
if n == 0 {
return Err(GnnError::EmptyGraph);
}
if dim == 0 {
return Err(GnnError::InvalidLayerConfig(
"k-WL: dim must be > 0".to_string(),
));
}
if x.len() != n * n * dim {
return Err(GnnError::DimensionMismatch {
expected: n * n * dim,
got: x.len(),
});
}
let mut out = vec![0.0_f32; n * n * dim];
for c in 0..dim {
let red = Reductions::compute(x, n, dim, c);
for i in 0..n {
for j in 0..n {
out[(i * n + j) * dim + c] = red.op_value(op, x, n, dim, c, i, j);
}
}
}
Ok(out)
}
#[derive(Debug, Clone, Copy)]
pub struct KWlConfig {
pub in_features: usize,
pub out_features: usize,
pub seed: u64,
}
pub struct KWlGnn {
config: KWlConfig,
weight: Vec<f32>,
bias_all: Vec<f32>,
bias_diag: Vec<f32>,
}
impl KWlGnn {
pub fn new(config: KWlConfig) -> GnnResult<Self> {
if config.in_features == 0 {
return Err(GnnError::InvalidLayerConfig(
"k-WL: in_features must be > 0".to_string(),
));
}
if config.out_features == 0 {
return Err(GnnError::InvalidLayerConfig(
"k-WL: out_features must be > 0".to_string(),
));
}
let n_ops = PairOp::ALL.len();
let n_w = n_ops * config.in_features * config.out_features;
let mut rng = LcgRng::new(config.seed);
let weight: Vec<f32> = (0..n_w).map(|_| centered_unit(&mut rng) * 0.3).collect();
let bias_all: Vec<f32> = (0..config.out_features)
.map(|_| centered_unit(&mut rng) * 0.1)
.collect();
let bias_diag: Vec<f32> = (0..config.out_features)
.map(|_| centered_unit(&mut rng) * 0.1)
.collect();
Ok(Self {
config,
weight,
bias_all,
bias_diag,
})
}
#[inline]
pub fn output_dim(&self) -> usize {
self.config.out_features
}
pub fn forward(
&self,
pair_features: &[f32],
n_nodes: usize,
dim: usize,
) -> GnnResult<Vec<f32>> {
if n_nodes == 0 {
return Err(GnnError::EmptyGraph);
}
if dim != self.config.in_features {
return Err(GnnError::DimensionMismatch {
expected: self.config.in_features,
got: dim,
});
}
if pair_features.len() != n_nodes * n_nodes * dim {
return Err(GnnError::DimensionMismatch {
expected: n_nodes * n_nodes * dim,
got: pair_features.len(),
});
}
let n = n_nodes;
let in_f = self.config.in_features;
let out_f = self.config.out_features;
let reductions: Vec<Reductions> = (0..in_f)
.map(|c| Reductions::compute(pair_features, n, in_f, c))
.collect();
let mut out = vec![0.0_f32; n * n * out_f];
for i in 0..n {
for j in 0..n {
let on_diag = i == j;
for cp in 0..out_f {
let mut acc = self.bias_all[cp];
if on_diag {
acc += self.bias_diag[cp];
}
for (op_idx, &op) in PairOp::ALL.iter().enumerate() {
for (c, red) in reductions.iter().enumerate() {
let w = self.weight[(op_idx * in_f + c) * out_f + cp];
if w != 0.0 {
acc += w * red.op_value(op, pair_features, n, in_f, c, i, j);
}
}
}
out[(i * n + j) * out_f + cp] = acc.max(0.0);
}
}
}
if out.iter().any(|v| !v.is_finite()) {
return Err(GnnError::NonFiniteOutput("KWlGnn::forward"));
}
Ok(out)
}
pub fn graph_readout(
&self,
pair_features: &[f32],
n_nodes: usize,
dim: usize,
) -> GnnResult<Vec<f32>> {
graph_readout_sum(pair_features, n_nodes, dim)
}
}
pub fn graph_readout_sum(pair_features: &[f32], n_nodes: usize, dim: usize) -> GnnResult<Vec<f32>> {
if n_nodes == 0 {
return Err(GnnError::EmptyGraph);
}
if dim == 0 {
return Err(GnnError::InvalidLayerConfig(
"k-WL readout: dim must be > 0".to_string(),
));
}
if pair_features.len() != n_nodes * n_nodes * dim {
return Err(GnnError::DimensionMismatch {
expected: n_nodes * n_nodes * dim,
got: pair_features.len(),
});
}
let mut g = vec![0.0_f32; dim];
for i in 0..n_nodes {
for j in 0..n_nodes {
for c in 0..dim {
g[c] += pair_features[(i * n_nodes + j) * dim + c];
}
}
}
Ok(g)
}
#[inline]
fn centered_unit(rng: &mut LcgRng) -> f32 {
let unit = (rng.next_u32() as f32) / 4_294_967_296.0_f32; unit * 2.0 - 1.0
}
#[cfg(test)]
mod tests {
use super::*;
fn permute_pairs(x: &[f32], n: usize, dim: usize, perm: &[usize]) -> Vec<f32> {
let mut out = vec![0.0_f32; n * n * dim];
for i in 0..n {
for j in 0..n {
for c in 0..dim {
out[(i * n + j) * dim + c] = x[(perm[i] * n + perm[j]) * dim + c];
}
}
}
out
}
fn arange_pairs(n: usize, dim: usize) -> Vec<f32> {
(0..n * n * dim).map(|i| (i as f32) * 0.07 - 0.5).collect()
}
#[test]
fn new_valid() {
let layer = KWlGnn::new(KWlConfig {
in_features: 2,
out_features: 3,
seed: 7,
})
.expect("test invariant: value must be valid");
assert_eq!(layer.output_dim(), 3);
}
#[test]
fn new_invalid_zero_in() {
assert!(
KWlGnn::new(KWlConfig {
in_features: 0,
out_features: 3,
seed: 1,
})
.is_err()
);
}
#[test]
fn new_invalid_zero_out() {
assert!(
KWlGnn::new(KWlConfig {
in_features: 2,
out_features: 0,
seed: 1,
})
.is_err()
);
}
#[test]
fn fifteen_basis_ops() {
assert_eq!(PairOp::ALL.len(), 15);
}
#[test]
fn forward_permutation_equivariant() {
let n = 4;
let din = 2;
let dout = 3;
let layer = KWlGnn::new(KWlConfig {
in_features: din,
out_features: dout,
seed: 2024,
})
.expect("test invariant: value must be valid");
let x = arange_pairs(n, din);
let perm = [2usize, 0, 3, 1];
let y = layer
.forward(&x, n, din)
.expect("test invariant: value must be valid");
let y_permuted = permute_pairs(&y, n, dout, &perm);
let x_permuted = permute_pairs(&x, n, din, &perm);
let y_from_permuted = layer
.forward(&x_permuted, n, din)
.expect("test invariant: value must be valid");
for (a, b) in y_permuted.iter().zip(y_from_permuted.iter()) {
assert!((a - b).abs() < 1e-4, "equivariance broken: {a} vs {b}");
}
}
#[test]
fn forward_shape_and_finite() {
let n = 5;
let din = 2;
let dout = 2;
let layer = KWlGnn::new(KWlConfig {
in_features: din,
out_features: dout,
seed: 11,
})
.expect("test invariant: value must be valid");
let x = arange_pairs(n, din);
let y = layer
.forward(&x, n, din)
.expect("test invariant: value must be valid");
assert_eq!(y.len(), n * n * dout);
assert!(y.iter().all(|v| v.is_finite()));
}
#[test]
fn transpose_op_symmetric_input_symmetric_output() {
let n = 3;
let dim = 1;
let mut x = vec![0.0_f32; n * n * dim];
let vals = [[1.0, 2.0, 3.0], [2.0, 4.0, 5.0], [3.0, 5.0, 6.0]];
for i in 0..n {
for j in 0..n {
x[(i * n + j) * dim] = vals[i][j];
}
}
let y = apply_pair_op(PairOp::Transpose, &x, n, dim)
.expect("test invariant: value must be valid");
for i in 0..n {
for j in 0..n {
let a = y[(i * n + j) * dim];
let b = y[(j * n + i) * dim];
assert!((a - b).abs() < 1e-6, "not symmetric at ({i},{j})");
}
}
for (a, b) in y.iter().zip(x.iter()) {
assert!((a - b).abs() < 1e-6);
}
}
#[test]
fn symmetrization_always_symmetric() {
let n = 3;
let dim = 1;
let x = arange_pairs(n, dim); let id = apply_pair_op(PairOp::Identity, &x, n, dim)
.expect("test invariant: value must be valid");
let tr = apply_pair_op(PairOp::Transpose, &x, n, dim)
.expect("test invariant: value must be valid");
let sym: Vec<f32> = id.iter().zip(tr.iter()).map(|(a, b)| a + b).collect();
for i in 0..n {
for j in 0..n {
let a = sym[(i * n + j) * dim];
let b = sym[(j * n + i) * dim];
assert!((a - b).abs() < 1e-6, "symmetrised output not symmetric");
}
}
}
#[test]
fn basis_ops_match_definition() {
let n = 2;
let dim = 1;
let x = vec![1.0_f32, 2.0, 3.0, 4.0];
let get = |op| apply_pair_op(op, &x, n, dim).expect("op");
assert_eq!(get(PairOp::Identity), vec![1.0, 2.0, 3.0, 4.0]);
assert_eq!(get(PairOp::Transpose), vec![1.0, 3.0, 2.0, 4.0]);
assert_eq!(get(PairOp::DiagToDiag), vec![1.0, 0.0, 0.0, 4.0]);
assert_eq!(get(PairOp::RowSumToDiag), vec![3.0, 0.0, 0.0, 7.0]);
assert_eq!(get(PairOp::ColSumToDiag), vec![4.0, 0.0, 0.0, 6.0]);
assert_eq!(get(PairOp::TraceToDiag), vec![5.0, 0.0, 0.0, 5.0]);
assert_eq!(get(PairOp::TotalToDiag), vec![10.0, 0.0, 0.0, 10.0]);
assert_eq!(get(PairOp::DiagToRows), vec![1.0, 1.0, 4.0, 4.0]);
assert_eq!(get(PairOp::DiagToCols), vec![1.0, 4.0, 1.0, 4.0]);
assert_eq!(get(PairOp::RowSumToRows), vec![3.0, 3.0, 7.0, 7.0]);
assert_eq!(get(PairOp::ColSumToCols), vec![4.0, 6.0, 4.0, 6.0]);
assert_eq!(get(PairOp::RowSumToCols), vec![3.0, 7.0, 3.0, 7.0]);
assert_eq!(get(PairOp::ColSumToRows), vec![4.0, 4.0, 6.0, 6.0]);
assert_eq!(get(PairOp::TraceBroadcast), vec![5.0, 5.0, 5.0, 5.0]);
assert_eq!(get(PairOp::TotalBroadcast), vec![10.0, 10.0, 10.0, 10.0]);
}
#[test]
fn basis_ops_individually_equivariant() {
let n = 4;
let dim = 2;
let x = arange_pairs(n, dim);
let perm = [3usize, 1, 0, 2];
for &op in PairOp::ALL.iter() {
let y = apply_pair_op(op, &x, n, dim).expect("op");
let y_perm = permute_pairs(&y, n, dim, &perm);
let xp = permute_pairs(&x, n, dim, &perm);
let y2 = apply_pair_op(op, &xp, n, dim).expect("op");
for (a, b) in y_perm.iter().zip(y2.iter()) {
assert!((a - b).abs() < 1e-4, "op {op:?} not equivariant");
}
}
}
#[test]
fn forward_dim_mismatch_errors() {
let layer = KWlGnn::new(KWlConfig {
in_features: 2,
out_features: 2,
seed: 1,
})
.expect("test invariant: value must be valid");
let x = vec![0.0_f32; 3 * 3 * 3]; let err = layer.forward(&x, 3, 3);
assert!(matches!(err, Err(GnnError::DimensionMismatch { .. })));
}
#[test]
fn forward_length_mismatch_errors() {
let layer = KWlGnn::new(KWlConfig {
in_features: 2,
out_features: 2,
seed: 1,
})
.expect("test invariant: value must be valid");
let x = vec![0.0_f32; 10]; let err = layer.forward(&x, 3, 2);
assert!(matches!(err, Err(GnnError::DimensionMismatch { .. })));
}
#[test]
fn forward_empty_graph_errors() {
let layer = KWlGnn::new(KWlConfig {
in_features: 2,
out_features: 2,
seed: 1,
})
.expect("test invariant: value must be valid");
let err = layer.forward(&[], 0, 2);
assert!(matches!(err, Err(GnnError::EmptyGraph)));
}
#[test]
fn apply_op_bad_len_errors() {
let err = apply_pair_op(PairOp::Identity, &[1.0, 2.0], 2, 1);
assert!(matches!(err, Err(GnnError::DimensionMismatch { .. })));
}
#[test]
fn readout_permutation_invariant() {
let n = 4;
let dim = 2;
let layer = KWlGnn::new(KWlConfig {
in_features: dim,
out_features: dim,
seed: 5,
})
.expect("test invariant: value must be valid");
let x = arange_pairs(n, dim);
let perm = [1usize, 3, 0, 2];
let g1 = layer
.graph_readout(&x, n, dim)
.expect("test invariant: value must be valid");
let xp = permute_pairs(&x, n, dim, &perm);
let g2 = layer
.graph_readout(&xp, n, dim)
.expect("test invariant: value must be valid");
for (a, b) in g1.iter().zip(g2.iter()) {
assert!((a - b).abs() < 1e-4, "readout not invariant: {a} vs {b}");
}
}
#[test]
fn readout_sums_all_pairs() {
let n = 2;
let dim = 1;
let x = vec![1.0_f32, 2.0, 3.0, 4.0];
let g = graph_readout_sum(&x, n, dim).expect("test invariant: value must be valid");
assert_eq!(g.len(), 1);
assert!((g[0] - 10.0).abs() < 1e-6);
}
#[test]
fn readout_bad_len_errors() {
let err = graph_readout_sum(&[1.0, 2.0, 3.0], 2, 1);
assert!(matches!(err, Err(GnnError::DimensionMismatch { .. })));
}
#[test]
fn forward_deterministic_with_seed() {
let n = 3;
let dim = 2;
let x = arange_pairs(n, dim);
let make = || {
KWlGnn::new(KWlConfig {
in_features: dim,
out_features: dim,
seed: 99,
})
.expect("test invariant: value must be valid")
.forward(&x, n, dim)
.expect("test invariant: value must be valid")
};
assert_eq!(make(), make());
}
}