use crate::error::{GnnError, GnnResult};
pub struct Set2Set {
processing_steps: usize,
input_dim: usize,
lstm_dim: usize,
}
impl Set2Set {
pub fn new(input_dim: usize, processing_steps: usize) -> GnnResult<Self> {
if input_dim == 0 {
return Err(GnnError::InvalidLayerConfig(
"input_dim must be > 0".to_string(),
));
}
if processing_steps == 0 {
return Err(GnnError::InvalidLayerConfig(
"processing_steps must be > 0".to_string(),
));
}
let lstm_dim = 2 * input_dim;
Ok(Self {
processing_steps,
input_dim,
lstm_dim,
})
}
pub fn output_dim(&self) -> usize {
2 * self.input_dim
}
pub fn forward(
&self,
x: &[f32],
n_nodes: usize,
lstm_weight: &[f32],
lstm_bias: &[f32],
) -> GnnResult<Vec<f32>> {
let d = self.input_dim;
let hd = self.lstm_dim;
if n_nodes == 0 {
return Err(GnnError::EmptyGraph);
}
if x.len() != n_nodes * d {
return Err(GnnError::DimensionMismatch {
expected: n_nodes * d,
got: x.len(),
});
}
let in_total = hd + d;
if lstm_weight.len() != 4 * hd * in_total {
return Err(GnnError::WeightShapeMismatch {
r: 4 * hd,
c: in_total,
d: in_total,
});
}
if lstm_bias.len() != 4 * hd {
return Err(GnnError::DimensionMismatch {
expected: 4 * hd,
got: lstm_bias.len(),
});
}
let mut h = vec![0.0_f32; hd]; let mut c = vec![0.0_f32; hd]; let mut q_star = vec![0.0_f32; hd];
for _ in 0..self.processing_steps {
let (h_new, c_new) = self.lstm_step(&q_star, &h, &c, lstm_weight, lstm_bias)?;
h = h_new;
c = c_new;
let mut scores = Vec::with_capacity(n_nodes);
for i in 0..n_nodes {
let score: f32 = (0..d).map(|k| h[k] * x[i * d + k]).sum();
scores.push(score);
}
let max_s = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = scores.iter().map(|&s| (s - max_s).exp()).collect();
let sum_e: f32 = exps.iter().sum();
let alphas: Vec<f32> = if sum_e > 0.0 {
exps.iter().map(|&e| e / sum_e).collect()
} else {
vec![1.0 / n_nodes as f32; n_nodes]
};
let mut r = vec![0.0_f32; d];
for i in 0..n_nodes {
for k in 0..d {
r[k] += alphas[i] * x[i * d + k];
}
}
q_star = {
let mut qs = Vec::with_capacity(hd + d);
qs.extend_from_slice(&h);
qs.extend_from_slice(&r);
qs
};
}
let out: Vec<f32> = q_star[..hd].to_vec();
Ok(out)
}
fn lstm_step(
&self,
input: &[f32],
h: &[f32],
c: &[f32],
weight: &[f32],
bias: &[f32],
) -> GnnResult<(Vec<f32>, Vec<f32>)> {
let hd = self.lstm_dim;
let d = self.input_dim;
let in_total = hd + d;
let input_len = input.len().min(d); let concat_len = hd + d;
let _ = concat_len;
let mut concat = Vec::with_capacity(hd + input_len);
concat.extend_from_slice(h);
concat.extend_from_slice(&input[..input_len]);
let mut gates = vec![0.0_f32; 4 * hd];
for gate in 0..4 {
for k in 0..hd {
let row = gate * hd + k;
let mut val = bias[row];
let w_row_start = row * in_total;
for j in 0..concat.len().min(in_total) {
val += weight[w_row_start + j] * concat[j];
}
gates[row] = val;
}
}
let sigmoid = |v: f32| 1.0 / (1.0 + (-v).exp());
let tanh = |v: f32| v.tanh();
let mut c_new = vec![0.0_f32; hd];
let mut h_new = vec![0.0_f32; hd];
for k in 0..hd {
let i_gate = sigmoid(gates[k]); let f_gate = sigmoid(gates[hd + k]); let g_gate = tanh(gates[2 * hd + k]); let o_gate = sigmoid(gates[3 * hd + k]);
c_new[k] = f_gate * c[k] + i_gate * g_gate;
h_new[k] = o_gate * tanh(c_new[k]);
}
Ok((h_new, c_new))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn zero_weights(d: usize, steps: usize) -> (Vec<f32>, Vec<f32>) {
let s2s = Set2Set::new(d, steps).expect("test invariant: value must be valid");
let hd = s2s.lstm_dim;
let in_total = hd + d;
let w = vec![0.0_f32; 4 * hd * in_total];
let b = vec![0.0_f32; 4 * hd];
(w, b)
}
#[test]
fn output_dim_is_twice_input_dim() {
let s2s = Set2Set::new(8, 3).expect("test invariant: value must be valid");
assert_eq!(s2s.output_dim(), 16);
}
#[test]
fn output_shape_correct() {
let d = 4;
let n = 5;
let s2s = Set2Set::new(d, 2).expect("test invariant: value must be valid");
let x = vec![0.1_f32; n * d];
let (w, b) = zero_weights(d, 2);
let out = s2s
.forward(&x, n, &w, &b)
.expect("test invariant: value must be valid");
assert_eq!(out.len(), s2s.output_dim());
}
#[test]
fn single_node_graph() {
let d = 4;
let s2s = Set2Set::new(d, 3).expect("test invariant: value must be valid");
let x = vec![1.0_f32; d];
let (w, b) = zero_weights(d, 3);
let out = s2s
.forward(&x, 1, &w, &b)
.expect("test invariant: value must be valid");
assert_eq!(out.len(), 2 * d);
}
#[test]
fn output_finite() {
let d = 3;
let n = 6;
let s2s = Set2Set::new(d, 4).expect("test invariant: value must be valid");
let x: Vec<f32> = (0..n * d).map(|i| (i as f32) * 0.1).collect();
let hd = s2s.lstm_dim;
let in_total = hd + d;
let w = vec![0.01_f32; 4 * hd * in_total];
let b = vec![0.0_f32; 4 * hd];
let out = s2s
.forward(&x, n, &w, &b)
.expect("test invariant: value must be valid");
assert!(out.iter().all(|v| v.is_finite()));
}
#[test]
fn zero_weights_output_zero() {
let d = 4;
let n = 3;
let s2s = Set2Set::new(d, 2).expect("test invariant: value must be valid");
let x = vec![0.5_f32; n * d];
let (w, b) = zero_weights(d, 2);
let out = s2s
.forward(&x, n, &w, &b)
.expect("test invariant: value must be valid");
assert!(out.iter().all(|&v| v.abs() < 1e-5));
}
#[test]
fn empty_graph_error() {
let d = 4;
let s2s = Set2Set::new(d, 2).expect("test invariant: value must be valid");
let (w, b) = zero_weights(d, 2);
let err = s2s.forward(&[], 0, &w, &b);
assert!(matches!(err, Err(GnnError::EmptyGraph)));
}
#[test]
fn multiple_processing_steps() {
let d = 2;
let n = 4;
let steps = 5;
let s2s = Set2Set::new(d, steps).expect("test invariant: value must be valid");
let x = vec![1.0_f32; n * d];
let (w, b) = zero_weights(d, steps);
let out = s2s
.forward(&x, n, &w, &b)
.expect("test invariant: value must be valid");
assert_eq!(out.len(), 2 * d);
}
#[test]
fn lstm_step_output_shapes() {
let d = 4;
let steps = 1;
let s2s = Set2Set::new(d, steps).expect("test invariant: value must be valid");
let hd = s2s.lstm_dim;
let in_total = hd + d;
let input = vec![0.0_f32; d];
let h = vec![0.0_f32; hd];
let c = vec![0.0_f32; hd];
let w = vec![0.0_f32; 4 * hd * in_total];
let b = vec![0.0_f32; 4 * hd];
let (h_new, c_new) = s2s
.lstm_step(&input, &h, &c, &w, &b)
.expect("test invariant: value must be valid");
assert_eq!(h_new.len(), hd);
assert_eq!(c_new.len(), hd);
}
#[test]
fn invalid_zero_input_dim() {
let err = Set2Set::new(0, 3);
assert!(err.is_err());
}
#[test]
fn invalid_zero_steps() {
let err = Set2Set::new(4, 0);
assert!(err.is_err());
}
#[test]
fn dimension_mismatch_error() {
let d = 4;
let s2s = Set2Set::new(d, 2).expect("test invariant: value must be valid");
let x = vec![0.1_f32; 3 * d]; let (w, b) = zero_weights(d, 2);
let err = s2s.forward(&x, 5, &w, &b);
assert!(matches!(err, Err(GnnError::DimensionMismatch { .. })));
}
}