use ndarray::{Array1, Array2, Axis};
use rand::Rng;
use rand_distr::{Distribution, Uniform};
#[derive(Debug, thiserror::Error)]
pub enum AttentionError {
#[error("Dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch { expected: usize, actual: usize },
#[error("Invalid attention configuration: {0}")]
InvalidConfig(String),
#[error("Attention computation error: {0}")]
ComputationError(String),
}
pub type AttentionResult<T> = Result<T, AttentionError>;
#[derive(Debug, Clone)]
pub struct AttentionLayer {
query_weights: Array2<f32>,
key_weights: Array2<f32>,
value_weights: Array2<f32>,
attention_dim: usize,
input_dim: usize,
output_dim: usize,
scale: f32,
}
impl AttentionLayer {
#[must_use]
pub fn new(input_dim: usize, attention_dim: usize) -> Self {
let query_weights = xavier_init(input_dim, attention_dim);
let key_weights = xavier_init(input_dim, attention_dim);
let value_weights = xavier_init(input_dim, attention_dim);
let scale = (attention_dim as f32).sqrt();
Self {
query_weights,
key_weights,
value_weights,
attention_dim,
input_dim,
output_dim: attention_dim,
scale,
}
}
#[must_use]
pub fn output_dim(&self) -> usize {
self.output_dim
}
pub fn compute_attention(
&self,
query: &Array2<f32>,
key: &Array2<f32>,
mask: Option<&Array2<f32>>,
) -> AttentionResult<Array2<f32>> {
let scores = query.dot(&key.t()) / self.scale;
let scores = if let Some(mask) = mask {
let mut masked = scores;
for i in 0..masked.nrows() {
for j in 0..masked.ncols() {
if mask[[i, j]] == 0.0 {
masked[[i, j]] = f32::NEG_INFINITY;
}
}
}
masked
} else {
scores
};
let attention_weights = softmax_2d(&scores);
Ok(attention_weights)
}
pub fn forward(&self, features: &Array2<f32>) -> AttentionResult<Array2<f32>> {
if features.ncols() != self.input_dim {
return Err(AttentionError::DimensionMismatch {
expected: self.input_dim,
actual: features.ncols(),
});
}
let q = features.dot(&self.query_weights);
let k = features.dot(&self.key_weights);
let v = features.dot(&self.value_weights);
let attention_weights = self.compute_attention(&q, &k, None)?;
let output = attention_weights.dot(&v);
Ok(output)
}
pub fn forward_qkv(
&self,
query: &Array2<f32>,
key: &Array2<f32>,
value: &Array2<f32>,
mask: Option<&Array2<f32>>,
) -> AttentionResult<Array2<f32>> {
let q = query.dot(&self.query_weights);
let k = key.dot(&self.key_weights);
let v = value.dot(&self.value_weights);
let attention_weights = self.compute_attention(&q, &k, mask)?;
let output = attention_weights.dot(&v);
Ok(output)
}
pub fn graph_readout(&self, node_features: &Array2<f32>) -> AttentionResult<Array1<f32>> {
let attended = self.forward(node_features)?;
let mean = attended.mean_axis(Axis(0)).unwrap();
Ok(mean)
}
pub fn update_weights(&mut self, _lr: f32, weight_decay: f32) {
self.query_weights -= &(&self.query_weights * weight_decay);
self.key_weights -= &(&self.key_weights * weight_decay);
self.value_weights -= &(&self.value_weights * weight_decay);
}
}
#[derive(Debug, Clone)]
pub struct MultiHeadAttention {
heads: Vec<AttentionLayer>,
output_projection: Array2<f32>,
num_heads: usize,
head_dim: usize,
output_dim: usize,
dropout: f32,
}
impl MultiHeadAttention {
#[must_use]
pub fn new(input_dim: usize, num_heads: usize, head_dim: usize, dropout: f32) -> Self {
let mut heads = Vec::with_capacity(num_heads);
for _ in 0..num_heads {
heads.push(AttentionLayer::new(input_dim, head_dim));
}
let total_dim = num_heads * head_dim;
let output_projection = xavier_init(total_dim, input_dim);
Self {
heads,
output_projection,
num_heads,
head_dim,
output_dim: input_dim,
dropout,
}
}
#[must_use]
pub fn num_heads(&self) -> usize {
self.num_heads
}
#[must_use]
pub fn output_dim(&self) -> usize {
self.output_dim
}
pub fn forward(&self, features: &Array2<f32>) -> AttentionResult<Array2<f32>> {
let n = features.nrows();
let mut head_outputs = Vec::with_capacity(self.num_heads);
for head in &self.heads {
let output = head.forward(features)?;
head_outputs.push(output);
}
let mut concat = Array2::zeros((n, self.num_heads * self.head_dim));
for (h, output) in head_outputs.iter().enumerate() {
let start = h * self.head_dim;
for i in 0..n {
for j in 0..self.head_dim {
concat[[i, start + j]] = output[[i, j]];
}
}
}
let output = concat.dot(&self.output_projection);
Ok(output)
}
pub fn forward_qkv(
&self,
query: &Array2<f32>,
key: &Array2<f32>,
value: &Array2<f32>,
mask: Option<&Array2<f32>>,
) -> AttentionResult<Array2<f32>> {
let n = query.nrows();
let mut head_outputs = Vec::with_capacity(self.num_heads);
for head in &self.heads {
let output = head.forward_qkv(query, key, value, mask)?;
head_outputs.push(output);
}
let mut concat = Array2::zeros((n, self.num_heads * self.head_dim));
for (h, output) in head_outputs.iter().enumerate() {
let start = h * self.head_dim;
for i in 0..n {
for j in 0..self.head_dim {
concat[[i, start + j]] = output[[i, j]];
}
}
}
let output = concat.dot(&self.output_projection);
Ok(output)
}
pub fn graph_readout(&self, node_features: &Array2<f32>) -> AttentionResult<Array1<f32>> {
let attended = self.forward(node_features)?;
let mean = attended.mean_axis(Axis(0)).unwrap();
Ok(mean)
}
}
#[derive(Debug, Clone)]
pub struct CrossAttention {
query_proj: Array2<f32>,
key_proj: Array2<f32>,
value_proj: Array2<f32>,
attention_dim: usize,
source_dim: usize,
target_dim: usize,
}
impl CrossAttention {
#[must_use]
pub fn new(source_dim: usize, target_dim: usize, attention_dim: usize) -> Self {
Self {
query_proj: xavier_init(source_dim, attention_dim),
key_proj: xavier_init(target_dim, attention_dim),
value_proj: xavier_init(target_dim, attention_dim),
attention_dim,
source_dim,
target_dim,
}
}
pub fn forward(
&self,
source: &Array2<f32>,
target: &Array2<f32>,
) -> AttentionResult<Array2<f32>> {
let query = source.dot(&self.query_proj);
let key = target.dot(&self.key_proj);
let value = target.dot(&self.value_proj);
let scale = (self.attention_dim as f32).sqrt();
let scores = query.dot(&key.t()) / scale;
let attention_weights = softmax_2d(&scores);
let output = attention_weights.dot(&value);
Ok(output)
}
}
#[derive(Debug, Clone)]
pub struct SetAttention {
mha: MultiHeadAttention,
layer_norm_weight: Array1<f32>,
layer_norm_bias: Array1<f32>,
ffn_w1: Array2<f32>,
ffn_w2: Array2<f32>,
input_dim: usize,
hidden_dim: usize,
}
impl SetAttention {
#[must_use]
pub fn new(input_dim: usize, num_heads: usize, hidden_dim: usize) -> Self {
let head_dim = input_dim / num_heads;
let mha = MultiHeadAttention::new(input_dim, num_heads, head_dim, 0.0);
Self {
mha,
layer_norm_weight: Array1::ones(input_dim),
layer_norm_bias: Array1::zeros(input_dim),
ffn_w1: xavier_init(input_dim, hidden_dim),
ffn_w2: xavier_init(hidden_dim, input_dim),
input_dim,
hidden_dim,
}
}
pub fn forward(&self, features: &Array2<f32>) -> AttentionResult<Array2<f32>> {
let attended = self.mha.forward(features)?;
let residual1 = features + &attended;
let normed1 = layer_norm(&residual1, &self.layer_norm_weight, &self.layer_norm_bias);
let hidden = normed1.dot(&self.ffn_w1).mapv(|x| x.max(0.0)); let ffn_out = hidden.dot(&self.ffn_w2);
let residual2 = &normed1 + &ffn_out;
let output = layer_norm(&residual2, &self.layer_norm_weight, &self.layer_norm_bias);
Ok(output)
}
pub fn aggregate(&self, features: &Array2<f32>) -> AttentionResult<Array1<f32>> {
let transformed = self.forward(features)?;
Ok(transformed.mean_axis(Axis(0)).unwrap())
}
}
fn xavier_init(fan_in: usize, fan_out: usize) -> Array2<f32> {
let limit = (6.0 / (fan_in + fan_out) as f32).sqrt();
let uniform = Uniform::new(-limit, limit);
let mut rng = rand::thread_rng();
Array2::from_shape_fn((fan_in, fan_out), |_| uniform.sample(&mut rng))
}
fn softmax_2d(scores: &Array2<f32>) -> Array2<f32> {
let mut result = scores.clone();
for mut row in result.rows_mut() {
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0;
for val in row.iter_mut() {
if val.is_finite() {
*val = (*val - max).exp();
sum += *val;
} else {
*val = 0.0;
}
}
if sum > 0.0 {
row /= sum;
}
}
result
}
fn layer_norm(x: &Array2<f32>, weight: &Array1<f32>, bias: &Array1<f32>) -> Array2<f32> {
let eps = 1e-5;
let mut result = x.clone();
for mut row in result.rows_mut() {
let mean = row.mean().unwrap_or(0.0);
let variance = row.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / row.len() as f32;
let std = (variance + eps).sqrt();
for (i, val) in row.iter_mut().enumerate() {
*val = (*val - mean) / std * weight[i] + bias[i];
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_attention_layer() {
let layer = AttentionLayer::new(8, 16);
assert_eq!(layer.output_dim(), 16);
let features = Array2::from_elem((3, 8), 0.5);
let output = layer.forward(&features).unwrap();
assert_eq!(output.shape(), &[3, 16]);
}
#[test]
fn test_multi_head_attention() {
let mha = MultiHeadAttention::new(8, 4, 4, 0.0);
assert_eq!(mha.num_heads(), 4);
assert_eq!(mha.output_dim(), 8);
let features = Array2::from_elem((5, 8), 0.5);
let output = mha.forward(&features).unwrap();
assert_eq!(output.shape(), &[5, 8]);
}
#[test]
fn test_cross_attention() {
let cross = CrossAttention::new(8, 16, 32);
let source = Array2::from_elem((3, 8), 0.5);
let target = Array2::from_elem((5, 16), 0.5);
let output = cross.forward(&source, &target).unwrap();
assert_eq!(output.shape(), &[3, 32]);
}
#[test]
fn test_set_attention() {
let set_attn = SetAttention::new(16, 4, 64);
let features = Array2::from_elem((4, 16), 0.5);
let output = set_attn.forward(&features).unwrap();
assert_eq!(output.shape(), &[4, 16]);
let aggregated = set_attn.aggregate(&features).unwrap();
assert_eq!(aggregated.len(), 16);
}
#[test]
fn test_softmax() {
let scores = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let probs = softmax_2d(&scores);
for row in probs.rows() {
let sum: f32 = row.iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
}
assert!(probs[[0, 2]] > probs[[0, 1]]);
assert!(probs[[0, 1]] > probs[[0, 0]]);
}
#[test]
fn test_layer_norm() {
let x = Array2::from_shape_vec((2, 4), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
let weight = Array1::ones(4);
let bias = Array1::zeros(4);
let normed = layer_norm(&x, &weight, &bias);
for row in normed.rows() {
let mean: f32 = row.iter().sum::<f32>() / row.len() as f32;
assert!(mean.abs() < 1e-5);
}
}
#[test]
fn test_graph_readout() {
let layer = AttentionLayer::new(8, 4);
let node_features = Array2::from_elem((5, 8), 0.5);
let readout = layer.graph_readout(&node_features).unwrap();
assert_eq!(readout.len(), 4);
}
#[test]
fn test_attention_with_mask() {
let layer = AttentionLayer::new(4, 4);
let features = Array2::from_elem((3, 4), 0.5);
let mut mask = Array2::ones((3, 3));
mask[[0, 1]] = 0.0;
mask[[0, 2]] = 0.0;
let query = features.dot(&layer.query_weights);
let key = features.dot(&layer.key_weights);
let attn_weights = layer.compute_attention(&query, &key, Some(&mask)).unwrap();
assert!(attn_weights[[0, 0]] > 0.99); assert!(attn_weights[[0, 1]] < 0.01);
assert!(attn_weights[[0, 2]] < 0.01);
}
}