use crate::error::GnnError;
use ndarray::{Array1, Array2, ArrayView1};
use rand::Rng;
use rand_distr::{Distribution, Normal};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Linear {
weights: Array2<f32>,
bias: Array1<f32>,
}
impl Linear {
pub fn new(input_dim: usize, output_dim: usize) -> Self {
let mut rng = rand::thread_rng();
let scale = (2.0 / (input_dim + output_dim) as f32).sqrt();
let normal = Normal::new(0.0, scale as f64).unwrap();
let weights =
Array2::from_shape_fn((output_dim, input_dim), |_| normal.sample(&mut rng) as f32);
let bias = Array1::zeros(output_dim);
Self { weights, bias }
}
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
let x = ArrayView1::from(input);
let output = self.weights.dot(&x) + &self.bias;
output.to_vec()
}
pub fn output_dim(&self) -> usize {
self.weights.shape()[0]
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerNorm {
gamma: Array1<f32>,
beta: Array1<f32>,
eps: f32,
}
impl LayerNorm {
pub fn new(dim: usize, eps: f32) -> Self {
Self {
gamma: Array1::ones(dim),
beta: Array1::zeros(dim),
eps,
}
}
pub fn forward(&self, input: &[f32]) -> Vec<f32> {
let x = ArrayView1::from(input);
let mean = x.mean().unwrap_or(0.0);
let variance = x.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / x.len() as f32;
let normalized = x.mapv(|v| (v - mean) / (variance + self.eps).sqrt());
let output = &self.gamma * &normalized + &self.beta;
output.to_vec()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiHeadAttention {
num_heads: usize,
head_dim: usize,
q_linear: Linear,
k_linear: Linear,
v_linear: Linear,
out_linear: Linear,
}
impl MultiHeadAttention {
pub fn new(embed_dim: usize, num_heads: usize) -> Result<Self, GnnError> {
if embed_dim % num_heads != 0 {
return Err(GnnError::layer_config(format!(
"Embedding dimension ({}) must be divisible by number of heads ({})",
embed_dim, num_heads
)));
}
let head_dim = embed_dim / num_heads;
Ok(Self {
num_heads,
head_dim,
q_linear: Linear::new(embed_dim, embed_dim),
k_linear: Linear::new(embed_dim, embed_dim),
v_linear: Linear::new(embed_dim, embed_dim),
out_linear: Linear::new(embed_dim, embed_dim),
})
}
pub fn forward(&self, query: &[f32], keys: &[Vec<f32>], values: &[Vec<f32>]) -> Vec<f32> {
if keys.is_empty() || values.is_empty() {
return query.to_vec();
}
let q = self.q_linear.forward(query);
let k: Vec<Vec<f32>> = keys.iter().map(|k| self.k_linear.forward(k)).collect();
let v: Vec<Vec<f32>> = values.iter().map(|v| self.v_linear.forward(v)).collect();
let q_heads = self.split_heads(&q);
let k_heads: Vec<Vec<Vec<f32>>> = k.iter().map(|k_vec| self.split_heads(k_vec)).collect();
let v_heads: Vec<Vec<Vec<f32>>> = v.iter().map(|v_vec| self.split_heads(v_vec)).collect();
let mut head_outputs = Vec::new();
for h in 0..self.num_heads {
let q_h = &q_heads[h];
let k_h: Vec<&Vec<f32>> = k_heads.iter().map(|heads| &heads[h]).collect();
let v_h: Vec<&Vec<f32>> = v_heads.iter().map(|heads| &heads[h]).collect();
let head_output = self.scaled_dot_product_attention(q_h, &k_h, &v_h);
head_outputs.push(head_output);
}
let concat: Vec<f32> = head_outputs.into_iter().flatten().collect();
self.out_linear.forward(&concat)
}
fn split_heads(&self, x: &[f32]) -> Vec<Vec<f32>> {
let mut heads = Vec::new();
for h in 0..self.num_heads {
let start = h * self.head_dim;
let end = start + self.head_dim;
heads.push(x[start..end].to_vec());
}
heads
}
fn scaled_dot_product_attention(
&self,
query: &[f32],
keys: &[&Vec<f32>],
values: &[&Vec<f32>],
) -> Vec<f32> {
if keys.is_empty() {
return query.to_vec();
}
let scale = (self.head_dim as f32).sqrt();
let scores: Vec<f32> = keys
.iter()
.map(|k| {
let dot: f32 = query.iter().zip(k.iter()).map(|(q, k)| q * k).sum();
dot / scale
})
.collect();
let max_score = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
let sum_exp: f32 = exp_scores.iter().sum::<f32>().max(1e-10);
let attention_weights: Vec<f32> = exp_scores.iter().map(|&e| e / sum_exp).collect();
let mut output = vec![0.0; self.head_dim];
for (weight, value) in attention_weights.iter().zip(values.iter()) {
for (out, &val) in output.iter_mut().zip(value.iter()) {
*out += weight * val;
}
}
output
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GRUCell {
w_z: Linear,
u_z: Linear,
w_r: Linear,
u_r: Linear,
w_h: Linear,
u_h: Linear,
}
impl GRUCell {
pub fn new(input_dim: usize, hidden_dim: usize) -> Self {
Self {
w_z: Linear::new(input_dim, hidden_dim),
u_z: Linear::new(hidden_dim, hidden_dim),
w_r: Linear::new(input_dim, hidden_dim),
u_r: Linear::new(hidden_dim, hidden_dim),
w_h: Linear::new(input_dim, hidden_dim),
u_h: Linear::new(hidden_dim, hidden_dim),
}
}
pub fn forward(&self, input: &[f32], hidden: &[f32]) -> Vec<f32> {
let z =
self.sigmoid_vec(&self.add_vecs(&self.w_z.forward(input), &self.u_z.forward(hidden)));
let r =
self.sigmoid_vec(&self.add_vecs(&self.w_r.forward(input), &self.u_r.forward(hidden)));
let r_hidden = self.mul_vecs(&r, hidden);
let h_tilde =
self.tanh_vec(&self.add_vecs(&self.w_h.forward(input), &self.u_h.forward(&r_hidden)));
let one_minus_z: Vec<f32> = z.iter().map(|&zval| 1.0 - zval).collect();
let term1 = self.mul_vecs(&one_minus_z, hidden);
let term2 = self.mul_vecs(&z, &h_tilde);
self.add_vecs(&term1, &term2)
}
fn sigmoid(&self, x: f32) -> f32 {
if x > 0.0 {
1.0 / (1.0 + (-x).exp())
} else {
let ex = x.exp();
ex / (1.0 + ex)
}
}
fn sigmoid_vec(&self, v: &[f32]) -> Vec<f32> {
v.iter().map(|&x| self.sigmoid(x)).collect()
}
fn tanh(&self, x: f32) -> f32 {
x.tanh()
}
fn tanh_vec(&self, v: &[f32]) -> Vec<f32> {
v.iter().map(|&x| self.tanh(x)).collect()
}
fn add_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
}
fn mul_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
a.iter().zip(b.iter()).map(|(x, y)| x * y).collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RuvectorLayer {
w_msg: Linear,
w_agg: Linear,
w_update: GRUCell,
attention: MultiHeadAttention,
norm: LayerNorm,
dropout: f32,
}
impl RuvectorLayer {
pub fn new(
input_dim: usize,
hidden_dim: usize,
heads: usize,
dropout: f32,
) -> Result<Self, GnnError> {
if !(0.0..=1.0).contains(&dropout) {
return Err(GnnError::layer_config(format!(
"Dropout must be between 0.0 and 1.0, got {}",
dropout
)));
}
Ok(Self {
w_msg: Linear::new(input_dim, hidden_dim),
w_agg: Linear::new(hidden_dim, hidden_dim),
w_update: GRUCell::new(hidden_dim, hidden_dim),
attention: MultiHeadAttention::new(hidden_dim, heads)?,
norm: LayerNorm::new(hidden_dim, 1e-5),
dropout,
})
}
pub fn forward(
&self,
node_embedding: &[f32],
neighbor_embeddings: &[Vec<f32>],
edge_weights: &[f32],
) -> Vec<f32> {
if neighbor_embeddings.is_empty() {
let projected = self.w_msg.forward(node_embedding);
return self.norm.forward(&projected);
}
let node_msg = self.w_msg.forward(node_embedding);
let neighbor_msgs: Vec<Vec<f32>> = neighbor_embeddings
.iter()
.map(|n| self.w_msg.forward(n))
.collect();
let attention_output = self
.attention
.forward(&node_msg, &neighbor_msgs, &neighbor_msgs);
let weighted_msgs = self.aggregate_messages(&neighbor_msgs, edge_weights);
let combined = self.add_vecs(&attention_output, &weighted_msgs);
let aggregated = self.w_agg.forward(&combined);
let updated = self.w_update.forward(&aggregated, &node_msg);
let dropped = self.apply_dropout(&updated);
self.norm.forward(&dropped)
}
fn aggregate_messages(&self, messages: &[Vec<f32>], weights: &[f32]) -> Vec<f32> {
if messages.is_empty() || weights.is_empty() {
return vec![0.0; self.w_msg.output_dim()];
}
let weight_sum: f32 = weights.iter().sum();
let normalized_weights: Vec<f32> = if weight_sum > 0.0 {
weights.iter().map(|&w| w / weight_sum).collect()
} else {
vec![1.0 / weights.len() as f32; weights.len()]
};
let dim = messages[0].len();
let mut aggregated = vec![0.0; dim];
for (msg, &weight) in messages.iter().zip(normalized_weights.iter()) {
for (agg, &m) in aggregated.iter_mut().zip(msg.iter()) {
*agg += weight * m;
}
}
aggregated
}
fn apply_dropout(&self, input: &[f32]) -> Vec<f32> {
let scale = 1.0 - self.dropout;
input.iter().map(|&x| x * scale).collect()
}
fn add_vecs(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_linear_layer() {
let linear = Linear::new(4, 2);
let input = vec![1.0, 2.0, 3.0, 4.0];
let output = linear.forward(&input);
assert_eq!(output.len(), 2);
}
#[test]
fn test_layer_norm() {
let norm = LayerNorm::new(4, 1e-5);
let input = vec![1.0, 2.0, 3.0, 4.0];
let output = norm.forward(&input);
let mean: f32 = output.iter().sum::<f32>() / output.len() as f32;
assert!((mean).abs() < 1e-5);
}
#[test]
fn test_multihead_attention() {
let attention = MultiHeadAttention::new(8, 2).unwrap();
let query = vec![0.5; 8];
let keys = vec![vec![0.3; 8], vec![0.7; 8]];
let values = vec![vec![0.2; 8], vec![0.8; 8]];
let output = attention.forward(&query, &keys, &values);
assert_eq!(output.len(), 8);
}
#[test]
fn test_multihead_attention_invalid_dims() {
let result = MultiHeadAttention::new(10, 3);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("divisible"));
}
#[test]
fn test_gru_cell() {
let gru = GRUCell::new(4, 8);
let input = vec![1.0; 4];
let hidden = vec![0.5; 8];
let new_hidden = gru.forward(&input, &hidden);
assert_eq!(new_hidden.len(), 8);
}
#[test]
fn test_ruvector_layer() {
let layer = RuvectorLayer::new(4, 8, 2, 0.1).unwrap();
let node = vec![1.0, 2.0, 3.0, 4.0];
let neighbors = vec![vec![0.5, 1.0, 1.5, 2.0], vec![2.0, 3.0, 4.0, 5.0]];
let weights = vec![0.3, 0.7];
let output = layer.forward(&node, &neighbors, &weights);
assert_eq!(output.len(), 8);
}
#[test]
fn test_ruvector_layer_no_neighbors() {
let layer = RuvectorLayer::new(4, 8, 2, 0.1).unwrap();
let node = vec![1.0, 2.0, 3.0, 4.0];
let neighbors: Vec<Vec<f32>> = vec![];
let weights: Vec<f32> = vec![];
let output = layer.forward(&node, &neighbors, &weights);
assert_eq!(output.len(), 8);
}
#[test]
fn test_ruvector_layer_invalid_dropout() {
let result = RuvectorLayer::new(4, 8, 2, 1.5);
assert!(result.is_err());
}
#[test]
fn test_ruvector_layer_invalid_heads() {
let result = RuvectorLayer::new(4, 7, 3, 0.1);
assert!(result.is_err());
}
}