use crate::error::{GnnError, GnnResult};
use crate::graph::csr::CsrGraph;
#[derive(Debug, Clone)]
pub struct TopKPool {
ratio: f32,
k: Option<usize>,
feat_dim: usize,
}
#[derive(Debug, Clone)]
pub struct TopKPoolResult {
pub node_indices: Vec<usize>,
pub x: Vec<f32>,
pub graph: CsrGraph,
}
impl TopKPoolResult {
pub fn n_nodes(&self) -> usize {
self.node_indices.len()
}
}
impl TopKPool {
pub fn new_ratio(feat_dim: usize, ratio: f32) -> GnnResult<Self> {
if feat_dim == 0 {
return Err(GnnError::InvalidLayerConfig(
"feat_dim must be > 0".to_string(),
));
}
if !(0.0 < ratio && ratio <= 1.0) {
return Err(GnnError::InvalidLayerConfig(
"ratio must be in (0, 1]".to_string(),
));
}
Ok(Self {
ratio,
k: None,
feat_dim,
})
}
pub fn new_k(feat_dim: usize, k: usize) -> Self {
Self {
ratio: 1.0,
k: Some(k),
feat_dim,
}
}
pub fn forward(&self, graph: &CsrGraph, x: &[f32], proj: &[f32]) -> GnnResult<TopKPoolResult> {
let n = graph.n_nodes();
let fd = self.feat_dim;
if x.len() != n * fd {
return Err(GnnError::NodeFeatureMismatch(n, x.len() / fd.max(1)));
}
if proj.len() != fd {
return Err(GnnError::DimensionMismatch {
expected: fd,
got: proj.len(),
});
}
let k = if let Some(fixed_k) = self.k {
fixed_k
} else {
((n as f32 * self.ratio).ceil() as usize).max(1)
};
if k > n {
return Err(GnnError::TopKExceedsGraphSize { k, n });
}
let norm_sq: f32 = proj.iter().map(|&v| v * v).sum();
let norm = norm_sq.sqrt().max(1e-12);
let scores: Vec<f32> = (0..n)
.map(|i| {
let dot: f32 = (0..fd).map(|k_idx| x[i * fd + k_idx] * proj[k_idx]).sum();
(dot / norm).tanh()
})
.collect();
let mut indexed: Vec<(usize, f32)> =
scores.iter().enumerate().map(|(i, &s)| (i, s)).collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let selected: Vec<usize> = indexed[..k].iter().map(|&(i, _)| i).collect();
let mut node_indices = selected;
node_indices.sort_unstable();
let global_to_local: std::collections::HashMap<usize, usize> = node_indices
.iter()
.enumerate()
.map(|(local, &global)| (global, local))
.collect();
let new_x: Vec<f32> = node_indices
.iter()
.flat_map(|&global| {
let s = scores[global];
(0..fd).map(move |k_idx| x[global * fd + k_idx] * s)
})
.collect();
let mut new_edges: Vec<(usize, usize)> = Vec::new();
for &global_i in &node_indices {
let neighbors = graph.neighbors(global_i)?;
for &global_j in neighbors {
if let Some(&local_j) = global_to_local.get(&global_j) {
let local_i = global_to_local[&global_i];
new_edges.push((local_i, local_j));
}
}
}
let new_graph = if new_edges.is_empty() {
let row_ptr = vec![0usize; k + 1];
let col_idx = vec![];
CsrGraph::new(k, row_ptr, col_idx)?
} else {
CsrGraph::from_edges(k, &new_edges)?
};
Ok(TopKPoolResult {
node_indices,
x: new_x,
graph: new_graph,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn line_graph(n: usize) -> CsrGraph {
let edges: Vec<(usize, usize)> =
(0..n - 1).flat_map(|i| [(i, i + 1), (i + 1, i)]).collect();
CsrGraph::from_edges(n, &edges).expect("test invariant: value must be valid")
}
#[test]
fn k_never_exceeds_n() {
let g = line_graph(5);
let pool = TopKPool::new_ratio(3, 0.6).expect("test invariant: value must be valid"); let x = vec![1.0_f32; 5 * 3];
let proj = vec![1.0_f32, 0.0, 0.0];
let res = pool
.forward(&g, &x, &proj)
.expect("test invariant: value must be valid");
assert!(res.n_nodes() <= 5);
assert_eq!(res.n_nodes(), 3);
}
#[test]
fn fixed_k_returns_exactly_k() {
let g = line_graph(6);
let pool = TopKPool::new_k(2, 2);
let x = vec![1.0_f32; 6 * 2];
let proj = vec![1.0_f32, 0.5];
let res = pool
.forward(&g, &x, &proj)
.expect("test invariant: value must be valid");
assert_eq!(res.n_nodes(), 2);
}
#[test]
fn output_graph_consistent_node_count() {
let g = line_graph(8);
let pool = TopKPool::new_k(3, 3);
let x: Vec<f32> = (0..8 * 3).map(|i| i as f32 * 0.1).collect();
let proj = vec![1.0_f32, 1.0, 1.0];
let res = pool
.forward(&g, &x, &proj)
.expect("test invariant: value must be valid");
assert_eq!(res.graph.n_nodes(), 3);
}
#[test]
fn selected_indices_are_sorted() {
let g = line_graph(5);
let pool = TopKPool::new_k(2, 2);
let x: Vec<f32> = (0..5 * 2).map(|i| i as f32).collect();
let proj = vec![1.0_f32, 0.0];
let res = pool
.forward(&g, &x, &proj)
.expect("test invariant: value must be valid");
let sorted = {
let mut v = res.node_indices.clone();
v.sort_unstable();
v
};
assert_eq!(res.node_indices, sorted);
}
#[test]
fn x_length_correct() {
let g = line_graph(5);
let k = 3;
let fd = 4;
let pool = TopKPool::new_k(fd, k);
let x = vec![0.5_f32; 5 * fd];
let proj = vec![1.0_f32; fd];
let res = pool
.forward(&g, &x, &proj)
.expect("test invariant: value must be valid");
assert_eq!(res.x.len(), k * fd);
}
#[test]
fn score_ordering_selects_highest() {
let g = line_graph(5);
let pool = TopKPool::new_k(1, 2);
let x: Vec<f32> = (0..5).map(|i| i as f32).collect();
let proj = vec![1.0_f32];
let res = pool
.forward(&g, &x, &proj)
.expect("test invariant: value must be valid");
assert!(res.node_indices.contains(&3) || res.node_indices.contains(&4));
}
#[test]
fn k_exceeds_n_error() {
let g = line_graph(3);
let pool = TopKPool::new_k(2, 5);
let x = vec![0.1_f32; 3 * 2];
let proj = vec![1.0_f32, 0.0];
let err = pool.forward(&g, &x, &proj);
assert!(matches!(err, Err(GnnError::TopKExceedsGraphSize { .. })));
}
#[test]
fn invalid_ratio_error() {
let err = TopKPool::new_ratio(4, 0.0);
assert!(err.is_err());
let err = TopKPool::new_ratio(4, 1.5);
assert!(err.is_err());
}
#[test]
fn output_features_finite() {
let g = line_graph(6);
let pool = TopKPool::new_k(3, 3);
let x: Vec<f32> = (0..6 * 3).map(|i| i as f32 * 0.5).collect();
let proj = vec![1.0_f32, -0.5, 0.5];
let res = pool
.forward(&g, &x, &proj)
.expect("test invariant: value must be valid");
assert!(res.x.iter().all(|v| v.is_finite()));
}
#[test]
fn induced_subgraph_edges_in_selected_set() {
let g = line_graph(4);
let pool = TopKPool::new_k(2, 2);
let x: Vec<f32> = (0..4 * 2).map(|i| i as f32).collect();
let proj = vec![1.0_f32, 0.0];
let res = pool
.forward(&g, &x, &proj)
.expect("test invariant: value must be valid");
let k = res.n_nodes();
for e in 0..res.graph.n_edges() {
let col = res.graph.col_idx()[e];
assert!(col < k);
}
}
}