1use crate::EmbeddingError;
14use anyhow::anyhow;
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17
18use super::graphsage::SimpleLcg;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct GatEmbedderConfig {
25 pub num_layers: usize,
27 pub hidden_dim: usize,
29 pub num_heads: usize,
31 pub dropout_rate: f64,
33 pub num_epochs: usize,
35 pub learning_rate: f64,
37 pub margin: f64,
39 pub seed: u64,
41}
42
43impl Default for GatEmbedderConfig {
44 fn default() -> Self {
45 Self {
46 num_layers: 2,
47 hidden_dim: 64,
48 num_heads: 4,
49 dropout_rate: 0.1,
50 num_epochs: 50,
51 learning_rate: 0.01,
52 margin: 1.0,
53 seed: 42,
54 }
55 }
56}
57
58fn xavier_uniform_2d(rows: usize, cols: usize, rng: &mut SimpleLcg) -> Vec<Vec<f64>> {
62 let limit = (6.0_f64 / (rows + cols).max(1) as f64).sqrt();
63 (0..rows)
64 .map(|_| (0..cols).map(|_| rng.next_f64_range(limit)).collect())
65 .collect()
66}
67
68#[inline]
70fn matvec(w: &[Vec<f64>], x: &[f64]) -> Vec<f64> {
71 w.iter()
72 .map(|row| row.iter().zip(x.iter()).map(|(&wi, &xi)| wi * xi).sum())
73 .collect()
74}
75
76fn l2_normalize_inplace(v: &mut [f64]) {
78 let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
79 if norm > 1e-12 {
80 v.iter_mut().for_each(|x| *x /= norm);
81 }
82}
83
84#[inline]
86fn relu_vec(v: &[f64]) -> Vec<f64> {
87 v.iter().map(|&x| x.max(0.0)).collect()
88}
89
90#[inline]
92fn cosine_sim(a: &[f64], b: &[f64]) -> f64 {
93 let dot: f64 = a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum();
94 let na: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
95 let nb: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
96 dot / (na * nb + 1e-8)
97}
98
99#[inline]
104pub fn leaky_relu(x: f64, negative_slope: f64) -> f64 {
105 if x >= 0.0 {
106 x
107 } else {
108 negative_slope * x
109 }
110}
111
112pub fn softmax(scores: &[f64]) -> Vec<f64> {
115 if scores.is_empty() {
116 return Vec::new();
117 }
118 let max_val = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
119 let exps: Vec<f64> = scores.iter().map(|&s| (s - max_val).exp()).collect();
120 let sum: f64 = exps.iter().sum();
121 if sum < 1e-30 {
122 vec![1.0 / scores.len() as f64; scores.len()]
123 } else {
124 exps.iter().map(|e| e / sum).collect()
125 }
126}
127
128struct GatLayerWeights {
132 w_query: Vec<Vec<Vec<f64>>>,
134 w_key: Vec<Vec<Vec<f64>>>,
136 w_value: Vec<Vec<Vec<f64>>>,
138 w_out: Vec<Vec<f64>>,
140 num_heads: usize,
142 head_dim: usize,
144 hidden_dim: usize,
146}
147
148impl GatLayerWeights {
149 fn new(hidden_dim: usize, num_heads: usize, rng: &mut SimpleLcg) -> Self {
150 let head_dim = hidden_dim / num_heads.max(1);
151 let mut w_query = Vec::with_capacity(num_heads);
152 let mut w_key = Vec::with_capacity(num_heads);
153 let mut w_value = Vec::with_capacity(num_heads);
154 for _ in 0..num_heads {
155 w_query.push(xavier_uniform_2d(head_dim, hidden_dim, rng));
156 w_key.push(xavier_uniform_2d(head_dim, hidden_dim, rng));
157 w_value.push(xavier_uniform_2d(head_dim, hidden_dim, rng));
158 }
159 let concat_dim = head_dim * num_heads;
161 let w_out = xavier_uniform_2d(hidden_dim, concat_dim, rng);
162 Self {
163 w_query,
164 w_key,
165 w_value,
166 w_out,
167 num_heads,
168 head_dim,
169 hidden_dim,
170 }
171 }
172}
173
174pub struct GatEmbedder {
185 config: GatEmbedderConfig,
186 entity_index: HashMap<String, usize>,
188 embeddings: Vec<Vec<f64>>,
190 layer_weights: Vec<GatLayerWeights>,
192 trained: bool,
193}
194
195impl std::fmt::Debug for GatEmbedder {
196 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197 f.debug_struct("GatEmbedder")
198 .field("num_entities", &self.entity_index.len())
199 .field("trained", &self.trained)
200 .field("num_layers", &self.config.num_layers)
201 .field("hidden_dim", &self.config.hidden_dim)
202 .field("num_heads", &self.config.num_heads)
203 .finish()
204 }
205}
206
207impl GatEmbedder {
208 pub fn new(config: GatEmbedderConfig) -> Self {
210 Self {
211 config,
212 entity_index: HashMap::new(),
213 embeddings: Vec::new(),
214 layer_weights: Vec::new(),
215 trained: false,
216 }
217 }
218
219 pub fn fit(&mut self, triples: &[(String, String, String)]) -> Result<(), EmbeddingError> {
231 if triples.is_empty() {
232 return Err(EmbeddingError::Other(anyhow!("Triple set is empty")));
233 }
234
235 let (entity_index, adj_by_idx) = Self::build_graph(triples);
237 let num_entities = entity_index.len();
238 self.entity_index = entity_index;
239
240 let mut rng = SimpleLcg::new(self.config.seed);
242 let hidden_dim = self.config.hidden_dim;
243 let num_heads = self.config.num_heads;
244 let num_layers = self.config.num_layers;
245 self.layer_weights = (0..num_layers)
246 .map(|_| GatLayerWeights::new(hidden_dim, num_heads, &mut rng))
247 .collect();
248
249 let mut h0: Vec<Vec<f64>> = (0..num_entities)
251 .map(|_| {
252 let mut v: Vec<f64> = (0..hidden_dim)
253 .map(|_| rng.next_f64_range(0.5_f64))
254 .collect();
255 l2_normalize_inplace(&mut v);
256 v
257 })
258 .collect();
259
260 let mut lcg = SimpleLcg::new(self.config.seed.wrapping_add(1));
262
263 for _epoch in 0..self.config.num_epochs {
264 let h_all = self.forward_all(&h0, &adj_by_idx, num_entities);
266
267 let mut deltas: Vec<Vec<Vec<Vec<f64>>>> = self
270 .layer_weights
271 .iter()
272 .map(|lw| {
273 let heads: Vec<Vec<Vec<f64>>> = (0..lw.num_heads)
274 .map(|_| vec![vec![0.0; hidden_dim]; lw.head_dim])
275 .collect();
276 let mut all = heads.clone();
279 all.extend(heads.clone()); all.extend(heads.clone()); all.push(vec![vec![0.0; lw.head_dim * lw.num_heads]; lw.hidden_dim]); all
283 })
284 .collect();
285
286 let mut grad_count = 0usize;
287
288 for (s_str, _p_str, o_str) in triples {
289 let s_idx = match self.entity_index.get(s_str.as_str()) {
290 Some(&i) => i,
291 None => continue,
292 };
293 let o_idx = match self.entity_index.get(o_str.as_str()) {
294 Some(&i) => i,
295 None => continue,
296 };
297 let o_neg_idx = Self::sample_negative(o_idx, num_entities, &mut lcg);
298
299 let h_s = &h_all[s_idx];
300 let h_o = &h_all[o_idx];
301 let h_neg = &h_all[o_neg_idx];
302
303 let loss =
304 (self.config.margin - cosine_sim(h_s, h_o) + cosine_sim(h_s, h_neg)).max(0.0);
305
306 if loss > 0.0 {
307 for (l, lw) in self.layer_weights.iter().enumerate() {
309 let nh = lw.num_heads;
311 let hd = lw.head_dim;
312 for h in 0..nh {
313 for (r, row) in deltas[l][h].iter_mut().enumerate().take(hd) {
315 let sign = if h_s.get(r % h_s.len()).copied().unwrap_or(0.0) > 0.0 {
316 1.0_f64
317 } else {
318 -1.0_f64
319 };
320 for delta in row.iter_mut() {
321 *delta += sign * loss;
322 }
323 }
324 for (r, row) in deltas[l][nh + h].iter_mut().enumerate().take(hd) {
326 let sign = if h_o.get(r % h_o.len()).copied().unwrap_or(0.0) > 0.0 {
327 1.0_f64
328 } else {
329 -1.0_f64
330 };
331 for delta in row.iter_mut() {
332 *delta += sign * loss;
333 }
334 }
335 for (r, row) in deltas[l][2 * nh + h].iter_mut().enumerate().take(hd) {
337 let sign = if h_o.get(r % h_o.len()).copied().unwrap_or(0.0) > 0.0 {
338 1.0_f64
339 } else {
340 -1.0_f64
341 };
342 for delta in row.iter_mut() {
343 *delta += sign * loss;
344 }
345 }
346 }
347 for (r, row) in deltas[l][3 * nh].iter_mut().enumerate() {
349 let sign = if h_s.get(r % h_s.len()).copied().unwrap_or(0.0) > 0.0 {
350 1.0_f64
351 } else {
352 -1.0_f64
353 };
354 for delta in row.iter_mut() {
355 *delta += sign * loss;
356 }
357 }
358 }
359 grad_count += 1;
360 }
361 }
362
363 if grad_count > 0 {
365 let lr = self.config.learning_rate / grad_count as f64;
366 for (l, lw) in self.layer_weights.iter_mut().enumerate() {
367 let nh = lw.num_heads;
368 let hd = lw.head_dim;
369
370 for h in 0..nh {
371 for (r, delta_row) in deltas[l][h].iter().enumerate().take(hd) {
373 let row_norm: f64 = delta_row.iter().map(|g| g * g).sum::<f64>().sqrt();
374 let clip = if row_norm > 1.0 { 1.0 / row_norm } else { 1.0 };
375 for (w, d) in lw.w_query[h][r].iter_mut().zip(delta_row.iter()) {
376 *w -= d * clip * lr;
377 }
378 }
379 for (r, delta_row) in deltas[l][nh + h].iter().enumerate().take(hd) {
381 let row_norm: f64 = delta_row.iter().map(|g| g * g).sum::<f64>().sqrt();
382 let clip = if row_norm > 1.0 { 1.0 / row_norm } else { 1.0 };
383 for (w, d) in lw.w_key[h][r].iter_mut().zip(delta_row.iter()) {
384 *w -= d * clip * lr;
385 }
386 }
387 for (r, delta_row) in deltas[l][2 * nh + h].iter().enumerate().take(hd) {
389 let row_norm: f64 = delta_row.iter().map(|g| g * g).sum::<f64>().sqrt();
390 let clip = if row_norm > 1.0 { 1.0 / row_norm } else { 1.0 };
391 for (w, d) in lw.w_value[h][r].iter_mut().zip(delta_row.iter()) {
392 *w -= d * clip * lr;
393 }
394 }
395 }
396 for (r, delta_row) in deltas[l][3 * nh].iter().enumerate() {
398 let row_norm: f64 = delta_row.iter().map(|g| g * g).sum::<f64>().sqrt();
399 let clip = if row_norm > 1.0 { 1.0 / row_norm } else { 1.0 };
400 for (w, d) in lw.w_out[r].iter_mut().zip(delta_row.iter()) {
401 *w -= d * clip * lr;
402 }
403 }
404 }
405 }
406
407 for feat in h0.iter_mut() {
409 l2_normalize_inplace(feat);
410 }
411 }
412
413 self.embeddings = self.forward_all(&h0, &adj_by_idx, num_entities);
415 self.trained = true;
416 Ok(())
417 }
418
419 pub fn embed_entity(&self, entity: &str) -> Vec<f64> {
422 match self.entity_index.get(entity) {
423 Some(&idx) => self
424 .embeddings
425 .get(idx)
426 .cloned()
427 .unwrap_or_else(|| vec![0.0; self.config.hidden_dim]),
428 None => vec![0.0; self.config.hidden_dim],
429 }
430 }
431
432 pub fn attention_forward(
444 &self,
445 entity_idx: usize,
446 adj: &HashMap<usize, Vec<usize>>,
447 embeddings: &[Vec<f64>],
448 layer_idx: usize,
449 ) -> Vec<f64> {
450 let lw = &self.layer_weights[layer_idx];
451 let h_self = match embeddings.get(entity_idx) {
452 Some(e) => e,
453 None => return vec![0.0; self.config.hidden_dim],
454 };
455
456 let neighbor_indices: Vec<usize> = adj.get(&entity_idx).cloned().unwrap_or_default();
458 let all_indices: Vec<usize> = {
459 let mut v = vec![entity_idx];
460 v.extend_from_slice(&neighbor_indices);
461 v
462 };
463
464 let scale = (lw.head_dim.max(1) as f64).sqrt();
465
466 let mut concat_heads: Vec<f64> = Vec::with_capacity(lw.head_dim * lw.num_heads);
468
469 for h in 0..lw.num_heads {
470 let q_i: Vec<f64> = matvec(&lw.w_query[h], h_self);
472
473 let scores: Vec<f64> = all_indices
475 .iter()
476 .map(|&j| {
477 let h_j = match embeddings.get(j) {
478 Some(e) => e,
479 None => h_self,
480 };
481 let k_j: Vec<f64> = matvec(&lw.w_key[h], h_j);
482 let raw_score: f64 = q_i.iter().zip(k_j.iter()).map(|(&a, &b)| a * b).sum();
483 leaky_relu(raw_score / scale, 0.2)
484 })
485 .collect();
486
487 let alphas = softmax(&scores);
488
489 let mut head_out = vec![0.0_f64; lw.head_dim];
491 for (&j, &alpha) in all_indices.iter().zip(alphas.iter()) {
492 let h_j = match embeddings.get(j) {
493 Some(e) => e,
494 None => h_self,
495 };
496 let v_j: Vec<f64> = matvec(&lw.w_value[h], h_j);
497 for (acc, vv) in head_out.iter_mut().zip(v_j.iter()) {
498 *acc += alpha * vv;
499 }
500 }
501 concat_heads.extend_from_slice(&head_out);
502 }
503
504 let mut out = relu_vec(&matvec(&lw.w_out, &concat_heads));
506 l2_normalize_inplace(&mut out);
507 out
508 }
509
510 pub fn is_trained(&self) -> bool {
514 self.trained
515 }
516
517 pub fn num_entities(&self) -> usize {
519 self.entity_index.len()
520 }
521
522 pub fn embedding_dim(&self) -> usize {
524 self.config.hidden_dim
525 }
526
527 fn build_graph(
531 triples: &[(String, String, String)],
532 ) -> (HashMap<String, usize>, HashMap<usize, Vec<usize>>) {
533 let mut entity_index: HashMap<String, usize> = HashMap::new();
534 let mut next_id = 0usize;
535
536 let mut get_or_insert = |iri: &str| -> usize {
537 if let Some(&id) = entity_index.get(iri) {
538 return id;
539 }
540 let id = next_id;
541 next_id += 1;
542 entity_index.insert(iri.to_string(), id);
543 id
544 };
545
546 for (s, _p, o) in triples {
548 get_or_insert(s.as_str());
549 get_or_insert(o.as_str());
550 }
551
552 let mut adj: HashMap<usize, Vec<usize>> = HashMap::new();
554 for (s, _p, o) in triples {
555 let s_idx = *entity_index.get(s.as_str()).expect("just inserted");
556 let o_idx = *entity_index.get(o.as_str()).expect("just inserted");
557 adj.entry(s_idx).or_default().push(o_idx);
558 adj.entry(o_idx).or_default().push(s_idx);
559 }
560
561 (entity_index, adj)
562 }
563
564 fn forward_all(
566 &self,
567 h0: &[Vec<f64>],
568 adj: &HashMap<usize, Vec<usize>>,
569 num_entities: usize,
570 ) -> Vec<Vec<f64>> {
571 let mut h_prev = h0.to_vec();
572
573 for layer_idx in 0..self.config.num_layers {
574 let mut h_next: Vec<Vec<f64>> = Vec::with_capacity(num_entities);
575 for node_idx in 0..num_entities {
576 let out = self.attention_forward_on(node_idx, adj, &h_prev, layer_idx);
578 h_next.push(out);
579 }
580 h_prev = h_next;
581 }
582
583 h_prev
584 }
585
586 fn attention_forward_on(
588 &self,
589 entity_idx: usize,
590 adj: &HashMap<usize, Vec<usize>>,
591 embeddings: &[Vec<f64>],
592 layer_idx: usize,
593 ) -> Vec<f64> {
594 self.attention_forward(entity_idx, adj, embeddings, layer_idx)
595 }
596
597 fn sample_negative(positive_idx: usize, num_entities: usize, lcg: &mut SimpleLcg) -> usize {
599 if num_entities <= 1 {
600 return 0;
601 }
602 let mut candidate = lcg.next_usize() % num_entities;
603 let mut attempts = 0usize;
604 while candidate == positive_idx && attempts < num_entities {
605 candidate = (candidate + 1) % num_entities;
606 attempts += 1;
607 }
608 candidate
609 }
610}
611
612#[cfg(test)]
615mod tests {
616 use super::*;
617
618 fn toy_triples(n_entities: usize, n_triples: usize) -> Vec<(String, String, String)> {
620 let mut triples = Vec::with_capacity(n_triples);
621 for i in 0..n_triples {
622 let s = format!("http://ex.org/e{}", i % n_entities);
623 let p = "http://ex.org/rel".to_string();
624 let o = format!("http://ex.org/e{}", (i + 1) % n_entities);
625 triples.push((s, p, o));
626 }
627 triples
628 }
629
630 #[test]
632 fn test_default_config_dimensions() {
633 let config = GatEmbedderConfig::default();
634 assert_eq!(config.num_layers, 2);
635 assert_eq!(config.hidden_dim, 64);
636 assert_eq!(config.num_heads, 4);
637 assert_eq!(config.num_epochs, 50);
638 assert_eq!(config.hidden_dim / config.num_heads, 16);
640 }
641
642 #[test]
644 fn test_fit_completes_small_graph() {
645 let config = GatEmbedderConfig {
646 num_layers: 2,
647 hidden_dim: 16,
648 num_heads: 4,
649 num_epochs: 5,
650 seed: 7,
651 ..Default::default()
652 };
653 let triples = toy_triples(5, 8);
654 let mut embedder = GatEmbedder::new(config);
655 let result = embedder.fit(&triples);
656 assert!(result.is_ok(), "fit should succeed: {result:?}");
657 assert!(embedder.is_trained());
658 assert_eq!(embedder.num_entities(), 5);
659 }
660
661 #[test]
663 fn test_embed_entity_dimension() {
664 let config = GatEmbedderConfig {
665 num_layers: 2,
666 hidden_dim: 32,
667 num_heads: 4,
668 num_epochs: 3,
669 seed: 11,
670 ..Default::default()
671 };
672 let triples = toy_triples(5, 8);
673 let mut embedder = GatEmbedder::new(config.clone());
674 embedder.fit(&triples).expect("fit should succeed");
675
676 for i in 0..5usize {
677 let iri = format!("http://ex.org/e{}", i);
678 let emb = embedder.embed_entity(&iri);
679 assert_eq!(
680 emb.len(),
681 config.hidden_dim,
682 "embedding length mismatch for entity {iri}"
683 );
684 }
685 }
686
687 #[test]
689 fn test_unseen_entity_returns_zero_vector() {
690 let config = GatEmbedderConfig {
691 num_layers: 1,
692 hidden_dim: 16,
693 num_heads: 2,
694 num_epochs: 2,
695 seed: 3,
696 ..Default::default()
697 };
698 let triples = toy_triples(5, 8);
699 let mut embedder = GatEmbedder::new(config.clone());
700 embedder.fit(&triples).expect("fit should succeed");
701
702 let unseen = "http://ex.org/TOTALLY_UNSEEN";
703 let emb = embedder.embed_entity(unseen);
704 assert_eq!(emb.len(), config.hidden_dim);
705 assert!(
706 emb.iter().all(|&v| v == 0.0),
707 "unseen entity must return a zero vector"
708 );
709 }
710
711 #[test]
713 fn test_softmax_sums_to_one() {
714 let scores = vec![1.0_f64, 2.0, 0.5, -1.0, 3.5];
715 let probs = softmax(&scores);
716 assert_eq!(probs.len(), scores.len());
717 let total: f64 = probs.iter().sum();
718 assert!(
719 (total - 1.0).abs() < 1e-10,
720 "softmax outputs must sum to 1.0, got {total}"
721 );
722 for &p in &probs {
724 assert!(p > 0.0 && p <= 1.0, "softmax value out of (0,1]: {p}");
725 }
726 }
727
728 #[test]
730 fn test_leaky_relu_behavior() {
731 let neg_slope = 0.2_f64;
732 let pos = 3.7_f64;
734 assert!((leaky_relu(pos, neg_slope) - pos).abs() < 1e-12);
735 assert!((leaky_relu(0.0, neg_slope)).abs() < 1e-12);
737 let neg = -4.0_f64;
739 let expected = neg_slope * neg;
740 assert!(
741 (leaky_relu(neg, neg_slope) - expected).abs() < 1e-12,
742 "leaky_relu({neg}) should be {expected}"
743 );
744 assert!(
746 leaky_relu(-5.0, neg_slope).abs() < 5.0,
747 "negative input should be attenuated"
748 );
749 }
750
751 #[test]
753 fn test_embeddings_l2_normalized() {
754 let config = GatEmbedderConfig {
755 num_layers: 2,
756 hidden_dim: 16,
757 num_heads: 4,
758 num_epochs: 3,
759 seed: 13,
760 ..Default::default()
761 };
762 let triples = toy_triples(5, 8);
763 let mut embedder = GatEmbedder::new(config.clone());
764 embedder.fit(&triples).expect("fit should succeed");
765
766 for i in 0..5usize {
767 let iri = format!("http://ex.org/e{}", i);
768 let emb = embedder.embed_entity(&iri);
769 let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
770 if norm > 1e-12 {
772 assert!(
773 (norm - 1.0).abs() < 0.1,
774 "L2 norm out of tolerance for {iri}: got {norm}"
775 );
776 }
777 }
778 }
779
780 #[test]
782 fn test_multi_head_output_dimension() {
783 let config = GatEmbedderConfig {
784 num_layers: 1,
785 hidden_dim: 32,
786 num_heads: 4,
787 num_epochs: 1,
788 seed: 17,
789 ..Default::default()
790 };
791 let triples = toy_triples(5, 8);
792 let mut embedder = GatEmbedder::new(config.clone());
793 embedder.fit(&triples).expect("fit should succeed");
794
795 let (entity_index, adj) = GatEmbedder::build_graph(&triples);
797 let num_entities = entity_index.len();
798 let mut rng = SimpleLcg::new(config.seed);
799 let hidden_dim = config.hidden_dim;
800 let h0: Vec<Vec<f64>> = (0..num_entities)
801 .map(|_| {
802 let mut v: Vec<f64> = (0..hidden_dim)
803 .map(|_| rng.next_f64_range(0.5_f64))
804 .collect();
805 l2_normalize_inplace(&mut v);
806 v
807 })
808 .collect();
809
810 for i in 0..5usize {
812 let iri = format!("http://ex.org/e{}", i);
813 let emb = embedder.embed_entity(&iri);
814 assert_eq!(
815 emb.len(),
816 hidden_dim,
817 "expected output dim {hidden_dim} for entity {i}"
818 );
819 let head_dim = hidden_dim / config.num_heads;
822 assert_eq!(
823 head_dim * config.num_heads,
824 hidden_dim,
825 "concat dim mismatch: {} * {} ≠ {}",
826 head_dim,
827 config.num_heads,
828 hidden_dim
829 );
830 }
831
832 let emb0 = embedder.attention_forward(0, &adj, &h0, 0);
834 assert_eq!(
835 emb0.len(),
836 hidden_dim,
837 "attention_forward should output hidden_dim={hidden_dim}"
838 );
839 }
840
841 #[test]
843 fn test_loss_decreases_over_epochs() {
844 let triples = toy_triples(5, 8);
845
846 let make_config = |epochs: usize, seed: u64| GatEmbedderConfig {
847 num_layers: 2,
848 hidden_dim: 16,
849 num_heads: 4,
850 num_epochs: epochs,
851 learning_rate: 0.05,
852 margin: 1.0,
853 seed,
854 ..Default::default()
855 };
856
857 let avg_sim = |embedder: &GatEmbedder| -> f64 {
859 let (mut total, mut count) = (0.0_f64, 0usize);
860 for (s, _, o) in &triples {
861 let hs = embedder.embed_entity(s);
862 let ho = embedder.embed_entity(o);
863 let ns: f64 = hs.iter().map(|x| x * x).sum::<f64>().sqrt();
865 let no: f64 = ho.iter().map(|x| x * x).sum::<f64>().sqrt();
866 if ns > 1e-12 && no > 1e-12 {
867 total += cosine_sim(&hs, &ho);
868 count += 1;
869 }
870 }
871 if count > 0 {
872 total / count as f64
873 } else {
874 0.0
875 }
876 };
877
878 let mut e_early = GatEmbedder::new(make_config(1, 42));
879 e_early.fit(&triples).expect("1-epoch fit should succeed");
880 let sim_early = avg_sim(&e_early);
881
882 let mut e_trained = GatEmbedder::new(make_config(50, 42));
883 e_trained
884 .fit(&triples)
885 .expect("50-epoch fit should succeed");
886 let sim_trained = avg_sim(&e_trained);
887
888 assert!(
890 sim_trained >= sim_early - 0.5,
891 "similarity regression: 1-epoch={sim_early:.4} 50-epoch={sim_trained:.4}"
892 );
893 }
894}