1use anyhow::{anyhow, Result};
17use serde::{Deserialize, Serialize};
18
19use super::graphsage::{Graph, Lcg};
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct GATConfig {
28 pub input_dim: usize,
30 pub hidden_head_dim: usize,
32 pub hidden_num_heads: usize,
34 pub output_head_dim: usize,
36 pub output_num_heads: usize,
38 pub num_layers: usize,
40 pub dropout: f64,
42 pub alpha: f64,
44 pub concat_hidden: bool,
46 pub avg_output: bool,
48 pub normalize_output: bool,
50 pub seed: u64,
52}
53
54impl Default for GATConfig {
55 fn default() -> Self {
56 Self {
57 input_dim: 64,
58 hidden_head_dim: 8,
59 hidden_num_heads: 8,
60 output_head_dim: 8,
61 output_num_heads: 1,
62 num_layers: 2,
63 dropout: 0.6,
64 alpha: 0.2,
65 concat_hidden: true,
66 avg_output: true,
67 normalize_output: true,
68 seed: 42,
69 }
70 }
71}
72
73impl GATConfig {
74 pub fn output_dim(&self) -> usize {
76 if self.avg_output {
77 self.output_head_dim
78 } else {
79 self.output_head_dim * self.output_num_heads
80 }
81 }
82
83 pub fn hidden_layer_out_dim(&self) -> usize {
85 if self.concat_hidden {
86 self.hidden_head_dim * self.hidden_num_heads
87 } else {
88 self.hidden_head_dim
89 }
90 }
91}
92
93#[derive(Debug, Clone)]
99struct AttentionHead {
100 w: Vec<Vec<f64>>, a_src: Vec<f64>,
104 a_dst: Vec<f64>,
106 head_dim: usize,
107 alpha: f64,
109}
110
111impl AttentionHead {
112 fn new(in_dim: usize, head_dim: usize, alpha: f64, rng: &mut Lcg) -> Self {
113 let w_scale = (6.0 / (in_dim + head_dim) as f64).sqrt();
114 let w = (0..head_dim)
115 .map(|_| (0..in_dim).map(|_| rng.next_f64_range(w_scale)).collect())
116 .collect();
117 let a_scale = (2.0 / head_dim as f64).sqrt();
118 let a_src = (0..head_dim).map(|_| rng.next_f64_range(a_scale)).collect();
119 let a_dst = (0..head_dim).map(|_| rng.next_f64_range(a_scale)).collect();
120 Self {
121 w,
122 a_src,
123 a_dst,
124 head_dim,
125 alpha,
126 }
127 }
128
129 fn linear(&self, x: &[f64]) -> Vec<f64> {
131 self.w
132 .iter()
133 .map(|row| row.iter().zip(x.iter()).map(|(&w, &xi)| w * xi).sum())
134 .collect()
135 }
136
137 fn leaky_relu(&self, x: f64) -> f64 {
139 if x >= 0.0 {
140 x
141 } else {
142 self.alpha * x
143 }
144 }
145
146 fn softmax(scores: &[f64]) -> Vec<f64> {
148 if scores.is_empty() {
149 return Vec::new();
150 }
151 let max = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
152 let exps: Vec<f64> = scores.iter().map(|&s| (s - max).exp()).collect();
153 let sum: f64 = exps.iter().sum();
154 if sum < 1e-12 {
155 return vec![1.0 / scores.len() as f64; scores.len()];
156 }
157 exps.iter().map(|&e| e / sum).collect()
158 }
159
160 fn forward(
164 &self,
165 v: usize,
166 all_transformed: &[Vec<f64>], neighbors: &[usize],
168 dropout_mask: &[bool], ) -> Vec<f64> {
170 let mut candidates: Vec<usize> = vec![v];
172 candidates.extend_from_slice(neighbors);
173
174 let h_v = &all_transformed[v];
175
176 let scores: Vec<f64> = candidates
178 .iter()
179 .map(|&u| {
180 let h_u = &all_transformed[u];
181 let src: f64 = self
183 .a_src
184 .iter()
185 .zip(h_v.iter())
186 .map(|(&a, &h)| a * h)
187 .sum();
188 let dst: f64 = self
189 .a_dst
190 .iter()
191 .zip(h_u.iter())
192 .map(|(&a, &h)| a * h)
193 .sum();
194 self.leaky_relu(src + dst)
195 })
196 .collect();
197
198 let weights = Self::softmax(&scores);
199
200 let mut out = vec![0.0f64; self.head_dim];
202 for (k, (&u, &w)) in candidates.iter().zip(weights.iter()).enumerate() {
203 let keep = dropout_mask.get(k).copied().unwrap_or(true);
205 let effective_w = if keep { w } else { 0.0 };
206 let h_u = &all_transformed[u];
207 for (j, &val) in h_u.iter().enumerate() {
208 out[j] += effective_w * val;
209 }
210 }
211 out.iter_mut().for_each(|x| {
213 if *x < 0.0 {
214 *x = (*x).exp() - 1.0;
215 }
216 });
217 out
218 }
219}
220
221pub struct GATLayer {
233 heads: Vec<AttentionHead>,
234 in_dim: usize,
235 head_dim: usize,
236 num_heads: usize,
237 concat: bool,
238 dropout_rate: f64,
239}
240
241impl std::fmt::Debug for GATLayer {
242 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243 f.debug_struct("GATLayer")
244 .field("in_dim", &self.in_dim)
245 .field("num_heads", &self.num_heads)
246 .field("head_dim", &self.head_dim)
247 .field("concat", &self.concat)
248 .finish()
249 }
250}
251
252impl GATLayer {
253 pub fn new(
255 in_dim: usize,
256 head_dim: usize,
257 num_heads: usize,
258 alpha: f64,
259 dropout: f64,
260 concat: bool,
261 rng: &mut Lcg,
262 ) -> Result<Self> {
263 if in_dim == 0 {
264 return Err(anyhow!("GATLayer: in_dim must be > 0"));
265 }
266 if head_dim == 0 {
267 return Err(anyhow!("GATLayer: head_dim must be > 0"));
268 }
269 if num_heads == 0 {
270 return Err(anyhow!("GATLayer: num_heads must be > 0"));
271 }
272 let heads = (0..num_heads)
273 .map(|_| AttentionHead::new(in_dim, head_dim, alpha, rng))
274 .collect();
275 Ok(Self {
276 heads,
277 in_dim,
278 head_dim,
279 num_heads,
280 concat,
281 dropout_rate: dropout,
282 })
283 }
284
285 pub fn out_dim(&self) -> usize {
287 if self.concat {
288 self.head_dim * self.num_heads
289 } else {
290 self.head_dim
291 }
292 }
293
294 pub fn forward(
296 &self,
297 graph: &Graph,
298 current_embeddings: &[Vec<f64>],
299 rng: &mut Lcg,
300 ) -> Vec<Vec<f64>> {
301 let n = graph.num_nodes();
302
303 let all_transformed: Vec<Vec<Vec<f64>>> = self
306 .heads
307 .iter()
308 .map(|head| {
309 current_embeddings
310 .iter()
311 .map(|emb| head.linear(emb))
312 .collect()
313 })
314 .collect();
315
316 (0..n)
318 .map(|v| {
319 let neighbors = graph.neighbors(v);
320 let num_candidates = 1 + neighbors.len(); let dropout_mask: Vec<bool> = (0..num_candidates)
323 .map(|_| rng.next_f64() > self.dropout_rate)
324 .collect();
325
326 let head_outputs: Vec<Vec<f64>> = self
327 .heads
328 .iter()
329 .enumerate()
330 .map(|(k, head)| head.forward(v, &all_transformed[k], neighbors, &dropout_mask))
331 .collect();
332
333 if self.concat {
334 head_outputs.into_iter().flatten().collect()
336 } else {
337 let mut avg = vec![0.0f64; self.head_dim];
339 for h in &head_outputs {
340 for (i, &v) in h.iter().enumerate() {
341 avg[i] += v;
342 }
343 }
344 let k = self.num_heads as f64;
345 avg.iter_mut().for_each(|x| *x /= k);
346 avg
347 }
348 })
349 .collect()
350 }
351}
352
353pub struct GATModel {
359 layers: Vec<GATLayer>,
360 config: GATConfig,
361}
362
363impl std::fmt::Debug for GATModel {
364 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365 f.debug_struct("GATModel")
366 .field("num_layers", &self.layers.len())
367 .field("output_dim", &self.config.output_dim())
368 .finish()
369 }
370}
371
372impl GATModel {
373 pub fn new(config: GATConfig) -> Result<Self> {
375 if config.input_dim == 0 {
376 return Err(anyhow!("GATConfig: input_dim must be > 0"));
377 }
378 if config.num_layers == 0 {
379 return Err(anyhow!("GATConfig: num_layers must be > 0"));
380 }
381 if config.hidden_head_dim == 0 {
382 return Err(anyhow!("GATConfig: hidden_head_dim must be > 0"));
383 }
384 if config.output_head_dim == 0 {
385 return Err(anyhow!("GATConfig: output_head_dim must be > 0"));
386 }
387 if config.hidden_num_heads == 0 || config.output_num_heads == 0 {
388 return Err(anyhow!("GATConfig: num_heads must be > 0"));
389 }
390
391 let mut rng = Lcg::new(config.seed);
392 let mut layers = Vec::with_capacity(config.num_layers);
393
394 let mut current_in_dim = config.input_dim;
396 for layer_idx in 0..config.num_layers {
397 let is_last = layer_idx == config.num_layers - 1;
398 let (head_dim, num_heads, concat) = if is_last {
399 (
400 config.output_head_dim,
401 config.output_num_heads,
402 !config.avg_output,
403 )
404 } else {
405 (
406 config.hidden_head_dim,
407 config.hidden_num_heads,
408 config.concat_hidden,
409 )
410 };
411
412 let layer = GATLayer::new(
413 current_in_dim,
414 head_dim,
415 num_heads,
416 config.alpha,
417 config.dropout,
418 concat,
419 &mut rng,
420 )?;
421 current_in_dim = layer.out_dim();
422 layers.push(layer);
423 }
424
425 Ok(Self { layers, config })
426 }
427
428 pub fn embed(&self, graph: &Graph) -> Result<GATEmbeddings> {
430 if graph.num_nodes() == 0 {
431 return Err(anyhow!("GATModel: graph has no nodes"));
432 }
433 let mut rng = Lcg::new(self.config.seed.wrapping_add(0xcafe_babe));
434 let mut current: Vec<Vec<f64>> = graph.node_features.clone();
435 for layer in &self.layers {
436 current = layer.forward(graph, ¤t, &mut rng);
437 }
438 if self.config.normalize_output {
439 for emb in &mut current {
440 l2_normalize_inplace(emb);
441 }
442 }
443 let dim = self.config.output_dim();
444 let num_nodes = graph.num_nodes();
445 Ok(GATEmbeddings {
446 embeddings: current,
447 num_nodes,
448 dim,
449 })
450 }
451}
452
453#[derive(Debug, Clone)]
459pub struct GATEmbeddings {
460 pub embeddings: Vec<Vec<f64>>,
461 pub num_nodes: usize,
462 pub dim: usize,
463}
464
465impl GATEmbeddings {
466 pub fn get(&self, v: usize) -> Option<&[f64]> {
468 self.embeddings.get(v).map(|e| e.as_slice())
469 }
470
471 pub fn cosine_similarity(&self, a: usize, b: usize) -> Option<f64> {
473 let ea = self.embeddings.get(a)?;
474 let eb = self.embeddings.get(b)?;
475 Some(cosine_similarity_vecs(ea, eb))
476 }
477
478 pub fn top_k_similar(&self, query_node: usize, k: usize) -> Vec<(usize, f64)> {
480 let qe = match self.embeddings.get(query_node) {
481 Some(e) => e,
482 None => return Vec::new(),
483 };
484 let mut sims: Vec<(usize, f64)> = self
485 .embeddings
486 .iter()
487 .enumerate()
488 .filter(|(i, _)| *i != query_node)
489 .map(|(i, e)| (i, cosine_similarity_vecs(qe, e)))
490 .collect();
491 sims.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
492 sims.truncate(k);
493 sims
494 }
495
496 pub fn mean_embedding(&self) -> Vec<f64> {
498 if self.embeddings.is_empty() {
499 return Vec::new();
500 }
501 let mut mean = vec![0.0f64; self.dim];
502 for emb in &self.embeddings {
503 for (i, &v) in emb.iter().enumerate().take(self.dim) {
504 mean[i] += v;
505 }
506 }
507 let n = self.embeddings.len() as f64;
508 mean.iter_mut().for_each(|v| *v /= n);
509 mean
510 }
511}
512
513fn cosine_similarity_vecs(a: &[f64], b: &[f64]) -> f64 {
518 let dot: f64 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
519 let na: f64 = a.iter().map(|&x| x * x).sum::<f64>().sqrt();
520 let nb: f64 = b.iter().map(|&x| x * x).sum::<f64>().sqrt();
521 if na < 1e-12 || nb < 1e-12 {
522 return 0.0;
523 }
524 (dot / (na * nb)).clamp(-1.0, 1.0)
525}
526
527fn l2_normalize_inplace(v: &mut [f64]) {
528 let norm: f64 = v.iter().map(|&x| x * x).sum::<f64>().sqrt();
529 if norm > 1e-12 {
530 v.iter_mut().for_each(|x| *x /= norm);
531 }
532}
533
534#[cfg(test)]
539mod tests {
540 use super::super::graphsage::{Graph, Lcg};
541 use super::*;
542
543 fn line_graph(n: usize, feat_dim: usize, seed: u64) -> Graph {
544 let mut rng = Lcg::new(seed);
545 let features: Vec<Vec<f64>> = (0..n)
546 .map(|_| (0..feat_dim).map(|_| rng.next_f64()).collect())
547 .collect();
548 let mut adjacency: Vec<Vec<usize>> = vec![Vec::new(); n];
549 for i in 0..n.saturating_sub(1) {
550 adjacency[i].push(i + 1);
551 adjacency[i + 1].push(i);
552 }
553 Graph::new(features, adjacency).expect("line graph construction should succeed")
554 }
555
556 #[test]
557 fn test_gat_config_default() {
558 let config = GATConfig::default();
559 assert_eq!(config.num_layers, 2);
560 assert_eq!(config.hidden_num_heads, 8);
561 assert_eq!(config.output_dim(), config.output_head_dim);
563 }
564
565 #[test]
566 fn test_gat_config_concat_hidden() {
567 let config = GATConfig {
568 hidden_head_dim: 8,
569 hidden_num_heads: 4,
570 concat_hidden: true,
571 ..Default::default()
572 };
573 assert_eq!(config.hidden_layer_out_dim(), 32); }
575
576 #[test]
577 fn test_gat_config_avg_hidden() {
578 let config = GATConfig {
579 hidden_head_dim: 8,
580 hidden_num_heads: 4,
581 concat_hidden: false,
582 ..Default::default()
583 };
584 assert_eq!(config.hidden_layer_out_dim(), 8); }
586
587 #[test]
588 fn test_gat_layer_construction() {
589 let mut rng = Lcg::new(42);
590 let layer =
591 GATLayer::new(8, 4, 2, 0.2, 0.0, true, &mut rng).expect("layer should construct");
592 assert_eq!(layer.out_dim(), 8); }
594
595 #[test]
596 fn test_gat_layer_avg() {
597 let mut rng = Lcg::new(43);
598 let layer =
599 GATLayer::new(8, 4, 3, 0.2, 0.0, false, &mut rng).expect("layer should construct");
600 assert_eq!(layer.out_dim(), 4); }
602
603 #[test]
604 fn test_gat_layer_invalid() {
605 let mut rng = Lcg::new(1);
606 assert!(GATLayer::new(0, 4, 2, 0.2, 0.0, true, &mut rng).is_err());
607 assert!(GATLayer::new(8, 0, 2, 0.2, 0.0, true, &mut rng).is_err());
608 assert!(GATLayer::new(8, 4, 0, 0.2, 0.0, true, &mut rng).is_err());
609 }
610
611 #[test]
612 fn test_gat_model_embed_shape() {
613 let config = GATConfig {
614 input_dim: 8,
615 hidden_head_dim: 4,
616 hidden_num_heads: 2,
617 output_head_dim: 4,
618 output_num_heads: 1,
619 num_layers: 2,
620 dropout: 0.0,
621 concat_hidden: true,
622 avg_output: true,
623 normalize_output: false,
624 ..Default::default()
625 };
626 let model = GATModel::new(config.clone()).expect("GAT model should construct");
627 let g = line_graph(5, 8, 100);
628 let embs = model.embed(&g).expect("embed should succeed");
629
630 assert_eq!(embs.num_nodes, 5);
631 assert_eq!(embs.dim, config.output_dim());
632 for i in 0..5 {
633 assert_eq!(
634 embs.get(i).expect("embedding should exist").len(),
635 config.output_dim()
636 );
637 }
638 }
639
640 #[test]
641 fn test_gat_model_single_layer() {
642 let config = GATConfig {
643 input_dim: 4,
644 hidden_head_dim: 8,
645 hidden_num_heads: 2,
646 output_head_dim: 8,
647 output_num_heads: 2,
648 num_layers: 1,
649 dropout: 0.0,
650 concat_hidden: true,
651 avg_output: false,
652 normalize_output: false,
653 ..Default::default()
654 };
655 let model = GATModel::new(config.clone()).expect("GAT model should construct");
656 let g = line_graph(4, 4, 200);
657 let embs = model.embed(&g).expect("embed should succeed");
658 assert_eq!(embs.dim, 16);
660 }
661
662 #[test]
663 fn test_gat_model_normalized_output() {
664 let config = GATConfig {
665 input_dim: 4,
666 hidden_head_dim: 4,
667 hidden_num_heads: 2,
668 output_head_dim: 4,
669 output_num_heads: 1,
670 num_layers: 1,
671 dropout: 0.0,
672 concat_hidden: false,
673 avg_output: true,
674 normalize_output: true,
675 ..Default::default()
676 };
677 let model = GATModel::new(config).expect("GAT model should construct");
678 let g = line_graph(5, 4, 300);
679 let embs = model.embed(&g).expect("embed should succeed");
680 for i in 0..5 {
681 let emb = embs.get(i).expect("embedding exists");
682 let norm: f64 = emb.iter().map(|&x| x * x).sum::<f64>().sqrt();
683 assert!(norm <= 1.0 + 1e-6, "norm {} should be <= 1", norm);
684 }
685 }
686
687 #[test]
688 fn test_gat_cosine_similarity_bounds() {
689 let config = GATConfig {
690 input_dim: 4,
691 hidden_head_dim: 4,
692 hidden_num_heads: 2,
693 output_head_dim: 4,
694 output_num_heads: 1,
695 num_layers: 1,
696 dropout: 0.0,
697 concat_hidden: true,
698 avg_output: true,
699 normalize_output: false,
700 ..Default::default()
701 };
702 let model = GATModel::new(config).expect("GAT model should construct");
703 let g = line_graph(5, 4, 400);
704 let embs = model.embed(&g).expect("embed should succeed");
705 for i in 0..5 {
706 for j in 0..5 {
707 if let Some(sim) = embs.cosine_similarity(i, j) {
708 assert!(
709 (-1.0 - 1e-6..=1.0 + 1e-6).contains(&sim),
710 "cosine_similarity({i}, {j}) = {sim} out of range"
711 );
712 }
713 }
714 }
715 }
716
717 #[test]
718 fn test_gat_top_k_similar() {
719 let config = GATConfig {
720 input_dim: 4,
721 hidden_head_dim: 4,
722 hidden_num_heads: 2,
723 output_head_dim: 4,
724 output_num_heads: 1,
725 num_layers: 2,
726 dropout: 0.0,
727 concat_hidden: true,
728 avg_output: true,
729 normalize_output: true,
730 ..Default::default()
731 };
732 let model = GATModel::new(config).expect("GAT model should construct");
733 let g = line_graph(8, 4, 500);
734 let embs = model.embed(&g).expect("embed should succeed");
735 let top3 = embs.top_k_similar(0, 3);
736 assert!(top3.len() <= 3);
737 for window in top3.windows(2) {
738 assert!(
739 window[0].1 >= window[1].1 - 1e-10,
740 "top_k should be sorted descending"
741 );
742 }
743 }
744
745 #[test]
746 fn test_gat_isolated_node() {
747 let config = GATConfig {
748 input_dim: 4,
749 hidden_head_dim: 4,
750 hidden_num_heads: 2,
751 output_head_dim: 4,
752 output_num_heads: 1,
753 num_layers: 1,
754 dropout: 0.0,
755 concat_hidden: true,
756 avg_output: true,
757 normalize_output: false,
758 ..Default::default()
759 };
760 let model = GATModel::new(config).expect("GAT model should construct");
761 let features = vec![vec![1.0f64, 0.5, -0.3, 0.8]];
762 let adjacency = vec![vec![]]; let g = Graph::new(features, adjacency).expect("isolated node graph");
764 let embs = model.embed(&g).expect("isolated node should embed");
765 assert_eq!(embs.num_nodes, 1);
766 assert!(embs.get(0).is_some());
767 }
768
769 #[test]
770 fn test_gat_invalid_config() {
771 assert!(GATModel::new(GATConfig {
772 input_dim: 0,
773 ..Default::default()
774 })
775 .is_err());
776 assert!(GATModel::new(GATConfig {
777 num_layers: 0,
778 ..Default::default()
779 })
780 .is_err());
781 assert!(GATModel::new(GATConfig {
782 hidden_num_heads: 0,
783 ..Default::default()
784 })
785 .is_err());
786 assert!(GATModel::new(GATConfig {
787 output_head_dim: 0,
788 ..Default::default()
789 })
790 .is_err());
791 }
792
793 #[test]
794 fn test_gat_mean_embedding() {
795 let config = GATConfig {
796 input_dim: 4,
797 hidden_head_dim: 4,
798 hidden_num_heads: 2,
799 output_head_dim: 4,
800 output_num_heads: 1,
801 num_layers: 1,
802 dropout: 0.0,
803 concat_hidden: false,
804 avg_output: true,
805 normalize_output: true,
806 ..Default::default()
807 };
808 let model = GATModel::new(config).expect("GAT model should construct");
809 let g = line_graph(5, 4, 600);
810 let embs = model.embed(&g).expect("embed should succeed");
811 let mean = embs.mean_embedding();
812 assert_eq!(mean.len(), embs.dim);
813 }
814
815 #[test]
816 fn test_gat_attention_softmax_sums_to_one() {
817 let scores = vec![1.0f64, 2.0, 3.0, 0.5, -1.0];
818 let weights = AttentionHead::softmax(&scores);
819 let sum: f64 = weights.iter().sum();
820 assert!(
821 (sum - 1.0).abs() < 1e-10,
822 "softmax should sum to 1, got {sum}"
823 );
824 assert!(weights[2] > weights[1]);
826 assert!(weights[1] > weights[0]);
827 }
828
829 #[test]
830 fn test_gat_three_layer_deep() {
831 let config = GATConfig {
832 input_dim: 8,
833 hidden_head_dim: 4,
834 hidden_num_heads: 3,
835 output_head_dim: 4,
836 output_num_heads: 1,
837 num_layers: 3,
838 dropout: 0.0,
839 concat_hidden: true,
840 avg_output: true,
841 normalize_output: true,
842 seed: 77,
843 ..Default::default()
844 };
845 let model = GATModel::new(config.clone()).expect("3-layer GAT should construct");
846 let g = line_graph(6, 8, 77);
847 let embs = model.embed(&g).expect("embed should succeed");
848 assert_eq!(embs.num_nodes, 6);
849 assert_eq!(embs.dim, config.output_dim());
850 }
851}