use crate::error::{AttentionError, AttentionResult};
use crate::traits::Attention;
use crate::utils::stable_softmax;
#[derive(Clone, Debug)]
pub struct EdgeFeaturedConfig {
pub node_dim: usize,
pub edge_dim: usize,
pub num_heads: usize,
pub dropout: f32,
pub concat_heads: bool,
pub add_self_loops: bool,
pub negative_slope: f32, }
impl Default for EdgeFeaturedConfig {
fn default() -> Self {
Self {
node_dim: 256,
edge_dim: 64,
num_heads: 4,
dropout: 0.0,
concat_heads: true,
add_self_loops: true,
negative_slope: 0.2,
}
}
}
impl EdgeFeaturedConfig {
pub fn builder() -> EdgeFeaturedConfigBuilder {
EdgeFeaturedConfigBuilder::default()
}
pub fn head_dim(&self) -> usize {
self.node_dim / self.num_heads
}
}
#[derive(Default)]
pub struct EdgeFeaturedConfigBuilder {
config: EdgeFeaturedConfig,
}
impl EdgeFeaturedConfigBuilder {
pub fn node_dim(mut self, d: usize) -> Self {
self.config.node_dim = d;
self
}
pub fn edge_dim(mut self, d: usize) -> Self {
self.config.edge_dim = d;
self
}
pub fn num_heads(mut self, n: usize) -> Self {
self.config.num_heads = n;
self
}
pub fn dropout(mut self, d: f32) -> Self {
self.config.dropout = d;
self
}
pub fn concat_heads(mut self, c: bool) -> Self {
self.config.concat_heads = c;
self
}
pub fn negative_slope(mut self, s: f32) -> Self {
self.config.negative_slope = s;
self
}
pub fn build(self) -> EdgeFeaturedConfig {
self.config
}
}
pub struct EdgeFeaturedAttention {
config: EdgeFeaturedConfig,
w_node: Vec<f32>, w_edge: Vec<f32>, a_src: Vec<f32>, a_dst: Vec<f32>, a_edge: Vec<f32>, }
impl EdgeFeaturedAttention {
pub fn new(config: EdgeFeaturedConfig) -> Self {
let head_dim = config.head_dim();
let num_heads = config.num_heads;
let node_scale = (2.0 / (config.node_dim + head_dim) as f32).sqrt();
let edge_scale = (2.0 / (config.edge_dim + head_dim) as f32).sqrt();
let attn_scale = (1.0 / head_dim as f32).sqrt();
let mut seed = 42u64;
let mut rand = || {
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
(seed as f32) / (u64::MAX as f32) - 0.5
};
let w_node: Vec<f32> = (0..num_heads * head_dim * config.node_dim)
.map(|_| rand() * 2.0 * node_scale)
.collect();
let w_edge: Vec<f32> = (0..num_heads * head_dim * config.edge_dim)
.map(|_| rand() * 2.0 * edge_scale)
.collect();
let a_src: Vec<f32> = (0..num_heads * head_dim)
.map(|_| rand() * 2.0 * attn_scale)
.collect();
let a_dst: Vec<f32> = (0..num_heads * head_dim)
.map(|_| rand() * 2.0 * attn_scale)
.collect();
let a_edge: Vec<f32> = (0..num_heads * head_dim)
.map(|_| rand() * 2.0 * attn_scale)
.collect();
Self {
config,
w_node,
w_edge,
a_src,
a_dst,
a_edge,
}
}
fn transform_node(&self, node: &[f32], head: usize) -> Vec<f32> {
let head_dim = self.config.head_dim();
let node_dim = self.config.node_dim;
(0..head_dim)
.map(|i| {
node.iter()
.enumerate()
.map(|(j, &nj)| nj * self.w_node[head * head_dim * node_dim + i * node_dim + j])
.sum()
})
.collect()
}
fn transform_edge(&self, edge: &[f32], head: usize) -> Vec<f32> {
let head_dim = self.config.head_dim();
let edge_dim = self.config.edge_dim;
(0..head_dim)
.map(|i| {
edge.iter()
.enumerate()
.map(|(j, &ej)| ej * self.w_edge[head * head_dim * edge_dim + i * edge_dim + j])
.sum()
})
.collect()
}
fn attention_coeff(&self, src: &[f32], dst: &[f32], edge: &[f32], head: usize) -> f32 {
let head_dim = self.config.head_dim();
let mut score = 0.0f32;
for i in 0..head_dim {
let offset = head * head_dim + i;
score += src[i] * self.a_src[offset];
score += dst[i] * self.a_dst[offset];
score += edge[i] * self.a_edge[offset];
}
if score < 0.0 {
self.config.negative_slope * score
} else {
score
}
}
}
impl EdgeFeaturedAttention {
pub fn compute_with_edges(
&self,
query: &[f32],
keys: &[&[f32]],
values: &[&[f32]],
edges: &[&[f32]],
) -> AttentionResult<Vec<f32>> {
if keys.len() != edges.len() {
return Err(AttentionError::InvalidConfig(
"Keys and edges must have same length".to_string(),
));
}
let num_heads = self.config.num_heads;
let head_dim = self.config.head_dim();
let n = keys.len();
let query_transformed: Vec<Vec<f32>> = (0..num_heads)
.map(|h| self.transform_node(query, h))
.collect();
let mut head_outputs: Vec<Vec<f32>> = Vec::with_capacity(num_heads);
for h in 0..num_heads {
let keys_t: Vec<Vec<f32>> = keys.iter().map(|k| self.transform_node(k, h)).collect();
let edges_t: Vec<Vec<f32>> = edges.iter().map(|e| self.transform_edge(e, h)).collect();
let coeffs: Vec<f32> = (0..n)
.map(|i| self.attention_coeff(&query_transformed[h], &keys_t[i], &edges_t[i], h))
.collect();
let weights = stable_softmax(&coeffs);
let mut head_out = vec![0.0f32; head_dim];
for (i, &w) in weights.iter().enumerate() {
let value_t = self.transform_node(values[i], h);
for (j, &vj) in value_t.iter().enumerate() {
head_out[j] += w * vj;
}
}
head_outputs.push(head_out);
}
if self.config.concat_heads {
Ok(head_outputs.into_iter().flatten().collect())
} else {
let mut output = vec![0.0f32; head_dim];
for head_out in &head_outputs {
for (i, &v) in head_out.iter().enumerate() {
output[i] += v / num_heads as f32;
}
}
Ok(output)
}
}
pub fn edge_dim(&self) -> usize {
self.config.edge_dim
}
}
impl Attention for EdgeFeaturedAttention {
fn compute(
&self,
query: &[f32],
keys: &[&[f32]],
values: &[&[f32]],
) -> AttentionResult<Vec<f32>> {
if keys.is_empty() {
return Err(AttentionError::InvalidConfig("Empty keys".to_string()));
}
if query.len() != self.config.node_dim {
return Err(AttentionError::DimensionMismatch {
expected: self.config.node_dim,
actual: query.len(),
});
}
let zero_edge = vec![0.0f32; self.config.edge_dim];
let edges: Vec<&[f32]> = (0..keys.len()).map(|_| zero_edge.as_slice()).collect();
self.compute_with_edges(query, keys, values, &edges)
}
fn compute_with_mask(
&self,
query: &[f32],
keys: &[&[f32]],
values: &[&[f32]],
mask: Option<&[bool]>,
) -> AttentionResult<Vec<f32>> {
if let Some(m) = mask {
let filtered: Vec<(usize, bool)> = m
.iter()
.copied()
.enumerate()
.filter(|(_, keep)| *keep)
.collect();
let filtered_keys: Vec<&[f32]> = filtered.iter().map(|(i, _)| keys[*i]).collect();
let filtered_values: Vec<&[f32]> = filtered.iter().map(|(i, _)| values[*i]).collect();
self.compute(query, &filtered_keys, &filtered_values)
} else {
self.compute(query, keys, values)
}
}
fn dim(&self) -> usize {
if self.config.concat_heads {
self.config.node_dim
} else {
self.config.head_dim()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_edge_featured_attention() {
let config = EdgeFeaturedConfig::builder()
.node_dim(64)
.edge_dim(16)
.num_heads(4)
.build();
let attn = EdgeFeaturedAttention::new(config);
let query = vec![0.5; 64];
let keys: Vec<Vec<f32>> = (0..10).map(|_| vec![0.3; 64]).collect();
let values: Vec<Vec<f32>> = (0..10).map(|_| vec![1.0; 64]).collect();
let edges: Vec<Vec<f32>> = (0..10).map(|_| vec![0.2; 16]).collect();
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
let edges_refs: Vec<&[f32]> = edges.iter().map(|e| e.as_slice()).collect();
let result = attn
.compute_with_edges(&query, &keys_refs, &values_refs, &edges_refs)
.unwrap();
assert_eq!(result.len(), 64);
}
#[test]
fn test_without_edges() {
let config = EdgeFeaturedConfig::builder()
.node_dim(32)
.edge_dim(8)
.num_heads(2)
.build();
let attn = EdgeFeaturedAttention::new(config);
let query = vec![0.5; 32];
let keys: Vec<Vec<f32>> = vec![vec![0.3; 32]; 5];
let values: Vec<Vec<f32>> = vec![vec![1.0; 32]; 5];
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
let result = attn.compute(&query, &keys_refs, &values_refs).unwrap();
assert_eq!(result.len(), 32);
}
#[test]
fn test_leaky_relu() {
let config = EdgeFeaturedConfig::builder()
.node_dim(16)
.edge_dim(4)
.num_heads(1)
.negative_slope(0.2)
.build();
let attn = EdgeFeaturedAttention::new(config);
let query = vec![-1.0; 16];
let keys: Vec<Vec<f32>> = vec![vec![-0.5; 16]; 3];
let values: Vec<Vec<f32>> = vec![vec![1.0; 16]; 3];
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
let result = attn.compute(&query, &keys_refs, &values_refs).unwrap();
assert_eq!(result.len(), 16);
}
}