use crate::error::{GnnError, GnnResult};
#[derive(Debug, Clone, Copy)]
pub struct SortPoolConfig {
pub feat_dim: usize,
pub k: usize,
}
#[derive(Debug, Clone)]
pub struct SortPool {
config: SortPoolConfig,
}
impl SortPool {
pub fn new(config: SortPoolConfig) -> GnnResult<Self> {
if config.feat_dim == 0 {
return Err(GnnError::InvalidLayerConfig(
"SortPool: feat_dim must be > 0".to_string(),
));
}
if config.k == 0 {
return Err(GnnError::InvalidLayerConfig(
"SortPool: k must be > 0".to_string(),
));
}
Ok(Self { config })
}
fn sorted_order(&self, x: &[f32], n_nodes: usize) -> Vec<usize> {
let fd = self.config.feat_dim;
let mut order: Vec<usize> = (0..n_nodes).collect();
order.sort_by(|&a, &b| {
for c in (0..fd).rev() {
let va = x[a * fd + c];
let vb = x[b * fd + c];
match vb.partial_cmp(&va).unwrap_or(std::cmp::Ordering::Equal) {
std::cmp::Ordering::Equal => continue,
other => return other,
}
}
a.cmp(&b)
});
order
}
pub fn forward(&self, x: &[f32], n_nodes: usize) -> GnnResult<Vec<f32>> {
let fd = self.config.feat_dim;
let k = self.config.k;
if n_nodes == 0 {
return Err(GnnError::EmptyGraph);
}
if x.len() != n_nodes * fd {
return Err(GnnError::NodeFeatureMismatch(n_nodes, x.len() / fd.max(1)));
}
let order = self.sorted_order(x, n_nodes);
let mut out = vec![0.0_f32; k * fd]; let take = k.min(n_nodes);
for (row, &node) in order.iter().take(take).enumerate() {
for c in 0..fd {
out[row * fd + c] = x[node * fd + c];
}
}
Ok(out)
}
pub fn output_len(&self) -> usize {
self.config.k * self.config.feat_dim
}
pub fn k(&self) -> usize {
self.config.k
}
pub fn feat_dim(&self) -> usize {
self.config.feat_dim
}
}
#[cfg(test)]
mod tests {
use super::*;
fn pool(feat_dim: usize, k: usize) -> SortPool {
SortPool::new(SortPoolConfig { feat_dim, k }).expect("test invariant: value must be valid")
}
#[test]
fn build_and_accessors() {
let p = pool(3, 4);
assert_eq!(p.feat_dim(), 3);
assert_eq!(p.k(), 4);
assert_eq!(p.output_len(), 12);
}
#[test]
fn zero_config_errors() {
assert!(SortPool::new(SortPoolConfig { feat_dim: 0, k: 4 }).is_err());
assert!(SortPool::new(SortPoolConfig { feat_dim: 3, k: 0 }).is_err());
}
#[test]
fn output_shape_is_k_times_d() {
let p = pool(2, 3);
let x = vec![1.0_f32; 5 * 2];
let out = p.forward(&x, 5).expect("forward");
assert_eq!(out.len(), 3 * 2);
}
#[test]
fn sorts_descending_by_last_channel() {
let p = pool(1, 3);
let x = vec![3.0_f32, 1.0, 2.0];
let out = p.forward(&x, 3).expect("forward");
assert_eq!(out, vec![3.0, 2.0, 1.0]);
}
#[test]
fn truncates_to_k_when_more_nodes() {
let p = pool(1, 2);
let x = vec![10.0_f32, 40.0, 20.0, 30.0];
let out = p.forward(&x, 4).expect("forward");
assert_eq!(out, vec![40.0, 30.0]);
}
#[test]
fn zero_pads_when_fewer_nodes() {
let p = pool(2, 4);
let x = vec![1.0_f32, 5.0, 2.0, 9.0];
let out = p.forward(&x, 2).expect("forward");
assert_eq!(out.len(), 8);
assert_eq!(&out[0..2], &[2.0, 9.0]); assert_eq!(&out[2..4], &[1.0, 5.0]); assert!(
out[4..].iter().all(|&v| v == 0.0),
"tail must be zero-padded"
);
}
#[test]
fn last_channel_is_primary_key() {
let p = pool(2, 2);
let x = vec![9.0_f32, 1.0, 0.0, 8.0];
let out = p.forward(&x, 2).expect("forward");
assert_eq!(&out[0..2], &[0.0, 8.0], "node1 (last=8) must come first");
assert_eq!(&out[2..4], &[9.0, 1.0]);
}
#[test]
fn ties_broken_by_earlier_channel() {
let p = pool(2, 2);
let x = vec![1.0_f32, 5.0, 9.0, 5.0];
let out = p.forward(&x, 2).expect("forward");
assert_eq!(&out[0..2], &[9.0, 5.0]);
assert_eq!(&out[2..4], &[1.0, 5.0]);
}
#[test]
fn full_tie_broken_by_node_index() {
let p = pool(1, 3);
let x = vec![7.0_f32, 7.0, 7.0];
let out = p.forward(&x, 3).expect("forward");
assert_eq!(out, vec![7.0, 7.0, 7.0]); }
#[test]
fn empty_graph_errors() {
let p = pool(2, 3);
assert!(matches!(p.forward(&[], 0), Err(GnnError::EmptyGraph)));
}
#[test]
fn feature_mismatch_errors() {
let p = pool(3, 2);
let err = p.forward(&[1.0_f32; 7], 3); assert!(matches!(err, Err(GnnError::NodeFeatureMismatch(..))));
}
#[test]
fn exact_k_nodes_no_padding() {
let p = pool(2, 3);
let x: Vec<f32> = (0..3 * 2).map(|i| i as f32).collect();
let out = p.forward(&x, 3).expect("forward");
assert_eq!(out.len(), 6);
assert!(out.iter().any(|&v| v != 0.0));
}
#[test]
fn permutation_invariant() {
let p = pool(2, 3);
let x1 = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let x2 = vec![5.0_f32, 6.0, 1.0, 2.0, 3.0, 4.0]; let o1 = p.forward(&x1, 3).expect("forward");
let o2 = p.forward(&x2, 3).expect("forward");
assert_eq!(o1, o2, "SortPool must be permutation invariant");
}
#[test]
fn output_finite_and_negative_values() {
let p = pool(3, 4);
let x: Vec<f32> = (0..5 * 3).map(|i| (i as f32) * 0.5 - 3.0).collect();
let out = p.forward(&x, 5).expect("forward");
assert!(out.iter().all(|v| v.is_finite()));
}
}