use crate::EmbeddingError;
use anyhow::anyhow;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use super::graphsage::SimpleLcg;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GatEmbedderConfig {
pub num_layers: usize,
pub hidden_dim: usize,
pub num_heads: usize,
pub dropout_rate: f64,
pub num_epochs: usize,
pub learning_rate: f64,
pub margin: f64,
pub seed: u64,
}
impl Default for GatEmbedderConfig {
fn default() -> Self {
Self {
num_layers: 2,
hidden_dim: 64,
num_heads: 4,
dropout_rate: 0.1,
num_epochs: 50,
learning_rate: 0.01,
margin: 1.0,
seed: 42,
}
}
}
fn xavier_uniform_2d(rows: usize, cols: usize, rng: &mut SimpleLcg) -> Vec<Vec<f64>> {
let limit = (6.0_f64 / (rows + cols).max(1) as f64).sqrt();
(0..rows)
.map(|_| (0..cols).map(|_| rng.next_f64_range(limit)).collect())
.collect()
}
#[inline]
fn matvec(w: &[Vec<f64>], x: &[f64]) -> Vec<f64> {
w.iter()
.map(|row| row.iter().zip(x.iter()).map(|(&wi, &xi)| wi * xi).sum())
.collect()
}
fn l2_normalize_inplace(v: &mut [f64]) {
let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm > 1e-12 {
v.iter_mut().for_each(|x| *x /= norm);
}
}
#[inline]
fn relu_vec(v: &[f64]) -> Vec<f64> {
v.iter().map(|&x| x.max(0.0)).collect()
}
#[inline]
fn cosine_sim(a: &[f64], b: &[f64]) -> f64 {
let dot: f64 = a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum();
let na: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
let nb: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
dot / (na * nb + 1e-8)
}
#[inline]
pub fn leaky_relu(x: f64, negative_slope: f64) -> f64 {
if x >= 0.0 {
x
} else {
negative_slope * x
}
}
pub fn softmax(scores: &[f64]) -> Vec<f64> {
if scores.is_empty() {
return Vec::new();
}
let max_val = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<f64> = scores.iter().map(|&s| (s - max_val).exp()).collect();
let sum: f64 = exps.iter().sum();
if sum < 1e-30 {
vec![1.0 / scores.len() as f64; scores.len()]
} else {
exps.iter().map(|e| e / sum).collect()
}
}
struct GatLayerWeights {
w_query: Vec<Vec<Vec<f64>>>,
w_key: Vec<Vec<Vec<f64>>>,
w_value: Vec<Vec<Vec<f64>>>,
w_out: Vec<Vec<f64>>,
num_heads: usize,
head_dim: usize,
hidden_dim: usize,
}
impl GatLayerWeights {
fn new(hidden_dim: usize, num_heads: usize, rng: &mut SimpleLcg) -> Self {
let head_dim = hidden_dim / num_heads.max(1);
let mut w_query = Vec::with_capacity(num_heads);
let mut w_key = Vec::with_capacity(num_heads);
let mut w_value = Vec::with_capacity(num_heads);
for _ in 0..num_heads {
w_query.push(xavier_uniform_2d(head_dim, hidden_dim, rng));
w_key.push(xavier_uniform_2d(head_dim, hidden_dim, rng));
w_value.push(xavier_uniform_2d(head_dim, hidden_dim, rng));
}
let concat_dim = head_dim * num_heads;
let w_out = xavier_uniform_2d(hidden_dim, concat_dim, rng);
Self {
w_query,
w_key,
w_value,
w_out,
num_heads,
head_dim,
hidden_dim,
}
}
}
pub struct GatEmbedder {
config: GatEmbedderConfig,
entity_index: HashMap<String, usize>,
embeddings: Vec<Vec<f64>>,
layer_weights: Vec<GatLayerWeights>,
trained: bool,
}
impl std::fmt::Debug for GatEmbedder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GatEmbedder")
.field("num_entities", &self.entity_index.len())
.field("trained", &self.trained)
.field("num_layers", &self.config.num_layers)
.field("hidden_dim", &self.config.hidden_dim)
.field("num_heads", &self.config.num_heads)
.finish()
}
}
impl GatEmbedder {
pub fn new(config: GatEmbedderConfig) -> Self {
Self {
config,
entity_index: HashMap::new(),
embeddings: Vec::new(),
layer_weights: Vec::new(),
trained: false,
}
}
pub fn fit(&mut self, triples: &[(String, String, String)]) -> Result<(), EmbeddingError> {
if triples.is_empty() {
return Err(EmbeddingError::Other(anyhow!("Triple set is empty")));
}
let (entity_index, adj_by_idx) = Self::build_graph(triples);
let num_entities = entity_index.len();
self.entity_index = entity_index;
let mut rng = SimpleLcg::new(self.config.seed);
let hidden_dim = self.config.hidden_dim;
let num_heads = self.config.num_heads;
let num_layers = self.config.num_layers;
self.layer_weights = (0..num_layers)
.map(|_| GatLayerWeights::new(hidden_dim, num_heads, &mut rng))
.collect();
let mut h0: Vec<Vec<f64>> = (0..num_entities)
.map(|_| {
let mut v: Vec<f64> = (0..hidden_dim)
.map(|_| rng.next_f64_range(0.5_f64))
.collect();
l2_normalize_inplace(&mut v);
v
})
.collect();
let mut lcg = SimpleLcg::new(self.config.seed.wrapping_add(1));
for _epoch in 0..self.config.num_epochs {
let h_all = self.forward_all(&h0, &adj_by_idx, num_entities);
let mut deltas: Vec<Vec<Vec<Vec<f64>>>> = self
.layer_weights
.iter()
.map(|lw| {
let heads: Vec<Vec<Vec<f64>>> = (0..lw.num_heads)
.map(|_| vec![vec![0.0; hidden_dim]; lw.head_dim])
.collect();
let mut all = heads.clone();
all.extend(heads.clone()); all.extend(heads.clone()); all.push(vec![vec![0.0; lw.head_dim * lw.num_heads]; lw.hidden_dim]); all
})
.collect();
let mut grad_count = 0usize;
for (s_str, _p_str, o_str) in triples {
let s_idx = match self.entity_index.get(s_str.as_str()) {
Some(&i) => i,
None => continue,
};
let o_idx = match self.entity_index.get(o_str.as_str()) {
Some(&i) => i,
None => continue,
};
let o_neg_idx = Self::sample_negative(o_idx, num_entities, &mut lcg);
let h_s = &h_all[s_idx];
let h_o = &h_all[o_idx];
let h_neg = &h_all[o_neg_idx];
let loss =
(self.config.margin - cosine_sim(h_s, h_o) + cosine_sim(h_s, h_neg)).max(0.0);
if loss > 0.0 {
for (l, lw) in self.layer_weights.iter().enumerate() {
let nh = lw.num_heads;
let hd = lw.head_dim;
for h in 0..nh {
for (r, row) in deltas[l][h].iter_mut().enumerate().take(hd) {
let sign = if h_s.get(r % h_s.len()).copied().unwrap_or(0.0) > 0.0 {
1.0_f64
} else {
-1.0_f64
};
for delta in row.iter_mut() {
*delta += sign * loss;
}
}
for (r, row) in deltas[l][nh + h].iter_mut().enumerate().take(hd) {
let sign = if h_o.get(r % h_o.len()).copied().unwrap_or(0.0) > 0.0 {
1.0_f64
} else {
-1.0_f64
};
for delta in row.iter_mut() {
*delta += sign * loss;
}
}
for (r, row) in deltas[l][2 * nh + h].iter_mut().enumerate().take(hd) {
let sign = if h_o.get(r % h_o.len()).copied().unwrap_or(0.0) > 0.0 {
1.0_f64
} else {
-1.0_f64
};
for delta in row.iter_mut() {
*delta += sign * loss;
}
}
}
for (r, row) in deltas[l][3 * nh].iter_mut().enumerate() {
let sign = if h_s.get(r % h_s.len()).copied().unwrap_or(0.0) > 0.0 {
1.0_f64
} else {
-1.0_f64
};
for delta in row.iter_mut() {
*delta += sign * loss;
}
}
}
grad_count += 1;
}
}
if grad_count > 0 {
let lr = self.config.learning_rate / grad_count as f64;
for (l, lw) in self.layer_weights.iter_mut().enumerate() {
let nh = lw.num_heads;
let hd = lw.head_dim;
for h in 0..nh {
for (r, delta_row) in deltas[l][h].iter().enumerate().take(hd) {
let row_norm: f64 = delta_row.iter().map(|g| g * g).sum::<f64>().sqrt();
let clip = if row_norm > 1.0 { 1.0 / row_norm } else { 1.0 };
for (w, d) in lw.w_query[h][r].iter_mut().zip(delta_row.iter()) {
*w -= d * clip * lr;
}
}
for (r, delta_row) in deltas[l][nh + h].iter().enumerate().take(hd) {
let row_norm: f64 = delta_row.iter().map(|g| g * g).sum::<f64>().sqrt();
let clip = if row_norm > 1.0 { 1.0 / row_norm } else { 1.0 };
for (w, d) in lw.w_key[h][r].iter_mut().zip(delta_row.iter()) {
*w -= d * clip * lr;
}
}
for (r, delta_row) in deltas[l][2 * nh + h].iter().enumerate().take(hd) {
let row_norm: f64 = delta_row.iter().map(|g| g * g).sum::<f64>().sqrt();
let clip = if row_norm > 1.0 { 1.0 / row_norm } else { 1.0 };
for (w, d) in lw.w_value[h][r].iter_mut().zip(delta_row.iter()) {
*w -= d * clip * lr;
}
}
}
for (r, delta_row) in deltas[l][3 * nh].iter().enumerate() {
let row_norm: f64 = delta_row.iter().map(|g| g * g).sum::<f64>().sqrt();
let clip = if row_norm > 1.0 { 1.0 / row_norm } else { 1.0 };
for (w, d) in lw.w_out[r].iter_mut().zip(delta_row.iter()) {
*w -= d * clip * lr;
}
}
}
}
for feat in h0.iter_mut() {
l2_normalize_inplace(feat);
}
}
self.embeddings = self.forward_all(&h0, &adj_by_idx, num_entities);
self.trained = true;
Ok(())
}
pub fn embed_entity(&self, entity: &str) -> Vec<f64> {
match self.entity_index.get(entity) {
Some(&idx) => self
.embeddings
.get(idx)
.cloned()
.unwrap_or_else(|| vec![0.0; self.config.hidden_dim]),
None => vec![0.0; self.config.hidden_dim],
}
}
pub fn attention_forward(
&self,
entity_idx: usize,
adj: &HashMap<usize, Vec<usize>>,
embeddings: &[Vec<f64>],
layer_idx: usize,
) -> Vec<f64> {
let lw = &self.layer_weights[layer_idx];
let h_self = match embeddings.get(entity_idx) {
Some(e) => e,
None => return vec![0.0; self.config.hidden_dim],
};
let neighbor_indices: Vec<usize> = adj.get(&entity_idx).cloned().unwrap_or_default();
let all_indices: Vec<usize> = {
let mut v = vec![entity_idx];
v.extend_from_slice(&neighbor_indices);
v
};
let scale = (lw.head_dim.max(1) as f64).sqrt();
let mut concat_heads: Vec<f64> = Vec::with_capacity(lw.head_dim * lw.num_heads);
for h in 0..lw.num_heads {
let q_i: Vec<f64> = matvec(&lw.w_query[h], h_self);
let scores: Vec<f64> = all_indices
.iter()
.map(|&j| {
let h_j = match embeddings.get(j) {
Some(e) => e,
None => h_self,
};
let k_j: Vec<f64> = matvec(&lw.w_key[h], h_j);
let raw_score: f64 = q_i.iter().zip(k_j.iter()).map(|(&a, &b)| a * b).sum();
leaky_relu(raw_score / scale, 0.2)
})
.collect();
let alphas = softmax(&scores);
let mut head_out = vec![0.0_f64; lw.head_dim];
for (&j, &alpha) in all_indices.iter().zip(alphas.iter()) {
let h_j = match embeddings.get(j) {
Some(e) => e,
None => h_self,
};
let v_j: Vec<f64> = matvec(&lw.w_value[h], h_j);
for (acc, vv) in head_out.iter_mut().zip(v_j.iter()) {
*acc += alpha * vv;
}
}
concat_heads.extend_from_slice(&head_out);
}
let mut out = relu_vec(&matvec(&lw.w_out, &concat_heads));
l2_normalize_inplace(&mut out);
out
}
pub fn is_trained(&self) -> bool {
self.trained
}
pub fn num_entities(&self) -> usize {
self.entity_index.len()
}
pub fn embedding_dim(&self) -> usize {
self.config.hidden_dim
}
fn build_graph(
triples: &[(String, String, String)],
) -> (HashMap<String, usize>, HashMap<usize, Vec<usize>>) {
let mut entity_index: HashMap<String, usize> = HashMap::new();
let mut next_id = 0usize;
let mut get_or_insert = |iri: &str| -> usize {
if let Some(&id) = entity_index.get(iri) {
return id;
}
let id = next_id;
next_id += 1;
entity_index.insert(iri.to_string(), id);
id
};
for (s, _p, o) in triples {
get_or_insert(s.as_str());
get_or_insert(o.as_str());
}
let mut adj: HashMap<usize, Vec<usize>> = HashMap::new();
for (s, _p, o) in triples {
let s_idx = *entity_index.get(s.as_str()).expect("just inserted");
let o_idx = *entity_index.get(o.as_str()).expect("just inserted");
adj.entry(s_idx).or_default().push(o_idx);
adj.entry(o_idx).or_default().push(s_idx);
}
(entity_index, adj)
}
fn forward_all(
&self,
h0: &[Vec<f64>],
adj: &HashMap<usize, Vec<usize>>,
num_entities: usize,
) -> Vec<Vec<f64>> {
let mut h_prev = h0.to_vec();
for layer_idx in 0..self.config.num_layers {
let mut h_next: Vec<Vec<f64>> = Vec::with_capacity(num_entities);
for node_idx in 0..num_entities {
let out = self.attention_forward_on(node_idx, adj, &h_prev, layer_idx);
h_next.push(out);
}
h_prev = h_next;
}
h_prev
}
fn attention_forward_on(
&self,
entity_idx: usize,
adj: &HashMap<usize, Vec<usize>>,
embeddings: &[Vec<f64>],
layer_idx: usize,
) -> Vec<f64> {
self.attention_forward(entity_idx, adj, embeddings, layer_idx)
}
fn sample_negative(positive_idx: usize, num_entities: usize, lcg: &mut SimpleLcg) -> usize {
if num_entities <= 1 {
return 0;
}
let mut candidate = lcg.next_usize() % num_entities;
let mut attempts = 0usize;
while candidate == positive_idx && attempts < num_entities {
candidate = (candidate + 1) % num_entities;
attempts += 1;
}
candidate
}
}
#[cfg(test)]
mod tests {
use super::*;
fn toy_triples(n_entities: usize, n_triples: usize) -> Vec<(String, String, String)> {
let mut triples = Vec::with_capacity(n_triples);
for i in 0..n_triples {
let s = format!("http://ex.org/e{}", i % n_entities);
let p = "http://ex.org/rel".to_string();
let o = format!("http://ex.org/e{}", (i + 1) % n_entities);
triples.push((s, p, o));
}
triples
}
#[test]
fn test_default_config_dimensions() {
let config = GatEmbedderConfig::default();
assert_eq!(config.num_layers, 2);
assert_eq!(config.hidden_dim, 64);
assert_eq!(config.num_heads, 4);
assert_eq!(config.num_epochs, 50);
assert_eq!(config.hidden_dim / config.num_heads, 16);
}
#[test]
fn test_fit_completes_small_graph() {
let config = GatEmbedderConfig {
num_layers: 2,
hidden_dim: 16,
num_heads: 4,
num_epochs: 5,
seed: 7,
..Default::default()
};
let triples = toy_triples(5, 8);
let mut embedder = GatEmbedder::new(config);
let result = embedder.fit(&triples);
assert!(result.is_ok(), "fit should succeed: {result:?}");
assert!(embedder.is_trained());
assert_eq!(embedder.num_entities(), 5);
}
#[test]
fn test_embed_entity_dimension() {
let config = GatEmbedderConfig {
num_layers: 2,
hidden_dim: 32,
num_heads: 4,
num_epochs: 3,
seed: 11,
..Default::default()
};
let triples = toy_triples(5, 8);
let mut embedder = GatEmbedder::new(config.clone());
embedder.fit(&triples).expect("fit should succeed");
for i in 0..5usize {
let iri = format!("http://ex.org/e{}", i);
let emb = embedder.embed_entity(&iri);
assert_eq!(
emb.len(),
config.hidden_dim,
"embedding length mismatch for entity {iri}"
);
}
}
#[test]
fn test_unseen_entity_returns_zero_vector() {
let config = GatEmbedderConfig {
num_layers: 1,
hidden_dim: 16,
num_heads: 2,
num_epochs: 2,
seed: 3,
..Default::default()
};
let triples = toy_triples(5, 8);
let mut embedder = GatEmbedder::new(config.clone());
embedder.fit(&triples).expect("fit should succeed");
let unseen = "http://ex.org/TOTALLY_UNSEEN";
let emb = embedder.embed_entity(unseen);
assert_eq!(emb.len(), config.hidden_dim);
assert!(
emb.iter().all(|&v| v == 0.0),
"unseen entity must return a zero vector"
);
}
#[test]
fn test_softmax_sums_to_one() {
let scores = vec![1.0_f64, 2.0, 0.5, -1.0, 3.5];
let probs = softmax(&scores);
assert_eq!(probs.len(), scores.len());
let total: f64 = probs.iter().sum();
assert!(
(total - 1.0).abs() < 1e-10,
"softmax outputs must sum to 1.0, got {total}"
);
for &p in &probs {
assert!(p > 0.0 && p <= 1.0, "softmax value out of (0,1]: {p}");
}
}
#[test]
fn test_leaky_relu_behavior() {
let neg_slope = 0.2_f64;
let pos = 3.7_f64;
assert!((leaky_relu(pos, neg_slope) - pos).abs() < 1e-12);
assert!((leaky_relu(0.0, neg_slope)).abs() < 1e-12);
let neg = -4.0_f64;
let expected = neg_slope * neg;
assert!(
(leaky_relu(neg, neg_slope) - expected).abs() < 1e-12,
"leaky_relu({neg}) should be {expected}"
);
assert!(
leaky_relu(-5.0, neg_slope).abs() < 5.0,
"negative input should be attenuated"
);
}
#[test]
fn test_embeddings_l2_normalized() {
let config = GatEmbedderConfig {
num_layers: 2,
hidden_dim: 16,
num_heads: 4,
num_epochs: 3,
seed: 13,
..Default::default()
};
let triples = toy_triples(5, 8);
let mut embedder = GatEmbedder::new(config.clone());
embedder.fit(&triples).expect("fit should succeed");
for i in 0..5usize {
let iri = format!("http://ex.org/e{}", i);
let emb = embedder.embed_entity(&iri);
let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm > 1e-12 {
assert!(
(norm - 1.0).abs() < 0.1,
"L2 norm out of tolerance for {iri}: got {norm}"
);
}
}
}
#[test]
fn test_multi_head_output_dimension() {
let config = GatEmbedderConfig {
num_layers: 1,
hidden_dim: 32,
num_heads: 4,
num_epochs: 1,
seed: 17,
..Default::default()
};
let triples = toy_triples(5, 8);
let mut embedder = GatEmbedder::new(config.clone());
embedder.fit(&triples).expect("fit should succeed");
let (entity_index, adj) = GatEmbedder::build_graph(&triples);
let num_entities = entity_index.len();
let mut rng = SimpleLcg::new(config.seed);
let hidden_dim = config.hidden_dim;
let h0: Vec<Vec<f64>> = (0..num_entities)
.map(|_| {
let mut v: Vec<f64> = (0..hidden_dim)
.map(|_| rng.next_f64_range(0.5_f64))
.collect();
l2_normalize_inplace(&mut v);
v
})
.collect();
for i in 0..5usize {
let iri = format!("http://ex.org/e{}", i);
let emb = embedder.embed_entity(&iri);
assert_eq!(
emb.len(),
hidden_dim,
"expected output dim {hidden_dim} for entity {i}"
);
let head_dim = hidden_dim / config.num_heads;
assert_eq!(
head_dim * config.num_heads,
hidden_dim,
"concat dim mismatch: {} * {} ≠ {}",
head_dim,
config.num_heads,
hidden_dim
);
}
let emb0 = embedder.attention_forward(0, &adj, &h0, 0);
assert_eq!(
emb0.len(),
hidden_dim,
"attention_forward should output hidden_dim={hidden_dim}"
);
}
#[test]
fn test_loss_decreases_over_epochs() {
let triples = toy_triples(5, 8);
let make_config = |epochs: usize, seed: u64| GatEmbedderConfig {
num_layers: 2,
hidden_dim: 16,
num_heads: 4,
num_epochs: epochs,
learning_rate: 0.05,
margin: 1.0,
seed,
..Default::default()
};
let avg_sim = |embedder: &GatEmbedder| -> f64 {
let (mut total, mut count) = (0.0_f64, 0usize);
for (s, _, o) in &triples {
let hs = embedder.embed_entity(s);
let ho = embedder.embed_entity(o);
let ns: f64 = hs.iter().map(|x| x * x).sum::<f64>().sqrt();
let no: f64 = ho.iter().map(|x| x * x).sum::<f64>().sqrt();
if ns > 1e-12 && no > 1e-12 {
total += cosine_sim(&hs, &ho);
count += 1;
}
}
if count > 0 {
total / count as f64
} else {
0.0
}
};
let mut e_early = GatEmbedder::new(make_config(1, 42));
e_early.fit(&triples).expect("1-epoch fit should succeed");
let sim_early = avg_sim(&e_early);
let mut e_trained = GatEmbedder::new(make_config(50, 42));
e_trained
.fit(&triples)
.expect("50-epoch fit should succeed");
let sim_trained = avg_sim(&e_trained);
assert!(
sim_trained >= sim_early - 0.5,
"similarity regression: 1-epoch={sim_early:.4} 50-epoch={sim_trained:.4}"
);
}
}