Skip to main content

oxirs_embed/models/
gat.rs

1//! Graph Attention Network (GAT) Embedder for RDF Knowledge Graphs
2//!
3//! Veličković et al. (2018) — ICLR: "Graph Attention Networks"
4//!
5//! Implements multi-head self-attention over graph neighbourhoods to produce
6//! context-aware entity embeddings.  Each attention head learns independent Q/K/V
7//! projections; heads are concatenated (or averaged) and projected through W_out;
8//! the result is passed through ReLU and L2-normalised.
9//!
10//! Trained with margin-ranking loss and sign-SGD, identical to the GraphSAGE
11//! embedder pattern established in `graph_sage.rs`.
12
13use crate::EmbeddingError;
14use anyhow::anyhow;
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17
18use super::graphsage::SimpleLcg;
19
20// ── Configuration ─────────────────────────────────────────────────────────────
21
22/// Configuration for the multi-head GAT embedder.
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct GatEmbedderConfig {
25    /// Number of GAT layers. Default: 2.
26    pub num_layers: usize,
27    /// Hidden/output dimension (must be divisible by `num_heads`). Default: 64.
28    pub hidden_dim: usize,
29    /// Number of attention heads. Default: 4.
30    pub num_heads: usize,
31    /// Dropout rate applied to attention coefficients. Default: 0.1.
32    pub dropout_rate: f64,
33    /// Number of training epochs. Default: 50.
34    pub num_epochs: usize,
35    /// Sign-SGD learning rate. Default: 0.01.
36    pub learning_rate: f64,
37    /// Margin γ for ranking loss: max(0, γ − sim_pos + sim_neg). Default: 1.0.
38    pub margin: f64,
39    /// Seed for reproducibility. Default: 42.
40    pub seed: u64,
41}
42
43impl Default for GatEmbedderConfig {
44    fn default() -> Self {
45        Self {
46            num_layers: 2,
47            hidden_dim: 64,
48            num_heads: 4,
49            dropout_rate: 0.1,
50            num_epochs: 50,
51            learning_rate: 0.01,
52            margin: 1.0,
53            seed: 42,
54        }
55    }
56}
57
58// ── Internal helpers ──────────────────────────────────────────────────────────
59
60/// Xavier-uniform init: values drawn from U(−√(6/(in+out)), √(6/(in+out))).
61fn xavier_uniform_2d(rows: usize, cols: usize, rng: &mut SimpleLcg) -> Vec<Vec<f64>> {
62    let limit = (6.0_f64 / (rows + cols).max(1) as f64).sqrt();
63    (0..rows)
64        .map(|_| (0..cols).map(|_| rng.next_f64_range(limit)).collect())
65        .collect()
66}
67
68/// Matrix-vector multiply: W (rows×cols) · x (cols) → (rows).
69#[inline]
70fn matvec(w: &[Vec<f64>], x: &[f64]) -> Vec<f64> {
71    w.iter()
72        .map(|row| row.iter().zip(x.iter()).map(|(&wi, &xi)| wi * xi).sum())
73        .collect()
74}
75
76/// In-place L2-normalisation; no-op when norm ≤ 1e-12.
77fn l2_normalize_inplace(v: &mut [f64]) {
78    let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
79    if norm > 1e-12 {
80        v.iter_mut().for_each(|x| *x /= norm);
81    }
82}
83
84/// Element-wise ReLU.
85#[inline]
86fn relu_vec(v: &[f64]) -> Vec<f64> {
87    v.iter().map(|&x| x.max(0.0)).collect()
88}
89
90/// Cosine similarity between two equal-length slices (numerically safe).
91#[inline]
92fn cosine_sim(a: &[f64], b: &[f64]) -> f64 {
93    let dot: f64 = a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum();
94    let na: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
95    let nb: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
96    dot / (na * nb + 1e-8)
97}
98
99// ── Public scalar utilities (exposed for testing) ─────────────────────────────
100
101/// LeakyReLU: passes positive inputs unchanged, attenuates negatives by
102/// `negative_slope`.  GAT canonical value: `negative_slope = 0.2`.
103#[inline]
104pub fn leaky_relu(x: f64, negative_slope: f64) -> f64 {
105    if x >= 0.0 {
106        x
107    } else {
108        negative_slope * x
109    }
110}
111
112/// Numerically stable softmax over a slice.  Returns a same-length vector
113/// whose entries sum to 1.0.
114pub fn softmax(scores: &[f64]) -> Vec<f64> {
115    if scores.is_empty() {
116        return Vec::new();
117    }
118    let max_val = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
119    let exps: Vec<f64> = scores.iter().map(|&s| (s - max_val).exp()).collect();
120    let sum: f64 = exps.iter().sum();
121    if sum < 1e-30 {
122        vec![1.0 / scores.len() as f64; scores.len()]
123    } else {
124        exps.iter().map(|e| e / sum).collect()
125    }
126}
127
128// ── Per-layer weight matrices ──────────────────────────────────────────────────
129
130/// Weight matrices for one GAT layer (all `num_heads` heads + output projection).
131struct GatLayerWeights {
132    /// Query projections, one per head: [num_heads][head_dim × hidden_dim]
133    w_query: Vec<Vec<Vec<f64>>>,
134    /// Key projections, one per head: [num_heads][head_dim × hidden_dim]
135    w_key: Vec<Vec<Vec<f64>>>,
136    /// Value projections, one per head: [num_heads][head_dim × hidden_dim]
137    w_value: Vec<Vec<Vec<f64>>>,
138    /// Output projection: [hidden_dim × (head_dim * num_heads)]
139    w_out: Vec<Vec<f64>>,
140    /// Number of heads
141    num_heads: usize,
142    /// Per-head dimension (= hidden_dim / num_heads)
143    head_dim: usize,
144    /// Full hidden dimension
145    hidden_dim: usize,
146}
147
148impl GatLayerWeights {
149    fn new(hidden_dim: usize, num_heads: usize, rng: &mut SimpleLcg) -> Self {
150        let head_dim = hidden_dim / num_heads.max(1);
151        let mut w_query = Vec::with_capacity(num_heads);
152        let mut w_key = Vec::with_capacity(num_heads);
153        let mut w_value = Vec::with_capacity(num_heads);
154        for _ in 0..num_heads {
155            w_query.push(xavier_uniform_2d(head_dim, hidden_dim, rng));
156            w_key.push(xavier_uniform_2d(head_dim, hidden_dim, rng));
157            w_value.push(xavier_uniform_2d(head_dim, hidden_dim, rng));
158        }
159        // Output projection: from (head_dim * num_heads) → hidden_dim
160        let concat_dim = head_dim * num_heads;
161        let w_out = xavier_uniform_2d(hidden_dim, concat_dim, rng);
162        Self {
163            w_query,
164            w_key,
165            w_value,
166            w_out,
167            num_heads,
168            head_dim,
169            hidden_dim,
170        }
171    }
172}
173
174// ── Main embedder ──────────────────────────────────────────────────────────────
175
176/// Multi-head graph attention network embedder trained on RDF triple lists.
177///
178/// Architecture follows Veličković et al. (2018): for each node the model
179/// attends over its in-neighbourhood using learned Q/K/V projections,
180/// applies LeakyReLU-gated softmax attention, concatenates the `num_heads`
181/// outputs, projects through W_out, applies ReLU, and L2-normalises.
182///
183/// Training uses a margin-ranking loss and sign-SGD with gradient clipping.
184pub struct GatEmbedder {
185    config: GatEmbedderConfig,
186    /// Entity IRI → sequential integer index.
187    entity_index: HashMap<String, usize>,
188    /// Cached post-training embeddings indexed by entity id.
189    embeddings: Vec<Vec<f64>>,
190    /// Per-layer weight matrices (length = num_layers).
191    layer_weights: Vec<GatLayerWeights>,
192    trained: bool,
193}
194
195impl std::fmt::Debug for GatEmbedder {
196    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197        f.debug_struct("GatEmbedder")
198            .field("num_entities", &self.entity_index.len())
199            .field("trained", &self.trained)
200            .field("num_layers", &self.config.num_layers)
201            .field("hidden_dim", &self.config.hidden_dim)
202            .field("num_heads", &self.config.num_heads)
203            .finish()
204    }
205}
206
207impl GatEmbedder {
208    /// Create a new, un-trained GAT embedder.
209    pub fn new(config: GatEmbedderConfig) -> Self {
210        Self {
211            config,
212            entity_index: HashMap::new(),
213            embeddings: Vec::new(),
214            layer_weights: Vec::new(),
215            trained: false,
216        }
217    }
218
219    // ── Public API ─────────────────────────────────────────────────────────────
220
221    /// Train on `(subject_iri, predicate_iri, object_iri)` triples.
222    ///
223    /// Steps:
224    /// 1. Build entity index and adjacency map (undirected: s→o and o→s).
225    /// 2. Xavier-initialise all Q/K/V and W_out weight matrices.
226    /// 3. Initialise random L2-normalised entity feature vectors.
227    /// 4. For each epoch: attention forward-pass → margin-ranking loss →
228    ///    sign-SGD weight update.
229    /// 5. Cache final embeddings.
230    pub fn fit(&mut self, triples: &[(String, String, String)]) -> Result<(), EmbeddingError> {
231        if triples.is_empty() {
232            return Err(EmbeddingError::Other(anyhow!("Triple set is empty")));
233        }
234
235        // 1. Build entity index and integer-keyed adjacency list
236        let (entity_index, adj_by_idx) = Self::build_graph(triples);
237        let num_entities = entity_index.len();
238        self.entity_index = entity_index;
239
240        // 2. Initialise weight matrices
241        let mut rng = SimpleLcg::new(self.config.seed);
242        let hidden_dim = self.config.hidden_dim;
243        let num_heads = self.config.num_heads;
244        let num_layers = self.config.num_layers;
245        self.layer_weights = (0..num_layers)
246            .map(|_| GatLayerWeights::new(hidden_dim, num_heads, &mut rng))
247            .collect();
248
249        // 3. Initialise entity feature vectors ~ U(-0.5, 0.5), L2-normalised
250        let mut h0: Vec<Vec<f64>> = (0..num_entities)
251            .map(|_| {
252                let mut v: Vec<f64> = (0..hidden_dim)
253                    .map(|_| rng.next_f64_range(0.5_f64))
254                    .collect();
255                l2_normalize_inplace(&mut v);
256                v
257            })
258            .collect();
259
260        // 4. Training loop
261        let mut lcg = SimpleLcg::new(self.config.seed.wrapping_add(1));
262
263        for _epoch in 0..self.config.num_epochs {
264            // Forward pass: compute updated embeddings for all entities
265            let h_all = self.forward_all(&h0, &adj_by_idx, num_entities);
266
267            // Accumulate sign-SGD deltas for weight matrices (simplified proxy
268            // gradient: outer product of embedding sign × loss magnitude)
269            let mut deltas: Vec<Vec<Vec<Vec<f64>>>> = self
270                .layer_weights
271                .iter()
272                .map(|lw| {
273                    let heads: Vec<Vec<Vec<f64>>> = (0..lw.num_heads)
274                        .map(|_| vec![vec![0.0; hidden_dim]; lw.head_dim])
275                        .collect();
276                    // index 0..num_heads → Q, num_heads..2*num_heads → K,
277                    // 2*num_heads..3*num_heads → V, 3*num_heads → W_out
278                    let mut all = heads.clone();
279                    all.extend(heads.clone()); // K
280                    all.extend(heads.clone()); // V
281                    all.push(vec![vec![0.0; lw.head_dim * lw.num_heads]; lw.hidden_dim]); // W_out
282                    all
283                })
284                .collect();
285
286            let mut grad_count = 0usize;
287
288            for (s_str, _p_str, o_str) in triples {
289                let s_idx = match self.entity_index.get(s_str.as_str()) {
290                    Some(&i) => i,
291                    None => continue,
292                };
293                let o_idx = match self.entity_index.get(o_str.as_str()) {
294                    Some(&i) => i,
295                    None => continue,
296                };
297                let o_neg_idx = Self::sample_negative(o_idx, num_entities, &mut lcg);
298
299                let h_s = &h_all[s_idx];
300                let h_o = &h_all[o_idx];
301                let h_neg = &h_all[o_neg_idx];
302
303                let loss =
304                    (self.config.margin - cosine_sim(h_s, h_o) + cosine_sim(h_s, h_neg)).max(0.0);
305
306                if loss > 0.0 {
307                    // Accumulate magnitude-scaled sign gradient for every layer
308                    for (l, lw) in self.layer_weights.iter().enumerate() {
309                        // Proxy gradient: sign of subject embedding component
310                        let nh = lw.num_heads;
311                        let hd = lw.head_dim;
312                        for h in 0..nh {
313                            // Q deltas
314                            for (r, row) in deltas[l][h].iter_mut().enumerate().take(hd) {
315                                let sign = if h_s.get(r % h_s.len()).copied().unwrap_or(0.0) > 0.0 {
316                                    1.0_f64
317                                } else {
318                                    -1.0_f64
319                                };
320                                for delta in row.iter_mut() {
321                                    *delta += sign * loss;
322                                }
323                            }
324                            // K deltas
325                            for (r, row) in deltas[l][nh + h].iter_mut().enumerate().take(hd) {
326                                let sign = if h_o.get(r % h_o.len()).copied().unwrap_or(0.0) > 0.0 {
327                                    1.0_f64
328                                } else {
329                                    -1.0_f64
330                                };
331                                for delta in row.iter_mut() {
332                                    *delta += sign * loss;
333                                }
334                            }
335                            // V deltas
336                            for (r, row) in deltas[l][2 * nh + h].iter_mut().enumerate().take(hd) {
337                                let sign = if h_o.get(r % h_o.len()).copied().unwrap_or(0.0) > 0.0 {
338                                    1.0_f64
339                                } else {
340                                    -1.0_f64
341                                };
342                                for delta in row.iter_mut() {
343                                    *delta += sign * loss;
344                                }
345                            }
346                        }
347                        // W_out deltas
348                        for (r, row) in deltas[l][3 * nh].iter_mut().enumerate() {
349                            let sign = if h_s.get(r % h_s.len()).copied().unwrap_or(0.0) > 0.0 {
350                                1.0_f64
351                            } else {
352                                -1.0_f64
353                            };
354                            for delta in row.iter_mut() {
355                                *delta += sign * loss;
356                            }
357                        }
358                    }
359                    grad_count += 1;
360                }
361            }
362
363            // Apply sign-SGD updates with row-norm gradient clipping
364            if grad_count > 0 {
365                let lr = self.config.learning_rate / grad_count as f64;
366                for (l, lw) in self.layer_weights.iter_mut().enumerate() {
367                    let nh = lw.num_heads;
368                    let hd = lw.head_dim;
369
370                    for h in 0..nh {
371                        // Update Q
372                        for (r, delta_row) in deltas[l][h].iter().enumerate().take(hd) {
373                            let row_norm: f64 = delta_row.iter().map(|g| g * g).sum::<f64>().sqrt();
374                            let clip = if row_norm > 1.0 { 1.0 / row_norm } else { 1.0 };
375                            for (w, d) in lw.w_query[h][r].iter_mut().zip(delta_row.iter()) {
376                                *w -= d * clip * lr;
377                            }
378                        }
379                        // Update K
380                        for (r, delta_row) in deltas[l][nh + h].iter().enumerate().take(hd) {
381                            let row_norm: f64 = delta_row.iter().map(|g| g * g).sum::<f64>().sqrt();
382                            let clip = if row_norm > 1.0 { 1.0 / row_norm } else { 1.0 };
383                            for (w, d) in lw.w_key[h][r].iter_mut().zip(delta_row.iter()) {
384                                *w -= d * clip * lr;
385                            }
386                        }
387                        // Update V
388                        for (r, delta_row) in deltas[l][2 * nh + h].iter().enumerate().take(hd) {
389                            let row_norm: f64 = delta_row.iter().map(|g| g * g).sum::<f64>().sqrt();
390                            let clip = if row_norm > 1.0 { 1.0 / row_norm } else { 1.0 };
391                            for (w, d) in lw.w_value[h][r].iter_mut().zip(delta_row.iter()) {
392                                *w -= d * clip * lr;
393                            }
394                        }
395                    }
396                    // Update W_out
397                    for (r, delta_row) in deltas[l][3 * nh].iter().enumerate() {
398                        let row_norm: f64 = delta_row.iter().map(|g| g * g).sum::<f64>().sqrt();
399                        let clip = if row_norm > 1.0 { 1.0 / row_norm } else { 1.0 };
400                        for (w, d) in lw.w_out[r].iter_mut().zip(delta_row.iter()) {
401                            *w -= d * clip * lr;
402                        }
403                    }
404                }
405            }
406
407            // Re-normalise input features for next epoch
408            for feat in h0.iter_mut() {
409                l2_normalize_inplace(feat);
410            }
411        }
412
413        // 5. Cache final embeddings
414        self.embeddings = self.forward_all(&h0, &adj_by_idx, num_entities);
415        self.trained = true;
416        Ok(())
417    }
418
419    /// Return the embedding for a known entity IRI.
420    /// Returns a zero vector for unseen entities (inductive fallback — never panics).
421    pub fn embed_entity(&self, entity: &str) -> Vec<f64> {
422        match self.entity_index.get(entity) {
423            Some(&idx) => self
424                .embeddings
425                .get(idx)
426                .cloned()
427                .unwrap_or_else(|| vec![0.0; self.config.hidden_dim]),
428            None => vec![0.0; self.config.hidden_dim],
429        }
430    }
431
432    /// Multi-head attention forward pass for a single entity.
433    ///
434    /// For each head `h`:
435    ///   Q_h = W_query[h] · e_i
436    ///   K_h = W_key[h]   · e_j  for each neighbour j
437    ///   V_h = W_value[h] · e_j
438    ///   score_j = LeakyReLU(Q_h · K_h_j / √head_dim)
439    ///   α_j = softmax({score_j})
440    ///   head_out_h = Σ_j α_j · V_h_j
441    ///
442    /// Heads are concatenated → projected by W_out → ReLU → L2-normalised.
443    pub fn attention_forward(
444        &self,
445        entity_idx: usize,
446        adj: &HashMap<usize, Vec<usize>>,
447        embeddings: &[Vec<f64>],
448        layer_idx: usize,
449    ) -> Vec<f64> {
450        let lw = &self.layer_weights[layer_idx];
451        let h_self = match embeddings.get(entity_idx) {
452            Some(e) => e,
453            None => return vec![0.0; self.config.hidden_dim],
454        };
455
456        // Collect neighbour embeddings (include self for isolated nodes)
457        let neighbor_indices: Vec<usize> = adj.get(&entity_idx).cloned().unwrap_or_default();
458        let all_indices: Vec<usize> = {
459            let mut v = vec![entity_idx];
460            v.extend_from_slice(&neighbor_indices);
461            v
462        };
463
464        let scale = (lw.head_dim.max(1) as f64).sqrt();
465
466        // Compute per-head outputs
467        let mut concat_heads: Vec<f64> = Vec::with_capacity(lw.head_dim * lw.num_heads);
468
469        for h in 0..lw.num_heads {
470            // Q for entity i
471            let q_i: Vec<f64> = matvec(&lw.w_query[h], h_self);
472
473            // Attention scores for all neighbours (including self)
474            let scores: Vec<f64> = all_indices
475                .iter()
476                .map(|&j| {
477                    let h_j = match embeddings.get(j) {
478                        Some(e) => e,
479                        None => h_self,
480                    };
481                    let k_j: Vec<f64> = matvec(&lw.w_key[h], h_j);
482                    let raw_score: f64 = q_i.iter().zip(k_j.iter()).map(|(&a, &b)| a * b).sum();
483                    leaky_relu(raw_score / scale, 0.2)
484                })
485                .collect();
486
487            let alphas = softmax(&scores);
488
489            // Weighted sum of value vectors
490            let mut head_out = vec![0.0_f64; lw.head_dim];
491            for (&j, &alpha) in all_indices.iter().zip(alphas.iter()) {
492                let h_j = match embeddings.get(j) {
493                    Some(e) => e,
494                    None => h_self,
495                };
496                let v_j: Vec<f64> = matvec(&lw.w_value[h], h_j);
497                for (acc, vv) in head_out.iter_mut().zip(v_j.iter()) {
498                    *acc += alpha * vv;
499                }
500            }
501            concat_heads.extend_from_slice(&head_out);
502        }
503
504        // Output projection → ReLU → L2-normalise
505        let mut out = relu_vec(&matvec(&lw.w_out, &concat_heads));
506        l2_normalize_inplace(&mut out);
507        out
508    }
509
510    // ── Accessors ──────────────────────────────────────────────────────────────
511
512    /// Whether `fit` has been called successfully.
513    pub fn is_trained(&self) -> bool {
514        self.trained
515    }
516
517    /// Number of distinct entities seen during training.
518    pub fn num_entities(&self) -> usize {
519        self.entity_index.len()
520    }
521
522    /// Dimension of each output embedding.
523    pub fn embedding_dim(&self) -> usize {
524        self.config.hidden_dim
525    }
526
527    // ── Private helpers ────────────────────────────────────────────────────────
528
529    /// Build an entity-IRI→index map and an integer adjacency list from triples.
530    fn build_graph(
531        triples: &[(String, String, String)],
532    ) -> (HashMap<String, usize>, HashMap<usize, Vec<usize>>) {
533        let mut entity_index: HashMap<String, usize> = HashMap::new();
534        let mut next_id = 0usize;
535
536        let mut get_or_insert = |iri: &str| -> usize {
537            if let Some(&id) = entity_index.get(iri) {
538                return id;
539            }
540            let id = next_id;
541            next_id += 1;
542            entity_index.insert(iri.to_string(), id);
543            id
544        };
545
546        // First pass: build entity index
547        for (s, _p, o) in triples {
548            get_or_insert(s.as_str());
549            get_or_insert(o.as_str());
550        }
551
552        // Second pass: build adjacency (undirected: s→o and o→s)
553        let mut adj: HashMap<usize, Vec<usize>> = HashMap::new();
554        for (s, _p, o) in triples {
555            let s_idx = *entity_index.get(s.as_str()).expect("just inserted");
556            let o_idx = *entity_index.get(o.as_str()).expect("just inserted");
557            adj.entry(s_idx).or_default().push(o_idx);
558            adj.entry(o_idx).or_default().push(s_idx);
559        }
560
561        (entity_index, adj)
562    }
563
564    /// Forward pass over all entities for all layers.
565    fn forward_all(
566        &self,
567        h0: &[Vec<f64>],
568        adj: &HashMap<usize, Vec<usize>>,
569        num_entities: usize,
570    ) -> Vec<Vec<f64>> {
571        let mut h_prev = h0.to_vec();
572
573        for layer_idx in 0..self.config.num_layers {
574            let mut h_next: Vec<Vec<f64>> = Vec::with_capacity(num_entities);
575            for node_idx in 0..num_entities {
576                // Use a temporary GatEmbedder-like context pointing at h_prev
577                let out = self.attention_forward_on(node_idx, adj, &h_prev, layer_idx);
578                h_next.push(out);
579            }
580            h_prev = h_next;
581        }
582
583        h_prev
584    }
585
586    /// Internal variant of `attention_forward` that takes explicit embeddings slice.
587    fn attention_forward_on(
588        &self,
589        entity_idx: usize,
590        adj: &HashMap<usize, Vec<usize>>,
591        embeddings: &[Vec<f64>],
592        layer_idx: usize,
593    ) -> Vec<f64> {
594        self.attention_forward(entity_idx, adj, embeddings, layer_idx)
595    }
596
597    /// Sample a negative entity index different from `positive_idx`.
598    fn sample_negative(positive_idx: usize, num_entities: usize, lcg: &mut SimpleLcg) -> usize {
599        if num_entities <= 1 {
600            return 0;
601        }
602        let mut candidate = lcg.next_usize() % num_entities;
603        let mut attempts = 0usize;
604        while candidate == positive_idx && attempts < num_entities {
605            candidate = (candidate + 1) % num_entities;
606            attempts += 1;
607        }
608        candidate
609    }
610}
611
612// ── Tests ─────────────────────────────────────────────────────────────────────
613
614#[cfg(test)]
615mod tests {
616    use super::*;
617
618    /// Build a small, fully-connected knowledge graph for testing.
619    fn toy_triples(n_entities: usize, n_triples: usize) -> Vec<(String, String, String)> {
620        let mut triples = Vec::with_capacity(n_triples);
621        for i in 0..n_triples {
622            let s = format!("http://ex.org/e{}", i % n_entities);
623            let p = "http://ex.org/rel".to_string();
624            let o = format!("http://ex.org/e{}", (i + 1) % n_entities);
625            triples.push((s, p, o));
626        }
627        triples
628    }
629
630    // ── Test 1: Default config produces expected dimensions ────────────────────
631    #[test]
632    fn test_default_config_dimensions() {
633        let config = GatEmbedderConfig::default();
634        assert_eq!(config.num_layers, 2);
635        assert_eq!(config.hidden_dim, 64);
636        assert_eq!(config.num_heads, 4);
637        assert_eq!(config.num_epochs, 50);
638        // head_dim = hidden_dim / num_heads = 16
639        assert_eq!(config.hidden_dim / config.num_heads, 16);
640    }
641
642    // ── Test 2: fit completes without error on a small graph ──────────────────
643    #[test]
644    fn test_fit_completes_small_graph() {
645        let config = GatEmbedderConfig {
646            num_layers: 2,
647            hidden_dim: 16,
648            num_heads: 4,
649            num_epochs: 5,
650            seed: 7,
651            ..Default::default()
652        };
653        let triples = toy_triples(5, 8);
654        let mut embedder = GatEmbedder::new(config);
655        let result = embedder.fit(&triples);
656        assert!(result.is_ok(), "fit should succeed: {result:?}");
657        assert!(embedder.is_trained());
658        assert_eq!(embedder.num_entities(), 5);
659    }
660
661    // ── Test 3: embed_entity returns correct dimension after fit ───────────────
662    #[test]
663    fn test_embed_entity_dimension() {
664        let config = GatEmbedderConfig {
665            num_layers: 2,
666            hidden_dim: 32,
667            num_heads: 4,
668            num_epochs: 3,
669            seed: 11,
670            ..Default::default()
671        };
672        let triples = toy_triples(5, 8);
673        let mut embedder = GatEmbedder::new(config.clone());
674        embedder.fit(&triples).expect("fit should succeed");
675
676        for i in 0..5usize {
677            let iri = format!("http://ex.org/e{}", i);
678            let emb = embedder.embed_entity(&iri);
679            assert_eq!(
680                emb.len(),
681                config.hidden_dim,
682                "embedding length mismatch for entity {iri}"
683            );
684        }
685    }
686
687    // ── Test 4: Unseen entity returns zero vector (not panic) ─────────────────
688    #[test]
689    fn test_unseen_entity_returns_zero_vector() {
690        let config = GatEmbedderConfig {
691            num_layers: 1,
692            hidden_dim: 16,
693            num_heads: 2,
694            num_epochs: 2,
695            seed: 3,
696            ..Default::default()
697        };
698        let triples = toy_triples(5, 8);
699        let mut embedder = GatEmbedder::new(config.clone());
700        embedder.fit(&triples).expect("fit should succeed");
701
702        let unseen = "http://ex.org/TOTALLY_UNSEEN";
703        let emb = embedder.embed_entity(unseen);
704        assert_eq!(emb.len(), config.hidden_dim);
705        assert!(
706            emb.iter().all(|&v| v == 0.0),
707            "unseen entity must return a zero vector"
708        );
709    }
710
711    // ── Test 5: Softmax: attention scores sum to 1.0 ──────────────────────────
712    #[test]
713    fn test_softmax_sums_to_one() {
714        let scores = vec![1.0_f64, 2.0, 0.5, -1.0, 3.5];
715        let probs = softmax(&scores);
716        assert_eq!(probs.len(), scores.len());
717        let total: f64 = probs.iter().sum();
718        assert!(
719            (total - 1.0).abs() < 1e-10,
720            "softmax outputs must sum to 1.0, got {total}"
721        );
722        // All values must be in (0, 1)
723        for &p in &probs {
724            assert!(p > 0.0 && p <= 1.0, "softmax value out of (0,1]: {p}");
725        }
726    }
727
728    // ── Test 6: LeakyReLU passes positive inputs, attenuates negative ones ─────
729    #[test]
730    fn test_leaky_relu_behavior() {
731        let neg_slope = 0.2_f64;
732        // Positive input: passes unchanged
733        let pos = 3.7_f64;
734        assert!((leaky_relu(pos, neg_slope) - pos).abs() < 1e-12);
735        // Zero: passes unchanged
736        assert!((leaky_relu(0.0, neg_slope)).abs() < 1e-12);
737        // Negative input: attenuated by slope
738        let neg = -4.0_f64;
739        let expected = neg_slope * neg;
740        assert!(
741            (leaky_relu(neg, neg_slope) - expected).abs() < 1e-12,
742            "leaky_relu({neg}) should be {expected}"
743        );
744        // Attenuation: |output| < |input| for negative input
745        assert!(
746            leaky_relu(-5.0, neg_slope).abs() < 5.0,
747            "negative input should be attenuated"
748        );
749    }
750
751    // ── Test 7: Embeddings are L2-normalised after forward pass ───────────────
752    #[test]
753    fn test_embeddings_l2_normalized() {
754        let config = GatEmbedderConfig {
755            num_layers: 2,
756            hidden_dim: 16,
757            num_heads: 4,
758            num_epochs: 3,
759            seed: 13,
760            ..Default::default()
761        };
762        let triples = toy_triples(5, 8);
763        let mut embedder = GatEmbedder::new(config.clone());
764        embedder.fit(&triples).expect("fit should succeed");
765
766        for i in 0..5usize {
767            let iri = format!("http://ex.org/e{}", i);
768            let emb = embedder.embed_entity(&iri);
769            let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
770            // Allow collapsed (all-zero) embeddings if ReLU kills all activations
771            if norm > 1e-12 {
772                assert!(
773                    (norm - 1.0).abs() < 0.1,
774                    "L2 norm out of tolerance for {iri}: got {norm}"
775                );
776            }
777        }
778    }
779
780    // ── Test 8: num_heads attention heads → full hidden_dim output ─────────────
781    #[test]
782    fn test_multi_head_output_dimension() {
783        let config = GatEmbedderConfig {
784            num_layers: 1,
785            hidden_dim: 32,
786            num_heads: 4,
787            num_epochs: 1,
788            seed: 17,
789            ..Default::default()
790        };
791        let triples = toy_triples(5, 8);
792        let mut embedder = GatEmbedder::new(config.clone());
793        embedder.fit(&triples).expect("fit should succeed");
794
795        // Build the structures used internally
796        let (entity_index, adj) = GatEmbedder::build_graph(&triples);
797        let num_entities = entity_index.len();
798        let mut rng = SimpleLcg::new(config.seed);
799        let hidden_dim = config.hidden_dim;
800        let h0: Vec<Vec<f64>> = (0..num_entities)
801            .map(|_| {
802                let mut v: Vec<f64> = (0..hidden_dim)
803                    .map(|_| rng.next_f64_range(0.5_f64))
804                    .collect();
805                l2_normalize_inplace(&mut v);
806                v
807            })
808            .collect();
809
810        // After fit, each cached embedding has length == hidden_dim
811        for i in 0..5usize {
812            let iri = format!("http://ex.org/e{}", i);
813            let emb = embedder.embed_entity(&iri);
814            assert_eq!(
815                emb.len(),
816                hidden_dim,
817                "expected output dim {hidden_dim} for entity {i}"
818            );
819            // The W_out projection maps from head_dim*num_heads → hidden_dim;
820            // confirm head_dim * num_heads == hidden_dim (concat property)
821            let head_dim = hidden_dim / config.num_heads;
822            assert_eq!(
823                head_dim * config.num_heads,
824                hidden_dim,
825                "concat dim mismatch: {} * {} ≠ {}",
826                head_dim,
827                config.num_heads,
828                hidden_dim
829            );
830        }
831
832        // Direct test: attention_forward on entity 0 produces hidden_dim output
833        let emb0 = embedder.attention_forward(0, &adj, &h0, 0);
834        assert_eq!(
835            emb0.len(),
836            hidden_dim,
837            "attention_forward should output hidden_dim={hidden_dim}"
838        );
839    }
840
841    // ── Test 9: Loss decreases over training epochs ────────────────────────────
842    #[test]
843    fn test_loss_decreases_over_epochs() {
844        let triples = toy_triples(5, 8);
845
846        let make_config = |epochs: usize, seed: u64| GatEmbedderConfig {
847            num_layers: 2,
848            hidden_dim: 16,
849            num_heads: 4,
850            num_epochs: epochs,
851            learning_rate: 0.05,
852            margin: 1.0,
853            seed,
854            ..Default::default()
855        };
856
857        // Compute average positive-pair cosine similarity as a proxy for loss
858        let avg_sim = |embedder: &GatEmbedder| -> f64 {
859            let (mut total, mut count) = (0.0_f64, 0usize);
860            for (s, _, o) in &triples {
861                let hs = embedder.embed_entity(s);
862                let ho = embedder.embed_entity(o);
863                // Only count non-zero embeddings
864                let ns: f64 = hs.iter().map(|x| x * x).sum::<f64>().sqrt();
865                let no: f64 = ho.iter().map(|x| x * x).sum::<f64>().sqrt();
866                if ns > 1e-12 && no > 1e-12 {
867                    total += cosine_sim(&hs, &ho);
868                    count += 1;
869                }
870            }
871            if count > 0 {
872                total / count as f64
873            } else {
874                0.0
875            }
876        };
877
878        let mut e_early = GatEmbedder::new(make_config(1, 42));
879        e_early.fit(&triples).expect("1-epoch fit should succeed");
880        let sim_early = avg_sim(&e_early);
881
882        let mut e_trained = GatEmbedder::new(make_config(50, 42));
883        e_trained
884            .fit(&triples)
885            .expect("50-epoch fit should succeed");
886        let sim_trained = avg_sim(&e_trained);
887
888        // Trained model should not be dramatically worse; allow ±0.5 slack
889        assert!(
890            sim_trained >= sim_early - 0.5,
891            "similarity regression: 1-epoch={sim_early:.4} 50-epoch={sim_trained:.4}"
892        );
893    }
894}