use rand::Rng;
pub struct GnnLayer {
input_dim: usize,
#[allow(dead_code)]
hidden_dim: usize,
num_heads: usize,
weights: Vec<f32>,
#[allow(dead_code)]
attention_weights: Vec<f32>,
}
impl GnnLayer {
pub fn new(input_dim: usize, hidden_dim: usize, num_heads: usize) -> Self {
let mut rng = rand::thread_rng();
let scale = (2.0 / (input_dim + hidden_dim) as f32).sqrt();
let weights: Vec<f32> = (0..input_dim * hidden_dim)
.map(|_| rng.gen_range(-scale..scale))
.collect();
let attention_weights: Vec<f32> = (0..num_heads * input_dim)
.map(|_| rng.gen_range(-scale..scale))
.collect();
Self {
input_dim,
hidden_dim,
num_heads,
weights,
attention_weights,
}
}
pub fn forward(
&self,
node_embedding: &[f32],
neighbor_embeddings: &[Vec<f32>],
edge_weights: &[f32],
) -> Vec<f32> {
if neighbor_embeddings.is_empty() {
return self.linear_transform(node_embedding);
}
let attention_scores = self.compute_attention(node_embedding, neighbor_embeddings);
let combined_weights: Vec<f32> = attention_scores
.iter()
.zip(edge_weights.iter())
.map(|(a, e)| a * e)
.collect();
let weight_sum: f32 = combined_weights.iter().sum();
let normalized_weights: Vec<f32> = if weight_sum > 0.0 {
combined_weights.iter().map(|w| w / weight_sum).collect()
} else {
vec![1.0 / neighbor_embeddings.len() as f32; neighbor_embeddings.len()]
};
let mut aggregated = vec![0.0f32; self.input_dim];
for (neighbor, &weight) in neighbor_embeddings.iter().zip(normalized_weights.iter()) {
for (i, &val) in neighbor.iter().enumerate() {
if i < self.input_dim {
aggregated[i] += val * weight;
}
}
}
let combined: Vec<f32> = node_embedding
.iter()
.zip(aggregated.iter())
.map(|(n, a)| 0.5 * n + 0.5 * a)
.collect();
let transformed = self.linear_transform(&combined);
transformed.into_iter().map(|x| x.max(0.0)).collect()
}
fn compute_attention(&self, query: &[f32], keys: &[Vec<f32>]) -> Vec<f32> {
let head_dim = self.input_dim / self.num_heads;
let scores: Vec<f32> = keys
.iter()
.map(|key| {
let mut score = 0.0f32;
for h in 0..self.num_heads {
let start = h * head_dim;
let end = (start + head_dim).min(query.len()).min(key.len());
let dot: f32 = query[start..end]
.iter()
.zip(key[start..end].iter())
.map(|(q, k)| q * k)
.sum();
score += dot / (head_dim as f32).sqrt();
}
score / self.num_heads as f32
})
.collect();
let max_score = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_scores: Vec<f32> = scores.iter().map(|s| (s - max_score).exp()).collect();
let sum_exp: f32 = exp_scores.iter().sum();
exp_scores.iter().map(|e| e / sum_exp).collect()
}
fn linear_transform(&self, input: &[f32]) -> Vec<f32> {
let mut output = vec![0.0f32; self.input_dim];
let cols = input.len().min(self.input_dim);
for (i, out) in output.iter_mut().enumerate() {
for (j, &val) in input.iter().take(cols).enumerate() {
let weight_idx = i * self.input_dim + j;
if weight_idx < self.weights.len() {
*out += val * self.weights[weight_idx];
}
}
}
output
}
pub fn update(&mut self, query: &[f32], target_score: f32, learning_rate: f32) {
for (i, weight) in self.weights.iter_mut().enumerate() {
let query_idx = i % query.len().max(1);
let gradient = query.get(query_idx).unwrap_or(&0.0) * (target_score - 1.0);
*weight += learning_rate * gradient * 0.01;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gnn_layer_forward() {
let layer = GnnLayer::new(64, 128, 4);
let node = vec![0.5; 64];
let neighbors = vec![vec![0.3; 64], vec![0.7; 64]];
let edge_weights = vec![0.8, 0.6];
let output = layer.forward(&node, &neighbors, &edge_weights);
assert_eq!(output.len(), 64);
}
#[test]
fn test_gnn_layer_no_neighbors() {
let layer = GnnLayer::new(32, 64, 2);
let node = vec![0.5; 32];
let output = layer.forward(&node, &[], &[]);
assert_eq!(output.len(), 32);
}
}