use crate::error::{GnnError, GnnResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AggregationType {
Sum,
Mean,
Max,
Min,
SoftmaxWeighted,
}
pub fn aggregate(
messages: &[f32],
target_idx: &[usize],
n_nodes: usize,
feat_dim: usize,
agg_type: AggregationType,
) -> GnnResult<Vec<f32>> {
match agg_type {
AggregationType::Sum => aggregate_sum(messages, target_idx, n_nodes, feat_dim),
AggregationType::Mean => aggregate_mean(messages, target_idx, n_nodes, feat_dim),
AggregationType::Max => aggregate_max(messages, target_idx, n_nodes, feat_dim),
AggregationType::Min => aggregate_min(messages, target_idx, n_nodes, feat_dim),
AggregationType::SoftmaxWeighted => Err(GnnError::InvalidAggregation(
"SoftmaxWeighted requires explicit weights; use aggregate_softmax instead",
)),
}
}
fn validate_messages(
messages: &[f32],
target_idx: &[usize],
n_nodes: usize,
feat_dim: usize,
) -> GnnResult<usize> {
if feat_dim == 0 {
return Err(GnnError::InvalidLayerConfig(
"feat_dim must be > 0".to_string(),
));
}
let n_edges = target_idx.len();
if messages.len() != n_edges * feat_dim {
return Err(GnnError::DimensionMismatch {
expected: n_edges * feat_dim,
got: messages.len(),
});
}
for &idx in target_idx {
if idx >= n_nodes {
return Err(GnnError::NodeIndexOutOfRange { idx, n_nodes });
}
}
Ok(n_edges)
}
pub fn aggregate_sum(
messages: &[f32],
target_idx: &[usize],
n_nodes: usize,
feat_dim: usize,
) -> GnnResult<Vec<f32>> {
let n_edges = validate_messages(messages, target_idx, n_nodes, feat_dim)?;
let mut out = vec![0.0_f32; n_nodes * feat_dim];
for e in 0..n_edges {
let t = target_idx[e];
for k in 0..feat_dim {
out[t * feat_dim + k] += messages[e * feat_dim + k];
}
}
Ok(out)
}
pub fn aggregate_mean(
messages: &[f32],
target_idx: &[usize],
n_nodes: usize,
feat_dim: usize,
) -> GnnResult<Vec<f32>> {
let n_edges = validate_messages(messages, target_idx, n_nodes, feat_dim)?;
let mut out = vec![0.0_f32; n_nodes * feat_dim];
let mut counts = vec![0usize; n_nodes];
for e in 0..n_edges {
let t = target_idx[e];
counts[t] += 1;
for k in 0..feat_dim {
out[t * feat_dim + k] += messages[e * feat_dim + k];
}
}
for i in 0..n_nodes {
if counts[i] > 0 {
let inv = 1.0 / counts[i] as f32;
for k in 0..feat_dim {
out[i * feat_dim + k] *= inv;
}
}
}
Ok(out)
}
pub fn aggregate_max(
messages: &[f32],
target_idx: &[usize],
n_nodes: usize,
feat_dim: usize,
) -> GnnResult<Vec<f32>> {
let n_edges = validate_messages(messages, target_idx, n_nodes, feat_dim)?;
let mut out = vec![f32::NEG_INFINITY; n_nodes * feat_dim];
let mut has_msg = vec![false; n_nodes];
for e in 0..n_edges {
let t = target_idx[e];
has_msg[t] = true;
for k in 0..feat_dim {
let v = messages[e * feat_dim + k];
if v > out[t * feat_dim + k] {
out[t * feat_dim + k] = v;
}
}
}
for i in 0..n_nodes {
if !has_msg[i] {
for k in 0..feat_dim {
out[i * feat_dim + k] = 0.0;
}
}
}
Ok(out)
}
pub fn aggregate_min(
messages: &[f32],
target_idx: &[usize],
n_nodes: usize,
feat_dim: usize,
) -> GnnResult<Vec<f32>> {
let n_edges = validate_messages(messages, target_idx, n_nodes, feat_dim)?;
let mut out = vec![f32::INFINITY; n_nodes * feat_dim];
let mut has_msg = vec![false; n_nodes];
for e in 0..n_edges {
let t = target_idx[e];
has_msg[t] = true;
for k in 0..feat_dim {
let v = messages[e * feat_dim + k];
if v < out[t * feat_dim + k] {
out[t * feat_dim + k] = v;
}
}
}
for i in 0..n_nodes {
if !has_msg[i] {
for k in 0..feat_dim {
out[i * feat_dim + k] = 0.0;
}
}
}
Ok(out)
}
pub fn aggregate_softmax(
messages: &[f32],
weights: &[f32],
target_idx: &[usize],
n_nodes: usize,
feat_dim: usize,
) -> GnnResult<Vec<f32>> {
let n_edges = validate_messages(messages, target_idx, n_nodes, feat_dim)?;
if weights.len() != n_edges {
return Err(GnnError::DimensionMismatch {
expected: n_edges,
got: weights.len(),
});
}
let mut out = vec![0.0_f32; n_nodes * feat_dim];
for e in 0..n_edges {
let t = target_idx[e];
let w = weights[e];
for k in 0..feat_dim {
out[t * feat_dim + k] += w * messages[e * feat_dim + k];
}
}
Ok(out)
}
pub fn aggregate_degree_norm(
messages: &[f32],
target_idx: &[usize],
n_nodes: usize,
feat_dim: usize,
) -> GnnResult<Vec<f32>> {
aggregate_mean(messages, target_idx, n_nodes, feat_dim)
}
#[cfg(test)]
mod tests {
use super::*;
fn small_setup() -> (Vec<f32>, Vec<usize>, usize, usize) {
let messages = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let target_idx = vec![0, 1, 0];
(messages, target_idx, 2, 2) }
#[test]
fn sum_aggregate_correct() {
let (msg, idx, n, d) = small_setup();
let out = aggregate_sum(&msg, &idx, n, d).expect("test invariant: value must be valid");
assert!((out[0] - 6.0).abs() < 1e-6);
assert!((out[1] - 8.0).abs() < 1e-6);
assert!((out[2] - 3.0).abs() < 1e-6);
assert!((out[3] - 4.0).abs() < 1e-6);
}
#[test]
fn mean_aggregate_correct() {
let (msg, idx, n, d) = small_setup();
let out = aggregate_mean(&msg, &idx, n, d).expect("test invariant: value must be valid");
assert!((out[0] - 3.0).abs() < 1e-6);
assert!((out[1] - 4.0).abs() < 1e-6);
assert!((out[2] - 3.0).abs() < 1e-6);
assert!((out[3] - 4.0).abs() < 1e-6);
}
#[test]
fn max_aggregate_correct() {
let (msg, idx, n, d) = small_setup();
let out = aggregate_max(&msg, &idx, n, d).expect("test invariant: value must be valid");
assert!((out[0] - 5.0).abs() < 1e-6);
assert!((out[1] - 6.0).abs() < 1e-6);
}
#[test]
fn min_aggregate_correct() {
let (msg, idx, n, d) = small_setup();
let out = aggregate_min(&msg, &idx, n, d).expect("test invariant: value must be valid");
assert!((out[0] - 1.0).abs() < 1e-6);
assert!((out[1] - 2.0).abs() < 1e-6);
}
#[test]
fn isolated_node_produces_zero_sum() {
let messages = vec![1.0_f32, 2.0];
let target_idx = vec![0usize]; let out = aggregate_sum(&messages, &target_idx, 3, 2)
.expect("test invariant: value must be valid");
assert!((out[4]).abs() < 1e-6);
assert!((out[5]).abs() < 1e-6);
}
#[test]
fn isolated_node_produces_zero_max() {
let messages = vec![1.0_f32, 2.0];
let target_idx = vec![1usize];
let out = aggregate_max(&messages, &target_idx, 3, 2)
.expect("test invariant: value must be valid");
assert!((out[0]).abs() < 1e-6);
assert!((out[4]).abs() < 1e-6);
}
#[test]
fn softmax_aggregate_weighted() {
let messages = vec![1.0_f32, 2.0, 3.0, 4.0];
let weights = vec![0.3_f32, 0.7];
let target_idx = vec![0, 0];
let out = aggregate_softmax(&messages, &weights, &target_idx, 1, 2)
.expect("test invariant: value must be valid");
assert!((out[0] - 2.4).abs() < 1e-5);
assert!((out[1] - 3.4).abs() < 1e-5);
}
#[test]
fn aggregate_dispatch_sum() {
let (msg, idx, n, d) = small_setup();
let out = aggregate(&msg, &idx, n, d, AggregationType::Sum)
.expect("test invariant: value must be valid");
assert!((out[0] - 6.0).abs() < 1e-6);
}
#[test]
fn aggregate_dispatch_softmax_weighted_error() {
let (msg, idx, n, d) = small_setup();
let err = aggregate(&msg, &idx, n, d, AggregationType::SoftmaxWeighted);
assert!(err.is_err());
}
#[test]
fn dimension_mismatch_error() {
let err = aggregate_sum(&[1.0_f32, 2.0], &[0, 1], 2, 2);
assert!(matches!(err, Err(GnnError::DimensionMismatch { .. })));
}
#[test]
fn degree_norm_equals_mean() {
let (msg, idx, n, d) = small_setup();
let mean_out =
aggregate_mean(&msg, &idx, n, d).expect("test invariant: value must be valid");
let deg_out =
aggregate_degree_norm(&msg, &idx, n, d).expect("test invariant: value must be valid");
for (a, b) in mean_out.iter().zip(deg_out.iter()) {
assert!((a - b).abs() < 1e-6);
}
}
#[test]
fn out_of_range_target_error() {
let messages = vec![1.0_f32, 2.0];
let target_idx = vec![10usize]; let err = aggregate_sum(&messages, &target_idx, 3, 2);
assert!(matches!(err, Err(GnnError::NodeIndexOutOfRange { .. })));
}
}