use crate::error::{GnnError, GnnResult};
use crate::graph::csr::CsrGraph;
pub fn random_walk_se(graph: &CsrGraph, walk_length: usize) -> GnnResult<Vec<f32>> {
if walk_length == 0 {
return Err(GnnError::InvalidLayerConfig(
"RWSE: walk_length must be >= 1".to_string(),
));
}
let n = graph.n_nodes();
let mut inv_deg = vec![0.0_f32; n];
for (i, slot) in inv_deg.iter_mut().enumerate() {
let w = graph.edge_weights(i)?;
let deg: f32 = w.iter().sum();
*slot = if deg > 0.0 { 1.0 / deg } else { 0.0 };
}
let mut walk = vec![0.0_f32; n * n];
for i in 0..n {
let nbrs = graph.neighbors(i)?;
let wts = graph.edge_weights(i)?;
let inv = inv_deg[i];
for (idx, &j) in nbrs.iter().enumerate() {
walk[i * n + j] += wts[idx] * inv;
}
}
let mut enc = vec![0.0_f32; n * walk_length];
for i in 0..n {
enc[i * walk_length] = walk[i * n + i];
}
let mut next = vec![0.0_f32; n * n];
for k in 1..walk_length {
for slot in next.iter_mut() {
*slot = 0.0;
}
for i in 0..n {
let nbrs = graph.neighbors(i)?;
let wts = graph.edge_weights(i)?;
let inv = inv_deg[i];
for (idx, &j) in nbrs.iter().enumerate() {
let m_ij = wts[idx] * inv;
if m_ij != 0.0 {
for c in 0..n {
next[i * n + c] += m_ij * walk[j * n + c];
}
}
}
}
std::mem::swap(&mut walk, &mut next);
for i in 0..n {
enc[i * walk_length + k] = walk[i * n + i];
}
}
if enc.iter().any(|v| !v.is_finite()) {
return Err(GnnError::NonFiniteOutput("random_walk_se"));
}
Ok(enc)
}
#[derive(Debug, Clone, Copy)]
pub struct RwseConfig {
pub walk_length: usize,
}
pub struct RwseEncoder {
config: RwseConfig,
}
impl RwseEncoder {
pub fn new(config: RwseConfig) -> GnnResult<Self> {
if config.walk_length == 0 {
return Err(GnnError::InvalidLayerConfig(
"RWSE: walk_length must be >= 1".to_string(),
));
}
Ok(Self { config })
}
#[inline]
pub fn encoding_dim(&self) -> usize {
self.config.walk_length
}
pub fn encode(&self, graph: &CsrGraph) -> GnnResult<Vec<f32>> {
random_walk_se(graph, self.config.walk_length)
}
pub fn augment(&self, graph: &CsrGraph, x: &[f32], feat_dim: usize) -> GnnResult<Vec<f32>> {
let n = graph.n_nodes();
if x.len() != n * feat_dim {
return Err(GnnError::NodeFeatureMismatch(n, x.len() / feat_dim.max(1)));
}
let enc = self.encode(graph)?;
let k = self.config.walk_length;
let out_dim = feat_dim + k;
let mut out = vec![0.0_f32; n * out_dim];
for i in 0..n {
for d in 0..feat_dim {
out[i * out_dim + d] = x[i * feat_dim + d];
}
for d in 0..k {
out[i * out_dim + feat_dim + d] = enc[i * k + d];
}
}
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn two_cycle() -> CsrGraph {
CsrGraph::from_edges(2, &[(0, 1), (1, 0)]).expect("test invariant: value must be valid")
}
fn triangle() -> CsrGraph {
CsrGraph::from_edges(3, &[(0, 1), (1, 0), (1, 2), (2, 1), (0, 2), (2, 0)])
.expect("test invariant: value must be valid")
}
#[test]
fn new_valid() {
let enc = RwseEncoder::new(RwseConfig { walk_length: 4 })
.expect("test invariant: value must be valid");
assert_eq!(enc.encoding_dim(), 4);
}
#[test]
fn new_invalid_zero_walk() {
assert!(RwseEncoder::new(RwseConfig { walk_length: 0 }).is_err());
}
#[test]
fn free_fn_zero_walk_errors() {
let g = two_cycle();
assert!(random_walk_se(&g, 0).is_err());
}
#[test]
fn self_loop_returns_one() {
let g = CsrGraph::from_edges(1, &[(0, 0)]).expect("test invariant: value must be valid");
let enc = random_walk_se(&g, 4).expect("test invariant: value must be valid");
assert_eq!(enc.len(), 4);
for (k, &p) in enc.iter().enumerate() {
assert!((p - 1.0).abs() < 1e-6, "step {k}: p={p}");
}
}
#[test]
fn two_cycle_alternates() {
let g = two_cycle();
let k = 4;
let enc = random_walk_se(&g, k).expect("test invariant: value must be valid");
let expected = [0.0_f32, 1.0, 0.0, 1.0];
for node in 0..2 {
for (step, &e) in expected.iter().enumerate() {
let p = enc[node * k + step];
assert!(
(p - e).abs() < 1e-6,
"node {node} step {step}: got {p}, want {e}"
);
}
}
}
#[test]
fn triangle_analytic_return_probs() {
let g = triangle();
let k = 3;
let enc = random_walk_se(&g, k).expect("test invariant: value must be valid");
let expected = [0.0_f32, 0.5, 0.25];
for node in 0..3 {
for (step, &e) in expected.iter().enumerate() {
let p = enc[node * k + step];
assert!(
(p - e).abs() < 1e-5,
"node {node} step {step}: got {p}, want {e}"
);
}
}
}
#[test]
fn return_probs_in_unit_interval() {
let g = triangle();
let enc = random_walk_se(&g, 6).expect("test invariant: value must be valid");
for &p in &enc {
assert!((-1e-6..=1.0 + 1e-6).contains(&p), "p out of range: {p}");
}
}
#[test]
fn first_step_zero_without_self_loops() {
let g = triangle();
let enc = random_walk_se(&g, 3).expect("test invariant: value must be valid");
for node in 0..3 {
assert!((enc[node * 3]).abs() < 1e-6, "node {node} step 1 not zero");
}
}
#[test]
fn permutation_permutes_rows() {
let g = CsrGraph::from_edges(3, &[(0, 1), (1, 0), (1, 2), (2, 1)])
.expect("test invariant: value must be valid");
let k = 4;
let enc = random_walk_se(&g, k).expect("test invariant: value must be valid");
let perm = [2usize, 1, 0];
let mut edges = Vec::new();
for old in 0..3 {
for &j in g.neighbors(old).expect("nb") {
let inv = |x: usize| perm.iter().position(|&p| p == x).expect("inv");
edges.push((inv(old), inv(j)));
}
}
let g2 = CsrGraph::from_edges(3, &edges).expect("test invariant: value must be valid");
let enc2 = random_walk_se(&g2, k).expect("test invariant: value must be valid");
for a in 0..3 {
for step in 0..k {
let got = enc2[a * k + step];
let want = enc[perm[a] * k + step];
assert!(
(got - want).abs() < 1e-5,
"row {a} step {step}: {got} vs {want}"
);
}
}
}
#[test]
fn isolated_node_zero() {
let g = CsrGraph::from_edges(3, &[(1, 2), (2, 1)])
.expect("test invariant: value must be valid");
let k = 3;
let enc = random_walk_se(&g, k).expect("test invariant: value must be valid");
for (step, &p) in enc.iter().take(k).enumerate() {
assert!(p.abs() < 1e-6, "isolated node step {step} nonzero");
}
}
#[test]
fn encoder_encode_matches_free_fn() {
let g = triangle();
let enc = RwseEncoder::new(RwseConfig { walk_length: 5 })
.expect("test invariant: value must be valid");
let a = enc.encode(&g).expect("test invariant: value must be valid");
let b = random_walk_se(&g, 5).expect("test invariant: value must be valid");
assert_eq!(a, b);
}
#[test]
fn augment_concatenates() {
let g = two_cycle();
let feat_dim = 2;
let walk_length = 3;
let enc = RwseEncoder::new(RwseConfig { walk_length })
.expect("test invariant: value must be valid");
let x = vec![10.0_f32, 20.0, 30.0, 40.0]; let out = enc
.augment(&g, &x, feat_dim)
.expect("test invariant: value must be valid");
let out_dim = feat_dim + walk_length;
assert_eq!(out.len(), 2 * out_dim);
assert!((out[0] - 10.0).abs() < 1e-6);
assert!((out[1] - 20.0).abs() < 1e-6);
assert!((out[out_dim] - 30.0).abs() < 1e-6);
assert!((out[out_dim + 1] - 40.0).abs() < 1e-6);
assert!((out[feat_dim]).abs() < 1e-6);
assert!((out[feat_dim + 1] - 1.0).abs() < 1e-6);
assert!((out[feat_dim + 2]).abs() < 1e-6);
}
#[test]
fn augment_feature_mismatch_errors() {
let g = two_cycle();
let enc = RwseEncoder::new(RwseConfig { walk_length: 2 })
.expect("test invariant: value must be valid");
let x = vec![1.0_f32; 3]; let err = enc.augment(&g, &x, 2);
assert!(matches!(err, Err(GnnError::NodeFeatureMismatch(..))));
}
#[test]
fn weighted_graph_normalised() {
let g =
CsrGraph::from_edges_weighted(3, &[(0, 1, 3.0), (0, 2, 1.0), (1, 0, 1.0), (2, 0, 1.0)])
.expect("test invariant: value must be valid");
let enc = random_walk_se(&g, 2).expect("test invariant: value must be valid");
assert!((enc[0]).abs() < 1e-6, "step1 should be 0");
assert!((enc[1] - 1.0).abs() < 1e-6, "step2 should be 1");
}
}