use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use super::graphsage::{Graph, Lcg};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GATConfig {
pub input_dim: usize,
pub hidden_head_dim: usize,
pub hidden_num_heads: usize,
pub output_head_dim: usize,
pub output_num_heads: usize,
pub num_layers: usize,
pub dropout: f64,
pub alpha: f64,
pub concat_hidden: bool,
pub avg_output: bool,
pub normalize_output: bool,
pub seed: u64,
}
impl Default for GATConfig {
fn default() -> Self {
Self {
input_dim: 64,
hidden_head_dim: 8,
hidden_num_heads: 8,
output_head_dim: 8,
output_num_heads: 1,
num_layers: 2,
dropout: 0.6,
alpha: 0.2,
concat_hidden: true,
avg_output: true,
normalize_output: true,
seed: 42,
}
}
}
impl GATConfig {
pub fn output_dim(&self) -> usize {
if self.avg_output {
self.output_head_dim
} else {
self.output_head_dim * self.output_num_heads
}
}
pub fn hidden_layer_out_dim(&self) -> usize {
if self.concat_hidden {
self.hidden_head_dim * self.hidden_num_heads
} else {
self.hidden_head_dim
}
}
}
#[derive(Debug, Clone)]
struct AttentionHead {
w: Vec<Vec<f64>>, a_src: Vec<f64>,
a_dst: Vec<f64>,
head_dim: usize,
alpha: f64,
}
impl AttentionHead {
fn new(in_dim: usize, head_dim: usize, alpha: f64, rng: &mut Lcg) -> Self {
let w_scale = (6.0 / (in_dim + head_dim) as f64).sqrt();
let w = (0..head_dim)
.map(|_| (0..in_dim).map(|_| rng.next_f64_range(w_scale)).collect())
.collect();
let a_scale = (2.0 / head_dim as f64).sqrt();
let a_src = (0..head_dim).map(|_| rng.next_f64_range(a_scale)).collect();
let a_dst = (0..head_dim).map(|_| rng.next_f64_range(a_scale)).collect();
Self {
w,
a_src,
a_dst,
head_dim,
alpha,
}
}
fn linear(&self, x: &[f64]) -> Vec<f64> {
self.w
.iter()
.map(|row| row.iter().zip(x.iter()).map(|(&w, &xi)| w * xi).sum())
.collect()
}
fn leaky_relu(&self, x: f64) -> f64 {
if x >= 0.0 {
x
} else {
self.alpha * x
}
}
fn softmax(scores: &[f64]) -> Vec<f64> {
if scores.is_empty() {
return Vec::new();
}
let max = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<f64> = scores.iter().map(|&s| (s - max).exp()).collect();
let sum: f64 = exps.iter().sum();
if sum < 1e-12 {
return vec![1.0 / scores.len() as f64; scores.len()];
}
exps.iter().map(|&e| e / sum).collect()
}
fn forward(
&self,
v: usize,
all_transformed: &[Vec<f64>], neighbors: &[usize],
dropout_mask: &[bool], ) -> Vec<f64> {
let mut candidates: Vec<usize> = vec![v];
candidates.extend_from_slice(neighbors);
let h_v = &all_transformed[v];
let scores: Vec<f64> = candidates
.iter()
.map(|&u| {
let h_u = &all_transformed[u];
let src: f64 = self
.a_src
.iter()
.zip(h_v.iter())
.map(|(&a, &h)| a * h)
.sum();
let dst: f64 = self
.a_dst
.iter()
.zip(h_u.iter())
.map(|(&a, &h)| a * h)
.sum();
self.leaky_relu(src + dst)
})
.collect();
let weights = Self::softmax(&scores);
let mut out = vec![0.0f64; self.head_dim];
for (k, (&u, &w)) in candidates.iter().zip(weights.iter()).enumerate() {
let keep = dropout_mask.get(k).copied().unwrap_or(true);
let effective_w = if keep { w } else { 0.0 };
let h_u = &all_transformed[u];
for (j, &val) in h_u.iter().enumerate() {
out[j] += effective_w * val;
}
}
out.iter_mut().for_each(|x| {
if *x < 0.0 {
*x = (*x).exp() - 1.0;
}
});
out
}
}
pub struct GATLayer {
heads: Vec<AttentionHead>,
in_dim: usize,
head_dim: usize,
num_heads: usize,
concat: bool,
dropout_rate: f64,
}
impl std::fmt::Debug for GATLayer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GATLayer")
.field("in_dim", &self.in_dim)
.field("num_heads", &self.num_heads)
.field("head_dim", &self.head_dim)
.field("concat", &self.concat)
.finish()
}
}
impl GATLayer {
pub fn new(
in_dim: usize,
head_dim: usize,
num_heads: usize,
alpha: f64,
dropout: f64,
concat: bool,
rng: &mut Lcg,
) -> Result<Self> {
if in_dim == 0 {
return Err(anyhow!("GATLayer: in_dim must be > 0"));
}
if head_dim == 0 {
return Err(anyhow!("GATLayer: head_dim must be > 0"));
}
if num_heads == 0 {
return Err(anyhow!("GATLayer: num_heads must be > 0"));
}
let heads = (0..num_heads)
.map(|_| AttentionHead::new(in_dim, head_dim, alpha, rng))
.collect();
Ok(Self {
heads,
in_dim,
head_dim,
num_heads,
concat,
dropout_rate: dropout,
})
}
pub fn out_dim(&self) -> usize {
if self.concat {
self.head_dim * self.num_heads
} else {
self.head_dim
}
}
pub fn forward(
&self,
graph: &Graph,
current_embeddings: &[Vec<f64>],
rng: &mut Lcg,
) -> Vec<Vec<f64>> {
let n = graph.num_nodes();
let all_transformed: Vec<Vec<Vec<f64>>> = self
.heads
.iter()
.map(|head| {
current_embeddings
.iter()
.map(|emb| head.linear(emb))
.collect()
})
.collect();
(0..n)
.map(|v| {
let neighbors = graph.neighbors(v);
let num_candidates = 1 + neighbors.len(); let dropout_mask: Vec<bool> = (0..num_candidates)
.map(|_| rng.next_f64() > self.dropout_rate)
.collect();
let head_outputs: Vec<Vec<f64>> = self
.heads
.iter()
.enumerate()
.map(|(k, head)| head.forward(v, &all_transformed[k], neighbors, &dropout_mask))
.collect();
if self.concat {
head_outputs.into_iter().flatten().collect()
} else {
let mut avg = vec![0.0f64; self.head_dim];
for h in &head_outputs {
for (i, &v) in h.iter().enumerate() {
avg[i] += v;
}
}
let k = self.num_heads as f64;
avg.iter_mut().for_each(|x| *x /= k);
avg
}
})
.collect()
}
}
pub struct GATModel {
layers: Vec<GATLayer>,
config: GATConfig,
}
impl std::fmt::Debug for GATModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GATModel")
.field("num_layers", &self.layers.len())
.field("output_dim", &self.config.output_dim())
.finish()
}
}
impl GATModel {
pub fn new(config: GATConfig) -> Result<Self> {
if config.input_dim == 0 {
return Err(anyhow!("GATConfig: input_dim must be > 0"));
}
if config.num_layers == 0 {
return Err(anyhow!("GATConfig: num_layers must be > 0"));
}
if config.hidden_head_dim == 0 {
return Err(anyhow!("GATConfig: hidden_head_dim must be > 0"));
}
if config.output_head_dim == 0 {
return Err(anyhow!("GATConfig: output_head_dim must be > 0"));
}
if config.hidden_num_heads == 0 || config.output_num_heads == 0 {
return Err(anyhow!("GATConfig: num_heads must be > 0"));
}
let mut rng = Lcg::new(config.seed);
let mut layers = Vec::with_capacity(config.num_layers);
let mut current_in_dim = config.input_dim;
for layer_idx in 0..config.num_layers {
let is_last = layer_idx == config.num_layers - 1;
let (head_dim, num_heads, concat) = if is_last {
(
config.output_head_dim,
config.output_num_heads,
!config.avg_output,
)
} else {
(
config.hidden_head_dim,
config.hidden_num_heads,
config.concat_hidden,
)
};
let layer = GATLayer::new(
current_in_dim,
head_dim,
num_heads,
config.alpha,
config.dropout,
concat,
&mut rng,
)?;
current_in_dim = layer.out_dim();
layers.push(layer);
}
Ok(Self { layers, config })
}
pub fn embed(&self, graph: &Graph) -> Result<GATEmbeddings> {
if graph.num_nodes() == 0 {
return Err(anyhow!("GATModel: graph has no nodes"));
}
let mut rng = Lcg::new(self.config.seed.wrapping_add(0xcafe_babe));
let mut current: Vec<Vec<f64>> = graph.node_features.clone();
for layer in &self.layers {
current = layer.forward(graph, ¤t, &mut rng);
}
if self.config.normalize_output {
for emb in &mut current {
l2_normalize_inplace(emb);
}
}
let dim = self.config.output_dim();
let num_nodes = graph.num_nodes();
Ok(GATEmbeddings {
embeddings: current,
num_nodes,
dim,
})
}
}
#[derive(Debug, Clone)]
pub struct GATEmbeddings {
pub embeddings: Vec<Vec<f64>>,
pub num_nodes: usize,
pub dim: usize,
}
impl GATEmbeddings {
pub fn get(&self, v: usize) -> Option<&[f64]> {
self.embeddings.get(v).map(|e| e.as_slice())
}
pub fn cosine_similarity(&self, a: usize, b: usize) -> Option<f64> {
let ea = self.embeddings.get(a)?;
let eb = self.embeddings.get(b)?;
Some(cosine_similarity_vecs(ea, eb))
}
pub fn top_k_similar(&self, query_node: usize, k: usize) -> Vec<(usize, f64)> {
let qe = match self.embeddings.get(query_node) {
Some(e) => e,
None => return Vec::new(),
};
let mut sims: Vec<(usize, f64)> = self
.embeddings
.iter()
.enumerate()
.filter(|(i, _)| *i != query_node)
.map(|(i, e)| (i, cosine_similarity_vecs(qe, e)))
.collect();
sims.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
sims.truncate(k);
sims
}
pub fn mean_embedding(&self) -> Vec<f64> {
if self.embeddings.is_empty() {
return Vec::new();
}
let mut mean = vec![0.0f64; self.dim];
for emb in &self.embeddings {
for (i, &v) in emb.iter().enumerate().take(self.dim) {
mean[i] += v;
}
}
let n = self.embeddings.len() as f64;
mean.iter_mut().for_each(|v| *v /= n);
mean
}
}
fn cosine_similarity_vecs(a: &[f64], b: &[f64]) -> f64 {
let dot: f64 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
let na: f64 = a.iter().map(|&x| x * x).sum::<f64>().sqrt();
let nb: f64 = b.iter().map(|&x| x * x).sum::<f64>().sqrt();
if na < 1e-12 || nb < 1e-12 {
return 0.0;
}
(dot / (na * nb)).clamp(-1.0, 1.0)
}
fn l2_normalize_inplace(v: &mut [f64]) {
let norm: f64 = v.iter().map(|&x| x * x).sum::<f64>().sqrt();
if norm > 1e-12 {
v.iter_mut().for_each(|x| *x /= norm);
}
}
#[cfg(test)]
mod tests {
use super::super::graphsage::{Graph, Lcg};
use super::*;
fn line_graph(n: usize, feat_dim: usize, seed: u64) -> Graph {
let mut rng = Lcg::new(seed);
let features: Vec<Vec<f64>> = (0..n)
.map(|_| (0..feat_dim).map(|_| rng.next_f64()).collect())
.collect();
let mut adjacency: Vec<Vec<usize>> = vec![Vec::new(); n];
for i in 0..n.saturating_sub(1) {
adjacency[i].push(i + 1);
adjacency[i + 1].push(i);
}
Graph::new(features, adjacency).expect("line graph construction should succeed")
}
#[test]
fn test_gat_config_default() {
let config = GATConfig::default();
assert_eq!(config.num_layers, 2);
assert_eq!(config.hidden_num_heads, 8);
assert_eq!(config.output_dim(), config.output_head_dim);
}
#[test]
fn test_gat_config_concat_hidden() {
let config = GATConfig {
hidden_head_dim: 8,
hidden_num_heads: 4,
concat_hidden: true,
..Default::default()
};
assert_eq!(config.hidden_layer_out_dim(), 32); }
#[test]
fn test_gat_config_avg_hidden() {
let config = GATConfig {
hidden_head_dim: 8,
hidden_num_heads: 4,
concat_hidden: false,
..Default::default()
};
assert_eq!(config.hidden_layer_out_dim(), 8); }
#[test]
fn test_gat_layer_construction() {
let mut rng = Lcg::new(42);
let layer =
GATLayer::new(8, 4, 2, 0.2, 0.0, true, &mut rng).expect("layer should construct");
assert_eq!(layer.out_dim(), 8); }
#[test]
fn test_gat_layer_avg() {
let mut rng = Lcg::new(43);
let layer =
GATLayer::new(8, 4, 3, 0.2, 0.0, false, &mut rng).expect("layer should construct");
assert_eq!(layer.out_dim(), 4); }
#[test]
fn test_gat_layer_invalid() {
let mut rng = Lcg::new(1);
assert!(GATLayer::new(0, 4, 2, 0.2, 0.0, true, &mut rng).is_err());
assert!(GATLayer::new(8, 0, 2, 0.2, 0.0, true, &mut rng).is_err());
assert!(GATLayer::new(8, 4, 0, 0.2, 0.0, true, &mut rng).is_err());
}
#[test]
fn test_gat_model_embed_shape() {
let config = GATConfig {
input_dim: 8,
hidden_head_dim: 4,
hidden_num_heads: 2,
output_head_dim: 4,
output_num_heads: 1,
num_layers: 2,
dropout: 0.0,
concat_hidden: true,
avg_output: true,
normalize_output: false,
..Default::default()
};
let model = GATModel::new(config.clone()).expect("GAT model should construct");
let g = line_graph(5, 8, 100);
let embs = model.embed(&g).expect("embed should succeed");
assert_eq!(embs.num_nodes, 5);
assert_eq!(embs.dim, config.output_dim());
for i in 0..5 {
assert_eq!(
embs.get(i).expect("embedding should exist").len(),
config.output_dim()
);
}
}
#[test]
fn test_gat_model_single_layer() {
let config = GATConfig {
input_dim: 4,
hidden_head_dim: 8,
hidden_num_heads: 2,
output_head_dim: 8,
output_num_heads: 2,
num_layers: 1,
dropout: 0.0,
concat_hidden: true,
avg_output: false,
normalize_output: false,
..Default::default()
};
let model = GATModel::new(config.clone()).expect("GAT model should construct");
let g = line_graph(4, 4, 200);
let embs = model.embed(&g).expect("embed should succeed");
assert_eq!(embs.dim, 16);
}
#[test]
fn test_gat_model_normalized_output() {
let config = GATConfig {
input_dim: 4,
hidden_head_dim: 4,
hidden_num_heads: 2,
output_head_dim: 4,
output_num_heads: 1,
num_layers: 1,
dropout: 0.0,
concat_hidden: false,
avg_output: true,
normalize_output: true,
..Default::default()
};
let model = GATModel::new(config).expect("GAT model should construct");
let g = line_graph(5, 4, 300);
let embs = model.embed(&g).expect("embed should succeed");
for i in 0..5 {
let emb = embs.get(i).expect("embedding exists");
let norm: f64 = emb.iter().map(|&x| x * x).sum::<f64>().sqrt();
assert!(norm <= 1.0 + 1e-6, "norm {} should be <= 1", norm);
}
}
#[test]
fn test_gat_cosine_similarity_bounds() {
let config = GATConfig {
input_dim: 4,
hidden_head_dim: 4,
hidden_num_heads: 2,
output_head_dim: 4,
output_num_heads: 1,
num_layers: 1,
dropout: 0.0,
concat_hidden: true,
avg_output: true,
normalize_output: false,
..Default::default()
};
let model = GATModel::new(config).expect("GAT model should construct");
let g = line_graph(5, 4, 400);
let embs = model.embed(&g).expect("embed should succeed");
for i in 0..5 {
for j in 0..5 {
if let Some(sim) = embs.cosine_similarity(i, j) {
assert!(
(-1.0 - 1e-6..=1.0 + 1e-6).contains(&sim),
"cosine_similarity({i}, {j}) = {sim} out of range"
);
}
}
}
}
#[test]
fn test_gat_top_k_similar() {
let config = GATConfig {
input_dim: 4,
hidden_head_dim: 4,
hidden_num_heads: 2,
output_head_dim: 4,
output_num_heads: 1,
num_layers: 2,
dropout: 0.0,
concat_hidden: true,
avg_output: true,
normalize_output: true,
..Default::default()
};
let model = GATModel::new(config).expect("GAT model should construct");
let g = line_graph(8, 4, 500);
let embs = model.embed(&g).expect("embed should succeed");
let top3 = embs.top_k_similar(0, 3);
assert!(top3.len() <= 3);
for window in top3.windows(2) {
assert!(
window[0].1 >= window[1].1 - 1e-10,
"top_k should be sorted descending"
);
}
}
#[test]
fn test_gat_isolated_node() {
let config = GATConfig {
input_dim: 4,
hidden_head_dim: 4,
hidden_num_heads: 2,
output_head_dim: 4,
output_num_heads: 1,
num_layers: 1,
dropout: 0.0,
concat_hidden: true,
avg_output: true,
normalize_output: false,
..Default::default()
};
let model = GATModel::new(config).expect("GAT model should construct");
let features = vec![vec![1.0f64, 0.5, -0.3, 0.8]];
let adjacency = vec![vec![]]; let g = Graph::new(features, adjacency).expect("isolated node graph");
let embs = model.embed(&g).expect("isolated node should embed");
assert_eq!(embs.num_nodes, 1);
assert!(embs.get(0).is_some());
}
#[test]
fn test_gat_invalid_config() {
assert!(GATModel::new(GATConfig {
input_dim: 0,
..Default::default()
})
.is_err());
assert!(GATModel::new(GATConfig {
num_layers: 0,
..Default::default()
})
.is_err());
assert!(GATModel::new(GATConfig {
hidden_num_heads: 0,
..Default::default()
})
.is_err());
assert!(GATModel::new(GATConfig {
output_head_dim: 0,
..Default::default()
})
.is_err());
}
#[test]
fn test_gat_mean_embedding() {
let config = GATConfig {
input_dim: 4,
hidden_head_dim: 4,
hidden_num_heads: 2,
output_head_dim: 4,
output_num_heads: 1,
num_layers: 1,
dropout: 0.0,
concat_hidden: false,
avg_output: true,
normalize_output: true,
..Default::default()
};
let model = GATModel::new(config).expect("GAT model should construct");
let g = line_graph(5, 4, 600);
let embs = model.embed(&g).expect("embed should succeed");
let mean = embs.mean_embedding();
assert_eq!(mean.len(), embs.dim);
}
#[test]
fn test_gat_attention_softmax_sums_to_one() {
let scores = vec![1.0f64, 2.0, 3.0, 0.5, -1.0];
let weights = AttentionHead::softmax(&scores);
let sum: f64 = weights.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-10,
"softmax should sum to 1, got {sum}"
);
assert!(weights[2] > weights[1]);
assert!(weights[1] > weights[0]);
}
#[test]
fn test_gat_three_layer_deep() {
let config = GATConfig {
input_dim: 8,
hidden_head_dim: 4,
hidden_num_heads: 3,
output_head_dim: 4,
output_num_heads: 1,
num_layers: 3,
dropout: 0.0,
concat_hidden: true,
avg_output: true,
normalize_output: true,
seed: 77,
..Default::default()
};
let model = GATModel::new(config.clone()).expect("3-layer GAT should construct");
let g = line_graph(6, 8, 77);
let embs = model.embed(&g).expect("embed should succeed");
assert_eq!(embs.num_nodes, 6);
assert_eq!(embs.dim, config.output_dim());
}
}