use std::collections::HashMap;
use axonml_autograd::Variable;
use axonml_tensor::Tensor;
use crate::module::Module;
use crate::parameter::Parameter;
pub struct GCNConv {
weight: Parameter,
bias: Option<Parameter>,
in_features: usize,
out_features: usize,
}
impl GCNConv {
pub fn new(in_features: usize, out_features: usize) -> Self {
let scale = (2.0 / (in_features + out_features) as f32).sqrt();
let weight_data: Vec<f32> = (0..in_features * out_features)
.map(|i| {
let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
x * scale
})
.collect();
let weight = Parameter::named(
"weight",
Tensor::from_vec(weight_data, &[in_features, out_features])
.expect("tensor creation failed"),
true,
);
let bias_data = vec![0.0; out_features];
let bias = Some(Parameter::named(
"bias",
Tensor::from_vec(bias_data, &[out_features]).expect("tensor creation failed"),
true,
));
Self {
weight,
bias,
in_features,
out_features,
}
}
pub fn without_bias(in_features: usize, out_features: usize) -> Self {
let scale = (2.0 / (in_features + out_features) as f32).sqrt();
let weight_data: Vec<f32> = (0..in_features * out_features)
.map(|i| {
let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
x * scale
})
.collect();
let weight = Parameter::named(
"weight",
Tensor::from_vec(weight_data, &[in_features, out_features])
.expect("tensor creation failed"),
true,
);
Self {
weight,
bias: None,
in_features,
out_features,
}
}
pub fn forward_graph(&self, x: &Variable, adj: &Variable) -> Variable {
let shape = x.shape();
assert!(
shape.len() == 3,
"GCNConv expects input shape (batch, nodes, features), got {:?}",
shape
);
assert_eq!(shape[2], self.in_features, "Input features mismatch");
let batch = shape[0];
let adj_shape = adj.shape();
let weight = self.weight.variable();
let mut per_sample: Vec<Variable> = Vec::with_capacity(batch);
for b in 0..batch {
let x_b = x.select(0, b);
let adj_b = if adj_shape.len() == 3 {
adj.select(0, b)
} else {
adj.clone()
};
let msg_b = adj_b.matmul(&x_b);
let mut out_b = msg_b.matmul(&weight);
if let Some(bias) = &self.bias {
out_b = out_b.add_var(&bias.variable());
}
per_sample.push(out_b.unsqueeze(0));
}
let refs: Vec<&Variable> = per_sample.iter().collect();
Variable::cat(&refs, 0)
}
pub fn in_features(&self) -> usize {
self.in_features
}
pub fn out_features(&self) -> usize {
self.out_features
}
}
impl Module for GCNConv {
fn forward(&self, input: &Variable) -> Variable {
let n = input.shape()[0];
let mut eye_data = vec![0.0f32; n * n];
for i in 0..n {
eye_data[i * n + i] = 1.0;
}
let adj = Variable::new(
axonml_tensor::Tensor::from_vec(eye_data, &[n, n])
.expect("identity matrix creation failed"),
false,
);
self.forward_graph(input, &adj)
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = vec![self.weight.clone()];
if let Some(bias) = &self.bias {
params.push(bias.clone());
}
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
params.insert("weight".to_string(), self.weight.clone());
if let Some(bias) = &self.bias {
params.insert("bias".to_string(), bias.clone());
}
params
}
fn name(&self) -> &'static str {
"GCNConv"
}
}
pub struct GATConv {
w: Parameter,
attn_src: Parameter,
attn_dst: Parameter,
bias: Option<Parameter>,
in_features: usize,
out_features: usize,
num_heads: usize,
negative_slope: f32,
}
impl GATConv {
pub fn new(in_features: usize, out_features: usize, num_heads: usize) -> Self {
let total_out = out_features * num_heads;
let scale = (2.0 / (in_features + total_out) as f32).sqrt();
let w_data: Vec<f32> = (0..in_features * total_out)
.map(|i| {
let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
x * scale
})
.collect();
let w = Parameter::named(
"w",
Tensor::from_vec(w_data, &[in_features, total_out]).expect("tensor creation failed"),
true,
);
let attn_scale = (1.0 / out_features as f32).sqrt();
let attn_src_data: Vec<f32> = (0..total_out)
.map(|i| {
let x = ((i as f32 * 0.723_606_8) % 1.0) * 2.0 - 1.0;
x * attn_scale
})
.collect();
let attn_dst_data: Vec<f32> = (0..total_out)
.map(|i| {
let x = ((i as f32 * 0.381_966_02) % 1.0) * 2.0 - 1.0;
x * attn_scale
})
.collect();
let attn_src = Parameter::named(
"attn_src",
Tensor::from_vec(attn_src_data, &[num_heads, out_features])
.expect("tensor creation failed"),
true,
);
let attn_dst = Parameter::named(
"attn_dst",
Tensor::from_vec(attn_dst_data, &[num_heads, out_features])
.expect("tensor creation failed"),
true,
);
let bias_data = vec![0.0; total_out];
let bias = Some(Parameter::named(
"bias",
Tensor::from_vec(bias_data, &[total_out]).expect("tensor creation failed"),
true,
));
Self {
w,
attn_src,
attn_dst,
bias,
in_features,
out_features,
num_heads,
negative_slope: 0.2,
}
}
pub fn forward_graph(&self, x: &Variable, adj: &Variable) -> Variable {
let shape = x.shape();
assert!(
shape.len() == 3,
"GATConv expects (batch, nodes, features), got {:?}",
shape
);
let batch = shape[0];
let nodes = shape[1];
let total_out = self.out_features * self.num_heads;
let x_data = x.data().to_vec();
let adj_data = adj.data().to_vec();
let w_data = self.w.data().to_vec();
let attn_src_data = self.attn_src.data().to_vec();
let attn_dst_data = self.attn_dst.data().to_vec();
let adj_nodes = if adj.shape().len() == 3 {
adj.shape()[1]
} else {
adj.shape()[0]
};
assert_eq!(adj_nodes, nodes, "Adjacency matrix size mismatch");
let mut output = vec![0.0f32; batch * nodes * total_out];
for b in 0..batch {
let mut h = vec![0.0f32; nodes * total_out];
for i in 0..nodes {
let x_off = (b * nodes + i) * self.in_features;
for o in 0..total_out {
let mut val = 0.0;
for f in 0..self.in_features {
val += x_data[x_off + f] * w_data[f * total_out + o];
}
h[i * total_out + o] = val;
}
}
let adj_off = if adj.shape().len() == 3 {
b * nodes * nodes
} else {
0
};
for head in 0..self.num_heads {
let head_off = head * self.out_features;
let mut attn_scores = vec![f32::NEG_INFINITY; nodes * nodes];
for i in 0..nodes {
let mut src_score = 0.0;
for f in 0..self.out_features {
src_score += h[i * total_out + head_off + f]
* attn_src_data[head * self.out_features + f];
}
for j in 0..nodes {
let a_ij = adj_data[adj_off + i * nodes + j];
if a_ij != 0.0 {
let mut dst_score = 0.0;
for f in 0..self.out_features {
dst_score += h[j * total_out + head_off + f]
* attn_dst_data[head * self.out_features + f];
}
let e = src_score + dst_score;
let e = if e > 0.0 { e } else { e * self.negative_slope };
attn_scores[i * nodes + j] = e;
}
}
}
for i in 0..nodes {
let row_start = i * nodes;
let row_end = row_start + nodes;
let row = &attn_scores[row_start..row_end];
let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
if max_val == f32::NEG_INFINITY {
continue; }
let mut sum_exp = 0.0f32;
let mut exps = vec![0.0; nodes];
for j in 0..nodes {
if row[j] > f32::NEG_INFINITY {
exps[j] = (row[j] - max_val).exp();
sum_exp += exps[j];
}
}
let out_off = (b * nodes + i) * total_out + head_off;
for j in 0..nodes {
if exps[j] > 0.0 {
let alpha = exps[j] / sum_exp;
for f in 0..self.out_features {
output[out_off + f] += alpha * h[j * total_out + head_off + f];
}
}
}
}
}
}
if let Some(bias) = &self.bias {
let bias_data = bias.data().to_vec();
for b in 0..batch {
for i in 0..nodes {
let offset = (b * nodes + i) * total_out;
for o in 0..total_out {
output[offset + o] += bias_data[o];
}
}
}
}
Variable::new(
Tensor::from_vec(output, &[batch, nodes, total_out]).expect("tensor creation failed"),
x.requires_grad(),
)
}
pub fn total_out_features(&self) -> usize {
self.out_features * self.num_heads
}
}
impl Module for GATConv {
fn forward(&self, input: &Variable) -> Variable {
let n = input.shape()[0];
let mut eye_data = vec![0.0f32; n * n];
for i in 0..n {
eye_data[i * n + i] = 1.0;
}
let adj = Variable::new(
axonml_tensor::Tensor::from_vec(eye_data, &[n, n])
.expect("identity matrix creation failed"),
false,
);
self.forward_graph(input, &adj)
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = vec![self.w.clone(), self.attn_src.clone(), self.attn_dst.clone()];
if let Some(bias) = &self.bias {
params.push(bias.clone());
}
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
params.insert("w".to_string(), self.w.clone());
params.insert("attn_src".to_string(), self.attn_src.clone());
params.insert("attn_dst".to_string(), self.attn_dst.clone());
if let Some(bias) = &self.bias {
params.insert("bias".to_string(), bias.clone());
}
params
}
fn name(&self) -> &'static str {
"GATConv"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gcn_conv_shape() {
let gcn = GCNConv::new(72, 128);
let x = Variable::new(
Tensor::from_vec(vec![1.0; 2 * 7 * 72], &[2, 7, 72]).expect("tensor creation failed"),
false,
);
let adj = Variable::new(
Tensor::from_vec(vec![1.0; 7 * 7], &[7, 7]).expect("tensor creation failed"),
false,
);
let output = gcn.forward_graph(&x, &adj);
assert_eq!(output.shape(), vec![2, 7, 128]);
}
#[test]
fn test_gcn_conv_identity_adjacency() {
let gcn = GCNConv::new(4, 8);
let x = Variable::new(
Tensor::from_vec(vec![1.0; 3 * 4], &[1, 3, 4]).expect("tensor creation failed"),
false,
);
let mut adj_data = vec![0.0; 9];
adj_data[0] = 1.0; adj_data[4] = 1.0; adj_data[8] = 1.0; let adj = Variable::new(
Tensor::from_vec(adj_data, &[3, 3]).expect("tensor creation failed"),
false,
);
let output = gcn.forward_graph(&x, &adj);
assert_eq!(output.shape(), vec![1, 3, 8]);
let data = output.data().to_vec();
for i in 0..3 {
for f in 0..8 {
assert!(
(data[i * 8 + f] - data[f]).abs() < 1e-6,
"Node outputs should be identical with identity adj and same input"
);
}
}
}
#[test]
fn test_gcn_conv_parameters() {
let gcn = GCNConv::new(16, 32);
let params = gcn.parameters();
assert_eq!(params.len(), 2);
let total_params: usize = params.iter().map(|p| p.numel()).sum();
assert_eq!(total_params, 16 * 32 + 32); }
#[test]
fn test_gcn_conv_no_bias() {
let gcn = GCNConv::without_bias(16, 32);
let params = gcn.parameters();
assert_eq!(params.len(), 1); }
#[test]
fn test_gcn_conv_named_parameters() {
let gcn = GCNConv::new(16, 32);
let params = gcn.named_parameters();
assert!(params.contains_key("weight"));
assert!(params.contains_key("bias"));
}
#[test]
fn test_gat_conv_shape() {
let gat = GATConv::new(72, 32, 4); let x = Variable::new(
Tensor::from_vec(vec![1.0; 2 * 7 * 72], &[2, 7, 72]).expect("tensor creation failed"),
false,
);
let adj = Variable::new(
Tensor::from_vec(vec![1.0; 7 * 7], &[7, 7]).expect("tensor creation failed"),
false,
);
let output = gat.forward_graph(&x, &adj);
assert_eq!(output.shape(), vec![2, 7, 128]); }
#[test]
fn test_gat_conv_single_head() {
let gat = GATConv::new(16, 8, 1);
let x = Variable::new(
Tensor::from_vec(vec![1.0; 5 * 16], &[1, 5, 16]).expect("tensor creation failed"),
false,
);
let adj = Variable::new(
Tensor::from_vec(vec![1.0; 5 * 5], &[5, 5]).expect("tensor creation failed"),
false,
);
let output = gat.forward_graph(&x, &adj);
assert_eq!(output.shape(), vec![1, 5, 8]);
}
#[test]
fn test_gat_conv_parameters() {
let gat = GATConv::new(16, 8, 4);
let params = gat.parameters();
assert_eq!(params.len(), 4);
let named = gat.named_parameters();
assert!(named.contains_key("w"));
assert!(named.contains_key("attn_src"));
assert!(named.contains_key("attn_dst"));
assert!(named.contains_key("bias"));
}
#[test]
fn test_gat_conv_total_output() {
let gat = GATConv::new(16, 32, 4);
assert_eq!(gat.total_out_features(), 128);
}
#[test]
fn test_gcn_zero_adjacency() {
let gcn = GCNConv::new(4, 4);
let x = Variable::new(
Tensor::from_vec(vec![99.0; 3 * 4], &[1, 3, 4]).expect("tensor creation failed"),
false,
);
let adj = Variable::new(
Tensor::from_vec(vec![0.0; 9], &[3, 3]).expect("tensor creation failed"),
false,
);
let output = gcn.forward_graph(&x, &adj);
let data = output.data().to_vec();
for val in &data {
assert!(
val.abs() < 1e-6,
"Zero adjacency should zero out message passing"
);
}
}
}