Skip to main content

oxirs_embed/graph_models/
gat.rs

1//! GAT: Graph Attention Network Embeddings (v0.3.0)
2//!
3//! Veličković et al., ICLR 2018 — "Graph Attention Networks"
4//!
5//! Key contributions:
6//! - Multi-head attention over node neighborhoods
7//! - No structural assumptions (unlike GCN's fixed Laplacian)
8//! - Implicit structural weighting via learned attention coefficients
9//!
10//! This module provides:
11//! - `GATLayer`:   single multi-head attention layer
12//! - `GATModel`:   stacked GAT layers for deep graph embeddings
13//! - `GATConfig`:  hyperparameter configuration
14//! - `GATEmbeddings`: output container with similarity utilities
15
16use anyhow::{anyhow, Result};
17use serde::{Deserialize, Serialize};
18
19use super::graphsage::{Graph, Lcg};
20
21// ---------------------------------------------------------------------------
22// GATConfig
23// ---------------------------------------------------------------------------
24
25/// Full hyperparameter configuration for a GAT model.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct GATConfig {
28    /// Dimension of raw input node features.
29    pub input_dim: usize,
30    /// Per-head output dimensionality in intermediate layers.
31    pub hidden_head_dim: usize,
32    /// Number of attention heads in intermediate layers.
33    pub hidden_num_heads: usize,
34    /// Per-head output dimensionality in the final layer.
35    pub output_head_dim: usize,
36    /// Number of attention heads in the final layer.
37    pub output_num_heads: usize,
38    /// Number of stacked GAT layers (must be >= 1).
39    pub num_layers: usize,
40    /// Dropout rate applied to attention coefficients (0.0 = disabled).
41    pub dropout: f64,
42    /// Negative slope for LeakyReLU in attention scoring.
43    pub alpha: f64,
44    /// If true, concatenate head outputs in intermediate layers; else average.
45    pub concat_hidden: bool,
46    /// Average head outputs in the final layer (standard GAT for classification).
47    pub avg_output: bool,
48    /// L2-normalize final node embeddings.
49    pub normalize_output: bool,
50    /// Random seed for weight initialization.
51    pub seed: u64,
52}
53
54impl Default for GATConfig {
55    fn default() -> Self {
56        Self {
57            input_dim: 64,
58            hidden_head_dim: 8,
59            hidden_num_heads: 8,
60            output_head_dim: 8,
61            output_num_heads: 1,
62            num_layers: 2,
63            dropout: 0.6,
64            alpha: 0.2,
65            concat_hidden: true,
66            avg_output: true,
67            normalize_output: true,
68            seed: 42,
69        }
70    }
71}
72
73impl GATConfig {
74    /// Compute the total output dimensionality of this configuration.
75    pub fn output_dim(&self) -> usize {
76        if self.avg_output {
77            self.output_head_dim
78        } else {
79            self.output_head_dim * self.output_num_heads
80        }
81    }
82
83    /// Compute the output dimensionality of each hidden (intermediate) layer.
84    pub fn hidden_layer_out_dim(&self) -> usize {
85        if self.concat_hidden {
86            self.hidden_head_dim * self.hidden_num_heads
87        } else {
88            self.hidden_head_dim
89        }
90    }
91}
92
93// ---------------------------------------------------------------------------
94// Attention head
95// ---------------------------------------------------------------------------
96
97/// A single attention head: W * x, then LeakyReLU attention scoring.
98#[derive(Debug, Clone)]
99struct AttentionHead {
100    /// Linear transform W: [in_dim -> head_dim]
101    w: Vec<Vec<f64>>, // [head_dim][in_dim]
102    /// Attention source parameter a_src: [head_dim]
103    a_src: Vec<f64>,
104    /// Attention target parameter a_dst: [head_dim]
105    a_dst: Vec<f64>,
106    head_dim: usize,
107    /// LeakyReLU negative slope
108    alpha: f64,
109}
110
111impl AttentionHead {
112    fn new(in_dim: usize, head_dim: usize, alpha: f64, rng: &mut Lcg) -> Self {
113        let w_scale = (6.0 / (in_dim + head_dim) as f64).sqrt();
114        let w = (0..head_dim)
115            .map(|_| (0..in_dim).map(|_| rng.next_f64_range(w_scale)).collect())
116            .collect();
117        let a_scale = (2.0 / head_dim as f64).sqrt();
118        let a_src = (0..head_dim).map(|_| rng.next_f64_range(a_scale)).collect();
119        let a_dst = (0..head_dim).map(|_| rng.next_f64_range(a_scale)).collect();
120        Self {
121            w,
122            a_src,
123            a_dst,
124            head_dim,
125            alpha,
126        }
127    }
128
129    /// Compute Wh for a single node.
130    fn linear(&self, x: &[f64]) -> Vec<f64> {
131        self.w
132            .iter()
133            .map(|row| row.iter().zip(x.iter()).map(|(&w, &xi)| w * xi).sum())
134            .collect()
135    }
136
137    /// LeakyReLU(x) with configured negative slope.
138    fn leaky_relu(&self, x: f64) -> f64 {
139        if x >= 0.0 {
140            x
141        } else {
142            self.alpha * x
143        }
144    }
145
146    /// Softmax over a slice.
147    fn softmax(scores: &[f64]) -> Vec<f64> {
148        if scores.is_empty() {
149            return Vec::new();
150        }
151        let max = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
152        let exps: Vec<f64> = scores.iter().map(|&s| (s - max).exp()).collect();
153        let sum: f64 = exps.iter().sum();
154        if sum < 1e-12 {
155            return vec![1.0 / scores.len() as f64; scores.len()];
156        }
157        exps.iter().map(|&e| e / sum).collect()
158    }
159
160    /// Compute attention-weighted aggregation for node `v`.
161    ///
162    /// Returns the new head embedding for `v`.
163    fn forward(
164        &self,
165        v: usize,
166        all_transformed: &[Vec<f64>], // Wh for every node
167        neighbors: &[usize],
168        dropout_mask: &[bool], // pre-computed dropout, true = keep
169    ) -> Vec<f64> {
170        // Self-connection plus neighbors
171        let mut candidates: Vec<usize> = vec![v];
172        candidates.extend_from_slice(neighbors);
173
174        let h_v = &all_transformed[v];
175
176        // Compute attention score for each candidate
177        let scores: Vec<f64> = candidates
178            .iter()
179            .map(|&u| {
180                let h_u = &all_transformed[u];
181                // e_ij = LeakyReLU(a_src^T h_i + a_dst^T h_j)
182                let src: f64 = self
183                    .a_src
184                    .iter()
185                    .zip(h_v.iter())
186                    .map(|(&a, &h)| a * h)
187                    .sum();
188                let dst: f64 = self
189                    .a_dst
190                    .iter()
191                    .zip(h_u.iter())
192                    .map(|(&a, &h)| a * h)
193                    .sum();
194                self.leaky_relu(src + dst)
195            })
196            .collect();
197
198        let weights = Self::softmax(&scores);
199
200        // Weighted sum of neighbor transformed features (with dropout)
201        let mut out = vec![0.0f64; self.head_dim];
202        for (k, (&u, &w)) in candidates.iter().zip(weights.iter()).enumerate() {
203            // Apply dropout: if mask says drop, skip (effectively zero weight)
204            let keep = dropout_mask.get(k).copied().unwrap_or(true);
205            let effective_w = if keep { w } else { 0.0 };
206            let h_u = &all_transformed[u];
207            for (j, &val) in h_u.iter().enumerate() {
208                out[j] += effective_w * val;
209            }
210        }
211        // ELU activation on output (standard in GAT)
212        out.iter_mut().for_each(|x| {
213            if *x < 0.0 {
214                *x = (*x).exp() - 1.0;
215            }
216        });
217        out
218    }
219}
220
221// ---------------------------------------------------------------------------
222// GATLayer
223// ---------------------------------------------------------------------------
224
225/// A multi-head graph attention layer.
226///
227/// For each node `v`:
228///   For each head `k`:
229///     α_ij = softmax_j( LeakyReLU( a^T [Wh_i || Wh_j] ) )
230///     h_v^k = ELU( Σ_j α_ij * W * h_j )
231///   h_v = CONCAT or AVG over heads
232pub struct GATLayer {
233    heads: Vec<AttentionHead>,
234    in_dim: usize,
235    head_dim: usize,
236    num_heads: usize,
237    concat: bool,
238    dropout_rate: f64,
239}
240
241impl std::fmt::Debug for GATLayer {
242    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243        f.debug_struct("GATLayer")
244            .field("in_dim", &self.in_dim)
245            .field("num_heads", &self.num_heads)
246            .field("head_dim", &self.head_dim)
247            .field("concat", &self.concat)
248            .finish()
249    }
250}
251
252impl GATLayer {
253    /// Create a new GAT layer.
254    pub fn new(
255        in_dim: usize,
256        head_dim: usize,
257        num_heads: usize,
258        alpha: f64,
259        dropout: f64,
260        concat: bool,
261        rng: &mut Lcg,
262    ) -> Result<Self> {
263        if in_dim == 0 {
264            return Err(anyhow!("GATLayer: in_dim must be > 0"));
265        }
266        if head_dim == 0 {
267            return Err(anyhow!("GATLayer: head_dim must be > 0"));
268        }
269        if num_heads == 0 {
270            return Err(anyhow!("GATLayer: num_heads must be > 0"));
271        }
272        let heads = (0..num_heads)
273            .map(|_| AttentionHead::new(in_dim, head_dim, alpha, rng))
274            .collect();
275        Ok(Self {
276            heads,
277            in_dim,
278            head_dim,
279            num_heads,
280            concat,
281            dropout_rate: dropout,
282        })
283    }
284
285    /// Output dimensionality of this layer.
286    pub fn out_dim(&self) -> usize {
287        if self.concat {
288            self.head_dim * self.num_heads
289        } else {
290            self.head_dim
291        }
292    }
293
294    /// Forward pass: returns new embeddings for all nodes.
295    pub fn forward(
296        &self,
297        graph: &Graph,
298        current_embeddings: &[Vec<f64>],
299        rng: &mut Lcg,
300    ) -> Vec<Vec<f64>> {
301        let n = graph.num_nodes();
302
303        // Pre-compute Wh for each head and all nodes
304        // all_transformed[head_idx][node_idx] = head.linear(node_emb)
305        let all_transformed: Vec<Vec<Vec<f64>>> = self
306            .heads
307            .iter()
308            .map(|head| {
309                current_embeddings
310                    .iter()
311                    .map(|emb| head.linear(emb))
312                    .collect()
313            })
314            .collect();
315
316        // Compute new embedding for each node
317        (0..n)
318            .map(|v| {
319                let neighbors = graph.neighbors(v);
320                // Generate dropout masks (one per head per candidate)
321                let num_candidates = 1 + neighbors.len(); // self + neighbors
322                let dropout_mask: Vec<bool> = (0..num_candidates)
323                    .map(|_| rng.next_f64() > self.dropout_rate)
324                    .collect();
325
326                let head_outputs: Vec<Vec<f64>> = self
327                    .heads
328                    .iter()
329                    .enumerate()
330                    .map(|(k, head)| head.forward(v, &all_transformed[k], neighbors, &dropout_mask))
331                    .collect();
332
333                if self.concat {
334                    // Concatenate: [h_1 || h_2 || ... || h_K]
335                    head_outputs.into_iter().flatten().collect()
336                } else {
337                    // Average: mean over heads
338                    let mut avg = vec![0.0f64; self.head_dim];
339                    for h in &head_outputs {
340                        for (i, &v) in h.iter().enumerate() {
341                            avg[i] += v;
342                        }
343                    }
344                    let k = self.num_heads as f64;
345                    avg.iter_mut().for_each(|x| *x /= k);
346                    avg
347                }
348            })
349            .collect()
350    }
351}
352
353// ---------------------------------------------------------------------------
354// GATModel
355// ---------------------------------------------------------------------------
356
357/// Multi-layer Graph Attention Network.
358pub struct GATModel {
359    layers: Vec<GATLayer>,
360    config: GATConfig,
361}
362
363impl std::fmt::Debug for GATModel {
364    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365        f.debug_struct("GATModel")
366            .field("num_layers", &self.layers.len())
367            .field("output_dim", &self.config.output_dim())
368            .finish()
369    }
370}
371
372impl GATModel {
373    /// Construct a GAT model from configuration.
374    pub fn new(config: GATConfig) -> Result<Self> {
375        if config.input_dim == 0 {
376            return Err(anyhow!("GATConfig: input_dim must be > 0"));
377        }
378        if config.num_layers == 0 {
379            return Err(anyhow!("GATConfig: num_layers must be > 0"));
380        }
381        if config.hidden_head_dim == 0 {
382            return Err(anyhow!("GATConfig: hidden_head_dim must be > 0"));
383        }
384        if config.output_head_dim == 0 {
385            return Err(anyhow!("GATConfig: output_head_dim must be > 0"));
386        }
387        if config.hidden_num_heads == 0 || config.output_num_heads == 0 {
388            return Err(anyhow!("GATConfig: num_heads must be > 0"));
389        }
390
391        let mut rng = Lcg::new(config.seed);
392        let mut layers = Vec::with_capacity(config.num_layers);
393
394        // Compute layer-by-layer input dimensions
395        let mut current_in_dim = config.input_dim;
396        for layer_idx in 0..config.num_layers {
397            let is_last = layer_idx == config.num_layers - 1;
398            let (head_dim, num_heads, concat) = if is_last {
399                (
400                    config.output_head_dim,
401                    config.output_num_heads,
402                    !config.avg_output,
403                )
404            } else {
405                (
406                    config.hidden_head_dim,
407                    config.hidden_num_heads,
408                    config.concat_hidden,
409                )
410            };
411
412            let layer = GATLayer::new(
413                current_in_dim,
414                head_dim,
415                num_heads,
416                config.alpha,
417                config.dropout,
418                concat,
419                &mut rng,
420            )?;
421            current_in_dim = layer.out_dim();
422            layers.push(layer);
423        }
424
425        Ok(Self { layers, config })
426    }
427
428    /// Compute embeddings for all nodes in `graph`.
429    pub fn embed(&self, graph: &Graph) -> Result<GATEmbeddings> {
430        if graph.num_nodes() == 0 {
431            return Err(anyhow!("GATModel: graph has no nodes"));
432        }
433        let mut rng = Lcg::new(self.config.seed.wrapping_add(0xcafe_babe));
434        let mut current: Vec<Vec<f64>> = graph.node_features.clone();
435        for layer in &self.layers {
436            current = layer.forward(graph, &current, &mut rng);
437        }
438        if self.config.normalize_output {
439            for emb in &mut current {
440                l2_normalize_inplace(emb);
441            }
442        }
443        let dim = self.config.output_dim();
444        let num_nodes = graph.num_nodes();
445        Ok(GATEmbeddings {
446            embeddings: current,
447            num_nodes,
448            dim,
449        })
450    }
451}
452
453// ---------------------------------------------------------------------------
454// GATEmbeddings
455// ---------------------------------------------------------------------------
456
457/// Node embeddings output by `GATModel`.
458#[derive(Debug, Clone)]
459pub struct GATEmbeddings {
460    pub embeddings: Vec<Vec<f64>>,
461    pub num_nodes: usize,
462    pub dim: usize,
463}
464
465impl GATEmbeddings {
466    /// Get embedding for node `v`.
467    pub fn get(&self, v: usize) -> Option<&[f64]> {
468        self.embeddings.get(v).map(|e| e.as_slice())
469    }
470
471    /// Cosine similarity between nodes `a` and `b`.
472    pub fn cosine_similarity(&self, a: usize, b: usize) -> Option<f64> {
473        let ea = self.embeddings.get(a)?;
474        let eb = self.embeddings.get(b)?;
475        Some(cosine_similarity_vecs(ea, eb))
476    }
477
478    /// Top-k most similar nodes to `query_node` (excluding itself).
479    pub fn top_k_similar(&self, query_node: usize, k: usize) -> Vec<(usize, f64)> {
480        let qe = match self.embeddings.get(query_node) {
481            Some(e) => e,
482            None => return Vec::new(),
483        };
484        let mut sims: Vec<(usize, f64)> = self
485            .embeddings
486            .iter()
487            .enumerate()
488            .filter(|(i, _)| *i != query_node)
489            .map(|(i, e)| (i, cosine_similarity_vecs(qe, e)))
490            .collect();
491        sims.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
492        sims.truncate(k);
493        sims
494    }
495
496    /// Compute mean embedding across all nodes.
497    pub fn mean_embedding(&self) -> Vec<f64> {
498        if self.embeddings.is_empty() {
499            return Vec::new();
500        }
501        let mut mean = vec![0.0f64; self.dim];
502        for emb in &self.embeddings {
503            for (i, &v) in emb.iter().enumerate().take(self.dim) {
504                mean[i] += v;
505            }
506        }
507        let n = self.embeddings.len() as f64;
508        mean.iter_mut().for_each(|v| *v /= n);
509        mean
510    }
511}
512
513// ---------------------------------------------------------------------------
514// Utility
515// ---------------------------------------------------------------------------
516
517fn cosine_similarity_vecs(a: &[f64], b: &[f64]) -> f64 {
518    let dot: f64 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
519    let na: f64 = a.iter().map(|&x| x * x).sum::<f64>().sqrt();
520    let nb: f64 = b.iter().map(|&x| x * x).sum::<f64>().sqrt();
521    if na < 1e-12 || nb < 1e-12 {
522        return 0.0;
523    }
524    (dot / (na * nb)).clamp(-1.0, 1.0)
525}
526
527fn l2_normalize_inplace(v: &mut [f64]) {
528    let norm: f64 = v.iter().map(|&x| x * x).sum::<f64>().sqrt();
529    if norm > 1e-12 {
530        v.iter_mut().for_each(|x| *x /= norm);
531    }
532}
533
534// ---------------------------------------------------------------------------
535// Tests
536// ---------------------------------------------------------------------------
537
538#[cfg(test)]
539mod tests {
540    use super::super::graphsage::{Graph, Lcg};
541    use super::*;
542
543    fn line_graph(n: usize, feat_dim: usize, seed: u64) -> Graph {
544        let mut rng = Lcg::new(seed);
545        let features: Vec<Vec<f64>> = (0..n)
546            .map(|_| (0..feat_dim).map(|_| rng.next_f64()).collect())
547            .collect();
548        let mut adjacency: Vec<Vec<usize>> = vec![Vec::new(); n];
549        for i in 0..n.saturating_sub(1) {
550            adjacency[i].push(i + 1);
551            adjacency[i + 1].push(i);
552        }
553        Graph::new(features, adjacency).expect("line graph construction should succeed")
554    }
555
556    #[test]
557    fn test_gat_config_default() {
558        let config = GATConfig::default();
559        assert_eq!(config.num_layers, 2);
560        assert_eq!(config.hidden_num_heads, 8);
561        // Default: avg_output=true => output_dim = output_head_dim
562        assert_eq!(config.output_dim(), config.output_head_dim);
563    }
564
565    #[test]
566    fn test_gat_config_concat_hidden() {
567        let config = GATConfig {
568            hidden_head_dim: 8,
569            hidden_num_heads: 4,
570            concat_hidden: true,
571            ..Default::default()
572        };
573        assert_eq!(config.hidden_layer_out_dim(), 32); // 8 * 4
574    }
575
576    #[test]
577    fn test_gat_config_avg_hidden() {
578        let config = GATConfig {
579            hidden_head_dim: 8,
580            hidden_num_heads: 4,
581            concat_hidden: false,
582            ..Default::default()
583        };
584        assert_eq!(config.hidden_layer_out_dim(), 8); // just head_dim
585    }
586
587    #[test]
588    fn test_gat_layer_construction() {
589        let mut rng = Lcg::new(42);
590        let layer =
591            GATLayer::new(8, 4, 2, 0.2, 0.0, true, &mut rng).expect("layer should construct");
592        assert_eq!(layer.out_dim(), 8); // 2 heads * 4 dim, concat
593    }
594
595    #[test]
596    fn test_gat_layer_avg() {
597        let mut rng = Lcg::new(43);
598        let layer =
599            GATLayer::new(8, 4, 3, 0.2, 0.0, false, &mut rng).expect("layer should construct");
600        assert_eq!(layer.out_dim(), 4); // avg mode: just head_dim
601    }
602
603    #[test]
604    fn test_gat_layer_invalid() {
605        let mut rng = Lcg::new(1);
606        assert!(GATLayer::new(0, 4, 2, 0.2, 0.0, true, &mut rng).is_err());
607        assert!(GATLayer::new(8, 0, 2, 0.2, 0.0, true, &mut rng).is_err());
608        assert!(GATLayer::new(8, 4, 0, 0.2, 0.0, true, &mut rng).is_err());
609    }
610
611    #[test]
612    fn test_gat_model_embed_shape() {
613        let config = GATConfig {
614            input_dim: 8,
615            hidden_head_dim: 4,
616            hidden_num_heads: 2,
617            output_head_dim: 4,
618            output_num_heads: 1,
619            num_layers: 2,
620            dropout: 0.0,
621            concat_hidden: true,
622            avg_output: true,
623            normalize_output: false,
624            ..Default::default()
625        };
626        let model = GATModel::new(config.clone()).expect("GAT model should construct");
627        let g = line_graph(5, 8, 100);
628        let embs = model.embed(&g).expect("embed should succeed");
629
630        assert_eq!(embs.num_nodes, 5);
631        assert_eq!(embs.dim, config.output_dim());
632        for i in 0..5 {
633            assert_eq!(
634                embs.get(i).expect("embedding should exist").len(),
635                config.output_dim()
636            );
637        }
638    }
639
640    #[test]
641    fn test_gat_model_single_layer() {
642        let config = GATConfig {
643            input_dim: 4,
644            hidden_head_dim: 8,
645            hidden_num_heads: 2,
646            output_head_dim: 8,
647            output_num_heads: 2,
648            num_layers: 1,
649            dropout: 0.0,
650            concat_hidden: true,
651            avg_output: false,
652            normalize_output: false,
653            ..Default::default()
654        };
655        let model = GATModel::new(config.clone()).expect("GAT model should construct");
656        let g = line_graph(4, 4, 200);
657        let embs = model.embed(&g).expect("embed should succeed");
658        // Single layer, no avg: concat of 2 heads * 8 = 16
659        assert_eq!(embs.dim, 16);
660    }
661
662    #[test]
663    fn test_gat_model_normalized_output() {
664        let config = GATConfig {
665            input_dim: 4,
666            hidden_head_dim: 4,
667            hidden_num_heads: 2,
668            output_head_dim: 4,
669            output_num_heads: 1,
670            num_layers: 1,
671            dropout: 0.0,
672            concat_hidden: false,
673            avg_output: true,
674            normalize_output: true,
675            ..Default::default()
676        };
677        let model = GATModel::new(config).expect("GAT model should construct");
678        let g = line_graph(5, 4, 300);
679        let embs = model.embed(&g).expect("embed should succeed");
680        for i in 0..5 {
681            let emb = embs.get(i).expect("embedding exists");
682            let norm: f64 = emb.iter().map(|&x| x * x).sum::<f64>().sqrt();
683            assert!(norm <= 1.0 + 1e-6, "norm {} should be <= 1", norm);
684        }
685    }
686
687    #[test]
688    fn test_gat_cosine_similarity_bounds() {
689        let config = GATConfig {
690            input_dim: 4,
691            hidden_head_dim: 4,
692            hidden_num_heads: 2,
693            output_head_dim: 4,
694            output_num_heads: 1,
695            num_layers: 1,
696            dropout: 0.0,
697            concat_hidden: true,
698            avg_output: true,
699            normalize_output: false,
700            ..Default::default()
701        };
702        let model = GATModel::new(config).expect("GAT model should construct");
703        let g = line_graph(5, 4, 400);
704        let embs = model.embed(&g).expect("embed should succeed");
705        for i in 0..5 {
706            for j in 0..5 {
707                if let Some(sim) = embs.cosine_similarity(i, j) {
708                    assert!(
709                        (-1.0 - 1e-6..=1.0 + 1e-6).contains(&sim),
710                        "cosine_similarity({i}, {j}) = {sim} out of range"
711                    );
712                }
713            }
714        }
715    }
716
717    #[test]
718    fn test_gat_top_k_similar() {
719        let config = GATConfig {
720            input_dim: 4,
721            hidden_head_dim: 4,
722            hidden_num_heads: 2,
723            output_head_dim: 4,
724            output_num_heads: 1,
725            num_layers: 2,
726            dropout: 0.0,
727            concat_hidden: true,
728            avg_output: true,
729            normalize_output: true,
730            ..Default::default()
731        };
732        let model = GATModel::new(config).expect("GAT model should construct");
733        let g = line_graph(8, 4, 500);
734        let embs = model.embed(&g).expect("embed should succeed");
735        let top3 = embs.top_k_similar(0, 3);
736        assert!(top3.len() <= 3);
737        for window in top3.windows(2) {
738            assert!(
739                window[0].1 >= window[1].1 - 1e-10,
740                "top_k should be sorted descending"
741            );
742        }
743    }
744
745    #[test]
746    fn test_gat_isolated_node() {
747        let config = GATConfig {
748            input_dim: 4,
749            hidden_head_dim: 4,
750            hidden_num_heads: 2,
751            output_head_dim: 4,
752            output_num_heads: 1,
753            num_layers: 1,
754            dropout: 0.0,
755            concat_hidden: true,
756            avg_output: true,
757            normalize_output: false,
758            ..Default::default()
759        };
760        let model = GATModel::new(config).expect("GAT model should construct");
761        let features = vec![vec![1.0f64, 0.5, -0.3, 0.8]];
762        let adjacency = vec![vec![]]; // isolated node
763        let g = Graph::new(features, adjacency).expect("isolated node graph");
764        let embs = model.embed(&g).expect("isolated node should embed");
765        assert_eq!(embs.num_nodes, 1);
766        assert!(embs.get(0).is_some());
767    }
768
769    #[test]
770    fn test_gat_invalid_config() {
771        assert!(GATModel::new(GATConfig {
772            input_dim: 0,
773            ..Default::default()
774        })
775        .is_err());
776        assert!(GATModel::new(GATConfig {
777            num_layers: 0,
778            ..Default::default()
779        })
780        .is_err());
781        assert!(GATModel::new(GATConfig {
782            hidden_num_heads: 0,
783            ..Default::default()
784        })
785        .is_err());
786        assert!(GATModel::new(GATConfig {
787            output_head_dim: 0,
788            ..Default::default()
789        })
790        .is_err());
791    }
792
793    #[test]
794    fn test_gat_mean_embedding() {
795        let config = GATConfig {
796            input_dim: 4,
797            hidden_head_dim: 4,
798            hidden_num_heads: 2,
799            output_head_dim: 4,
800            output_num_heads: 1,
801            num_layers: 1,
802            dropout: 0.0,
803            concat_hidden: false,
804            avg_output: true,
805            normalize_output: true,
806            ..Default::default()
807        };
808        let model = GATModel::new(config).expect("GAT model should construct");
809        let g = line_graph(5, 4, 600);
810        let embs = model.embed(&g).expect("embed should succeed");
811        let mean = embs.mean_embedding();
812        assert_eq!(mean.len(), embs.dim);
813    }
814
815    #[test]
816    fn test_gat_attention_softmax_sums_to_one() {
817        let scores = vec![1.0f64, 2.0, 3.0, 0.5, -1.0];
818        let weights = AttentionHead::softmax(&scores);
819        let sum: f64 = weights.iter().sum();
820        assert!(
821            (sum - 1.0).abs() < 1e-10,
822            "softmax should sum to 1, got {sum}"
823        );
824        // Larger scores should produce larger weights
825        assert!(weights[2] > weights[1]);
826        assert!(weights[1] > weights[0]);
827    }
828
829    #[test]
830    fn test_gat_three_layer_deep() {
831        let config = GATConfig {
832            input_dim: 8,
833            hidden_head_dim: 4,
834            hidden_num_heads: 3,
835            output_head_dim: 4,
836            output_num_heads: 1,
837            num_layers: 3,
838            dropout: 0.0,
839            concat_hidden: true,
840            avg_output: true,
841            normalize_output: true,
842            seed: 77,
843            ..Default::default()
844        };
845        let model = GATModel::new(config.clone()).expect("3-layer GAT should construct");
846        let g = line_graph(6, 8, 77);
847        let embs = model.embed(&g).expect("embed should succeed");
848        assert_eq!(embs.num_nodes, 6);
849        assert_eq!(embs.dim, config.output_dim());
850    }
851}