use crate::error::{GnnError, GnnResult};
use crate::graph::csr::CsrGraph;
#[derive(Debug, Clone, Copy)]
pub struct AppnpConfig {
pub feat_dim: usize,
pub alpha: f32,
pub k: usize,
}
pub struct AppnpLayer {
config: AppnpConfig,
}
impl AppnpLayer {
pub fn new(config: AppnpConfig) -> GnnResult<Self> {
if config.feat_dim == 0 {
return Err(GnnError::InvalidLayerConfig(
"APPNP: feat_dim must be > 0".to_string(),
));
}
if config.alpha <= 0.0 || config.alpha >= 1.0 {
return Err(GnnError::InvalidLayerConfig(format!(
"APPNP: alpha must be in (0, 1) exclusive, got {:.6}",
config.alpha
)));
}
if config.k == 0 {
return Err(GnnError::InvalidLayerConfig(
"APPNP: k must be >= 1".to_string(),
));
}
Ok(Self { config })
}
pub fn forward(&self, graph: &CsrGraph, h: &[f32]) -> GnnResult<Vec<f32>> {
let n = graph.n_nodes();
let f = self.config.feat_dim;
let alpha = self.config.alpha;
let one_minus_alpha = 1.0_f32 - alpha;
if h.len() != n * f {
return Err(GnnError::NodeFeatureMismatch(n, h.len() / f.max(1)));
}
let (rows, cols, vals) = graph.normalized_adjacency();
let h_0 = h.to_vec();
let mut h_cur = h.to_vec();
for _step in 0..self.config.k {
let mut h_prop = vec![0.0_f32; n * f];
for idx in 0..rows.len() {
let i = rows[idx];
let j = cols[idx];
let v = vals[idx];
for d in 0..f {
h_prop[i * f + d] += v * h_cur[j * f + d];
}
}
for k in 0..n * f {
h_cur[k] = one_minus_alpha * h_prop[k] + alpha * h_0[k];
}
}
if h_cur.iter().any(|v| !v.is_finite()) {
return Err(GnnError::NonFiniteOutput("APPNP forward"));
}
Ok(h_cur)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn line_graph() -> CsrGraph {
CsrGraph::from_edges(4, &[(0, 1), (1, 0), (1, 2), (2, 1), (2, 3), (3, 2)])
.expect("test invariant: value must be valid")
}
fn triangle_graph() -> CsrGraph {
CsrGraph::from_edges(3, &[(0, 1), (1, 0), (1, 2), (2, 1), (0, 2), (2, 0)])
.expect("test invariant: value must be valid")
}
#[test]
fn new_valid() {
let cfg = AppnpConfig {
feat_dim: 4,
alpha: 0.1,
k: 5,
};
assert!(AppnpLayer::new(cfg).is_ok());
}
#[test]
fn new_invalid_feat_dim_zero() {
let cfg = AppnpConfig {
feat_dim: 0,
alpha: 0.1,
k: 5,
};
assert!(AppnpLayer::new(cfg).is_err());
}
#[test]
fn new_invalid_alpha_zero() {
let cfg = AppnpConfig {
feat_dim: 4,
alpha: 0.0,
k: 5,
};
assert!(AppnpLayer::new(cfg).is_err());
}
#[test]
fn new_invalid_alpha_one() {
let cfg = AppnpConfig {
feat_dim: 4,
alpha: 1.0,
k: 5,
};
assert!(AppnpLayer::new(cfg).is_err());
}
#[test]
fn new_invalid_alpha_gt_1() {
let cfg = AppnpConfig {
feat_dim: 4,
alpha: 1.5,
k: 5,
};
assert!(AppnpLayer::new(cfg).is_err());
}
#[test]
fn new_invalid_k_zero() {
let cfg = AppnpConfig {
feat_dim: 4,
alpha: 0.1,
k: 0,
};
assert!(AppnpLayer::new(cfg).is_err());
}
#[test]
fn forward_output_shape() {
let g = line_graph();
let cfg = AppnpConfig {
feat_dim: 3,
alpha: 0.15,
k: 5,
};
let layer = AppnpLayer::new(cfg).expect("test invariant: value must be valid");
let h: Vec<f32> = (0..4 * 3).map(|i| i as f32 * 0.1).collect();
let out = layer
.forward(&g, &h)
.expect("test invariant: value must be valid");
assert_eq!(out.len(), 4 * 3);
assert!(out.iter().all(|v| v.is_finite()));
}
#[test]
fn forward_high_alpha_stays_near_h0() {
let g = line_graph();
let cfg = AppnpConfig {
feat_dim: 2,
alpha: 0.99,
k: 5,
};
let layer = AppnpLayer::new(cfg).expect("test invariant: value must be valid");
let h = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let out = layer
.forward(&g, &h)
.expect("test invariant: value must be valid");
let max_diff = out
.iter()
.zip(h.iter())
.map(|(o, hi)| (o - hi).abs())
.fold(0.0_f32, f32::max);
assert!(max_diff < 0.1, "max_diff={max_diff}");
}
#[test]
fn forward_single_node_graph() {
let g = CsrGraph::from_edges(1, &[(0, 0)]).expect("test invariant: value must be valid");
let cfg = AppnpConfig {
feat_dim: 2,
alpha: 0.1,
k: 5,
};
let layer = AppnpLayer::new(cfg).expect("test invariant: value must be valid");
let h = vec![3.0_f32, 7.0];
let out = layer
.forward(&g, &h)
.expect("test invariant: value must be valid");
assert_eq!(out.len(), 2);
assert!((out[0] - 3.0_f32).abs() < 1e-5, "out[0]={}", out[0]);
assert!((out[1] - 7.0_f32).abs() < 1e-5, "out[1]={}", out[1]);
}
#[test]
fn forward_does_not_change_with_k1_identity_graph() {
let g = CsrGraph::from_edges(4, &[(0, 1), (0, 2), (0, 3), (1, 0), (2, 0), (3, 0)])
.expect("test invariant: value must be valid");
let cfg = AppnpConfig {
feat_dim: 2,
alpha: 0.1,
k: 1,
};
let layer = AppnpLayer::new(cfg).expect("test invariant: value must be valid");
let h = vec![1.0_f32; 4 * 2];
let out = layer
.forward(&g, &h)
.expect("test invariant: value must be valid");
assert_eq!(out.len(), 4 * 2);
assert!(out.iter().all(|v| v.is_finite()));
}
#[test]
fn forward_k_steps_applied() {
let g = line_graph();
let make = |k: usize| {
let cfg = AppnpConfig {
feat_dim: 2,
alpha: 0.1,
k,
};
let layer = AppnpLayer::new(cfg).expect("test invariant: value must be valid");
let h: Vec<f32> = (0..4 * 2).map(|i| i as f32).collect();
layer
.forward(&g, &h)
.expect("test invariant: value must be valid")
};
let out1 = make(1);
let out10 = make(10);
let diff: f32 = out1
.iter()
.zip(out10.iter())
.map(|(a, b)| (a - b).abs())
.sum();
assert!(
diff > 1e-6,
"k=1 and k=10 outputs should differ, diff={diff}"
);
}
#[test]
fn forward_feature_mean_preserved_approx() {
let g = triangle_graph();
let cfg = AppnpConfig {
feat_dim: 3,
alpha: 0.1,
k: 5,
};
let layer = AppnpLayer::new(cfg).expect("test invariant: value must be valid");
let h: Vec<f32> = (0..3 * 3).map(|i| i as f32 * 0.3).collect();
let out = layer
.forward(&g, &h)
.expect("test invariant: value must be valid");
assert!(out.iter().all(|v| v.is_finite()));
}
#[test]
fn forward_disconnected_graph_no_panic() {
let g = CsrGraph::from_edges(3, &[(0, 0), (1, 1), (2, 2)])
.expect("test invariant: value must be valid");
let cfg = AppnpConfig {
feat_dim: 2,
alpha: 0.1,
k: 5,
};
let layer = AppnpLayer::new(cfg).expect("test invariant: value must be valid");
let h = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let out = layer.forward(&g, &h);
assert!(out.is_ok());
assert!(
out.expect("test invariant: value must be valid")
.iter()
.all(|v| v.is_finite())
);
}
#[test]
fn forward_err_node_feature_mismatch() {
let g = line_graph(); let cfg = AppnpConfig {
feat_dim: 3,
alpha: 0.1,
k: 3,
};
let layer = AppnpLayer::new(cfg).expect("test invariant: value must be valid");
let h = vec![1.0_f32; 3 * 3];
let err = layer.forward(&g, &h);
assert!(matches!(err, Err(GnnError::NodeFeatureMismatch(..))));
}
#[test]
fn forward_two_nodes_triangle() {
let g = triangle_graph(); let cfg = AppnpConfig {
feat_dim: 2,
alpha: 0.2,
k: 3,
};
let layer = AppnpLayer::new(cfg).expect("test invariant: value must be valid");
let h = vec![1.0_f32, 0.0, 0.0, 1.0, 0.5, 0.5];
let out = layer
.forward(&g, &h)
.expect("test invariant: value must be valid");
assert_eq!(out.len(), 3 * 2);
assert!(out.iter().all(|v| v.is_finite()));
}
#[test]
fn forward_propagation_smooths_extremes() {
let g = triangle_graph();
let cfg = AppnpConfig {
feat_dim: 1,
alpha: 0.1,
k: 20,
};
let layer = AppnpLayer::new(cfg).expect("test invariant: value must be valid");
let h = vec![10.0_f32, 0.0, 0.0];
let out = layer
.forward(&g, &h)
.expect("test invariant: value must be valid");
assert!(
out[0] < 10.0,
"node 0 feature should decrease, out[0]={}",
out[0]
);
assert!(
out[1] > 0.0 || out[2] > 0.0,
"neighbors should gain feature mass"
);
}
#[test]
fn forward_k1_matches_manual() {
let g = CsrGraph::from_edges(2, &[(0, 1), (1, 0)])
.expect("test invariant: value must be valid");
let alpha = 0.2_f32;
let cfg = AppnpConfig {
feat_dim: 1,
alpha,
k: 1,
};
let layer = AppnpLayer::new(cfg).expect("test invariant: value must be valid");
let h = vec![4.0_f32, 2.0_f32];
let out = layer
.forward(&g, &h)
.expect("test invariant: value must be valid");
let mean = (h[0] + h[1]) * 0.5_f32;
let expected_0 = (1.0 - alpha) * mean + alpha * h[0];
let expected_1 = (1.0 - alpha) * mean + alpha * h[1];
assert!(
(out[0] - expected_0).abs() < 1e-5,
"out[0]={} expected={expected_0}",
out[0]
);
assert!(
(out[1] - expected_1).abs() < 1e-5,
"out[1]={} expected={expected_1}",
out[1]
);
}
}