Skip to main content

oxirs_embed/models/
gat_basic.rs

1//! Graph Attention Networks (GAT) - Basic Implementation
2//!
3//! Veličković et al. (2018) - ICLR
4//! "Graph Attention Networks"
5//!
6//! Key innovation: learn attention coefficients between nodes and their neighbors,
7//! enabling the model to selectively focus on relevant structural information.
8//! Multi-head attention provides stability and richer representations.
9
10use crate::EmbeddingError;
11use anyhow::{anyhow, Result};
12use serde::{Deserialize, Serialize};
13
14use super::graphsage::{cosine_similarity_vecs, dot_product, GraphData, SimpleLcg};
15
16/// Configuration for a Graph Attention Network
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct GatConfig {
19    /// Dimensionality of input node features
20    pub input_dim: usize,
21    /// Dimensionality of output per attention head
22    pub head_output_dim: usize,
23    /// Number of attention heads
24    pub num_heads: usize,
25    /// Dropout rate (applied to attention coefficients)
26    pub dropout: f64,
27    /// LeakyReLU negative slope for attention scoring
28    pub alpha: f64,
29    /// If true, concatenate head outputs; else average them
30    pub concat_heads: bool,
31    /// L2-normalize output embeddings
32    pub normalize_output: bool,
33    /// Random seed for parameter initialization
34    pub seed: u64,
35}
36
37impl Default for GatConfig {
38    fn default() -> Self {
39        Self {
40            input_dim: 64,
41            head_output_dim: 8,
42            num_heads: 8,
43            dropout: 0.6,
44            alpha: 0.2,
45            concat_heads: true,
46            normalize_output: true,
47            seed: 42,
48        }
49    }
50}
51
52impl GatConfig {
53    /// Compute the final output dimensionality
54    pub fn output_dim(&self) -> usize {
55        if self.concat_heads {
56            self.head_output_dim * self.num_heads
57        } else {
58            self.head_output_dim
59        }
60    }
61}
62
63/// A single attention head in the GAT layer
64///
65/// Computes e_ij = LeakyReLU(a^T [Wh_i || Wh_j]) for each edge (i,j),
66/// then normalizes with softmax over the neighborhood.
67#[derive(Debug, Clone)]
68struct AttentionHead {
69    /// Linear transform W: [input_dim x head_output_dim]
70    w: Vec<Vec<f64>>,
71    /// Attention source weights a_src: [head_output_dim]
72    a_src: Vec<f64>,
73    /// Attention target weights a_dst: [head_output_dim]
74    a_dst: Vec<f64>,
75    /// Output dimensionality (head_output_dim)
76    output_dim: usize,
77    /// LeakyReLU negative slope
78    alpha: f64,
79}
80
81impl AttentionHead {
82    /// Create a new attention head with Xavier initialization
83    fn new(input_dim: usize, output_dim: usize, alpha: f64, rng: &mut SimpleLcg) -> Self {
84        let scale = (6.0 / (input_dim + output_dim) as f64).sqrt();
85        let w = (0..output_dim)
86            .map(|_| (0..input_dim).map(|_| rng.next_f64_range(scale)).collect())
87            .collect();
88
89        let attn_scale = (2.0 / output_dim as f64).sqrt();
90        let a_src = (0..output_dim)
91            .map(|_| rng.next_f64_range(attn_scale))
92            .collect();
93        let a_dst = (0..output_dim)
94            .map(|_| rng.next_f64_range(attn_scale))
95            .collect();
96
97        Self {
98            w,
99            a_src,
100            a_dst,
101            output_dim,
102            alpha,
103        }
104    }
105
106    /// Apply the linear transform W to a feature vector
107    fn transform(&self, feat: &[f64]) -> Vec<f64> {
108        let mut out = vec![0.0f64; self.output_dim];
109        for (i, row) in self.w.iter().enumerate() {
110            for (j, &wv) in row.iter().enumerate() {
111                if j < feat.len() {
112                    out[i] += wv * feat[j];
113                }
114            }
115        }
116        out
117    }
118
119    /// Compute unnormalized attention coefficient e_ij
120    ///
121    /// e_ij = LeakyReLU(a_src^T Wh_i + a_dst^T Wh_j)
122    fn attention_coeff(&self, h_i: &[f64], h_j: &[f64]) -> f64 {
123        let src_score = dot_product(&self.a_src, h_i);
124        let dst_score = dot_product(&self.a_dst, h_j);
125        Self::leaky_relu(src_score + dst_score, self.alpha)
126    }
127
128    /// LeakyReLU: max(alpha*x, x)
129    fn leaky_relu(x: f64, alpha: f64) -> f64 {
130        if x >= 0.0 {
131            x
132        } else {
133            alpha * x
134        }
135    }
136
137    /// Softmax over a slice of scores, returns normalized attention weights
138    fn softmax(scores: &[f64]) -> Vec<f64> {
139        if scores.is_empty() {
140            return Vec::new();
141        }
142        // Numerical stability: subtract max
143        let max_score = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
144        let exps: Vec<f64> = scores.iter().map(|&s| (s - max_score).exp()).collect();
145        let sum: f64 = exps.iter().sum();
146        if sum < 1e-12 {
147            // Uniform if all exps are ~0
148            return vec![1.0 / scores.len() as f64; scores.len()];
149        }
150        exps.iter().map(|e| e / sum).collect()
151    }
152
153    /// Forward pass for a single node.
154    ///
155    /// Computes: h'_i = sigma(sum_{j in N(i)} alpha_ij * W*h_j)
156    /// where alpha_ij are the softmax-normalized attention coefficients.
157    fn forward(&self, node_feat: &[f64], neighbor_feats: &[Vec<f64>]) -> Vec<f64> {
158        // Transform self and neighbors
159        let h_self = self.transform(node_feat);
160
161        if neighbor_feats.is_empty() {
162            // No neighbors: return self-transformed feature with self-attention
163            return h_self;
164        }
165
166        let neighbor_transformed: Vec<Vec<f64>> =
167            neighbor_feats.iter().map(|f| self.transform(f)).collect();
168
169        // Compute attention coefficients (including self-loop)
170        let mut all_feats = vec![&h_self as &Vec<f64>];
171        all_feats.extend(neighbor_transformed.iter());
172
173        let scores: Vec<f64> = all_feats
174            .iter()
175            .map(|h_j| self.attention_coeff(&h_self, h_j))
176            .collect();
177
178        let weights = Self::softmax(&scores);
179
180        // Weighted sum of transformed features
181        let mut output = vec![0.0f64; self.output_dim];
182        for (weight, h_j) in weights.iter().zip(all_feats.iter()) {
183            for (o, &v) in output.iter_mut().zip(h_j.iter()) {
184                *o += weight * v;
185            }
186        }
187
188        // Apply ELU activation (approximated as LeakyReLU here)
189        output
190            .into_iter()
191            .map(|x| Self::leaky_relu(x, self.alpha))
192            .collect()
193    }
194}
195
196/// Graph Attention Network embedding model
197///
198/// Implements multi-head graph attention as described in Veličković et al. (2018).
199/// Each attention head independently computes attention-weighted aggregations,
200/// and the results are either concatenated or averaged.
201#[derive(Debug, Clone)]
202pub struct Gat {
203    /// Model configuration
204    config: GatConfig,
205    /// Attention heads
206    heads: Vec<AttentionHead>,
207}
208
209impl Gat {
210    /// Create a new GAT model with the given configuration
211    pub fn new(config: GatConfig) -> Result<Self> {
212        if config.input_dim == 0 {
213            return Err(anyhow!("input_dim must be > 0"));
214        }
215        if config.num_heads == 0 {
216            return Err(anyhow!("num_heads must be > 0"));
217        }
218        if config.head_output_dim == 0 {
219            return Err(anyhow!("head_output_dim must be > 0"));
220        }
221
222        let mut rng = SimpleLcg::new(config.seed);
223        let heads = (0..config.num_heads)
224            .map(|_| {
225                AttentionHead::new(
226                    config.input_dim,
227                    config.head_output_dim,
228                    config.alpha,
229                    &mut rng,
230                )
231            })
232            .collect();
233
234        Ok(Self { config, heads })
235    }
236
237    /// Generate embeddings for all nodes using multi-head attention
238    pub fn embed(&self, graph: &GraphData) -> Result<GatEmbeddings> {
239        if graph.num_nodes() == 0 {
240            return Err(anyhow!("Graph has no nodes"));
241        }
242        if graph.feature_dim() != self.config.input_dim {
243            return Err(anyhow!(
244                "Graph feature_dim {} != GAT input_dim {}",
245                graph.feature_dim(),
246                self.config.input_dim
247            ));
248        }
249
250        let embeddings: Vec<Vec<f64>> = (0..graph.num_nodes())
251            .map(|node| self.forward_node(node, graph))
252            .collect();
253
254        let embeddings = if self.config.normalize_output {
255            embeddings.into_iter().map(|e| normalize_l2(&e)).collect()
256        } else {
257            embeddings
258        };
259
260        let output_dim = self.config.output_dim();
261        let num_nodes = graph.num_nodes();
262
263        Ok(GatEmbeddings {
264            embeddings,
265            config: self.config.clone(),
266            num_nodes,
267            dim: output_dim,
268        })
269    }
270
271    /// Compute embedding for a single node using all attention heads
272    fn forward_node(&self, node: usize, graph: &GraphData) -> Vec<f64> {
273        let node_feat = match graph.node_features.get(node) {
274            Some(f) => f.as_slice(),
275            None => return vec![0.0; self.config.output_dim()],
276        };
277
278        let neighbors = graph.neighbors(node);
279        let neighbor_feats: Vec<Vec<f64>> = neighbors
280            .iter()
281            .filter_map(|&n| graph.node_features.get(n).cloned())
282            .collect();
283
284        // Run each attention head
285        let head_outputs: Vec<Vec<f64>> = self
286            .heads
287            .iter()
288            .map(|head| head.forward(node_feat, &neighbor_feats))
289            .collect();
290
291        if self.config.concat_heads {
292            // Concatenate all head outputs
293            let mut concat = Vec::with_capacity(self.config.output_dim());
294            for head_out in &head_outputs {
295                concat.extend(head_out.iter().copied());
296            }
297            concat
298        } else {
299            // Average across heads
300            let dim = self.config.head_output_dim;
301            let mut avg = vec![0.0f64; dim];
302            for head_out in &head_outputs {
303                for (a, &v) in avg.iter_mut().zip(head_out.iter()) {
304                    *a += v;
305                }
306            }
307            let n = self.heads.len() as f64;
308            avg.iter_mut().for_each(|v| *v /= n);
309            avg
310        }
311    }
312}
313
314/// Output embeddings from GAT inference
315#[derive(Debug, Clone)]
316pub struct GatEmbeddings {
317    /// Embedding vectors indexed by node ID
318    pub embeddings: Vec<Vec<f64>>,
319    /// Configuration used
320    pub config: GatConfig,
321    /// Number of nodes
322    pub num_nodes: usize,
323    /// Embedding dimensionality
324    pub dim: usize,
325}
326
327impl GatEmbeddings {
328    /// Get embedding for a specific node
329    pub fn get(&self, node: usize) -> Option<&[f64]> {
330        self.embeddings.get(node).map(|v| v.as_slice())
331    }
332
333    /// Compute cosine similarity between two nodes
334    pub fn cosine_similarity(&self, a: usize, b: usize) -> Option<f64> {
335        let va = self.get(a)?;
336        let vb = self.get(b)?;
337        Some(cosine_similarity_vecs(va, vb))
338    }
339
340    /// Get the top-k most similar nodes to a given node
341    pub fn top_k_similar(&self, node: usize, k: usize) -> Vec<(usize, f64)> {
342        let query = match self.get(node) {
343            Some(v) => v,
344            None => return Vec::new(),
345        };
346
347        let mut similarities: Vec<(usize, f64)> = (0..self.num_nodes)
348            .filter(|&i| i != node)
349            .filter_map(|i| self.get(i).map(|v| (i, cosine_similarity_vecs(query, v))))
350            .collect();
351
352        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
353        similarities.truncate(k);
354        similarities
355    }
356}
357
358/// L2 normalize a vector
359fn normalize_l2(v: &[f64]) -> Vec<f64> {
360    let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
361    if norm < 1e-12 {
362        return v.to_vec();
363    }
364    v.iter().map(|x| x / norm).collect()
365}
366
367/// Softmax over a slice (re-exported for use in ab_test module)
368pub fn softmax(scores: &[f64]) -> Vec<f64> {
369    AttentionHead::softmax(scores)
370}
371
372/// Convert to EmbeddingError
373pub fn gat_err(msg: impl Into<String>) -> EmbeddingError {
374    EmbeddingError::Other(anyhow!(msg.into()))
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380
381    fn make_line_graph(n: usize, feat_dim: usize, seed: u64) -> GraphData {
382        let mut rng = SimpleLcg::new(seed);
383        let features: Vec<Vec<f64>> = (0..n)
384            .map(|_| (0..feat_dim).map(|_| rng.next_f64()).collect())
385            .collect();
386        let mut adjacency: Vec<Vec<usize>> = vec![Vec::new(); n];
387        for i in 0..n.saturating_sub(1) {
388            adjacency[i].push(i + 1);
389            adjacency[i + 1].push(i);
390        }
391        GraphData::new(features, adjacency).expect("line graph construction should succeed")
392    }
393
394    #[test]
395    fn test_gat_config_default() {
396        let config = GatConfig::default();
397        assert_eq!(config.num_heads, 8);
398        assert_eq!(config.head_output_dim, 8);
399        assert_eq!(config.output_dim(), 64); // concat: 8 * 8
400    }
401
402    #[test]
403    fn test_gat_config_avg() {
404        let config = GatConfig {
405            concat_heads: false,
406            num_heads: 4,
407            head_output_dim: 16,
408            ..Default::default()
409        };
410        assert_eq!(config.output_dim(), 16); // average: head_output_dim
411    }
412
413    #[test]
414    fn test_gat_embed_shape() {
415        let config = GatConfig {
416            input_dim: 8,
417            head_output_dim: 4,
418            num_heads: 2,
419            concat_heads: true,
420            normalize_output: false,
421            ..Default::default()
422        };
423        let model = Gat::new(config.clone()).expect("GAT construction should succeed");
424        let graph = make_line_graph(5, 8, 100);
425        let embeddings = model.embed(&graph).expect("embed should succeed");
426
427        assert_eq!(embeddings.num_nodes, 5);
428        assert_eq!(embeddings.dim, 8); // 2 heads * 4 per head
429        for i in 0..5 {
430            assert_eq!(embeddings.get(i).expect("embedding should exist").len(), 8);
431        }
432    }
433
434    #[test]
435    fn test_gat_embed_avg_heads() {
436        let config = GatConfig {
437            input_dim: 8,
438            head_output_dim: 4,
439            num_heads: 3,
440            concat_heads: false,
441            normalize_output: false,
442            ..Default::default()
443        };
444        let model = Gat::new(config.clone()).expect("GAT should construct");
445        let graph = make_line_graph(4, 8, 200);
446        let embeddings = model.embed(&graph).expect("embed should succeed");
447
448        assert_eq!(embeddings.dim, 4); // avg: head_output_dim
449        for i in 0..4 {
450            assert_eq!(embeddings.get(i).expect("embedding exists").len(), 4);
451        }
452    }
453
454    #[test]
455    fn test_gat_normalized_output() {
456        let config = GatConfig {
457            input_dim: 4,
458            head_output_dim: 4,
459            num_heads: 2,
460            concat_heads: false,
461            normalize_output: true,
462            ..Default::default()
463        };
464        let model = Gat::new(config).expect("GAT should construct");
465        let graph = make_line_graph(5, 4, 300);
466        let embeddings = model.embed(&graph).expect("embed should succeed");
467
468        for i in 0..5 {
469            let emb = embeddings.get(i).expect("embedding exists");
470            let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
471            // Either 0 (all ReLU killed) or ~1
472            assert!(norm <= 1.0 + 1e-6, "norm {} should be <= 1", norm);
473        }
474    }
475
476    #[test]
477    fn test_gat_cosine_similarity() {
478        let config = GatConfig {
479            input_dim: 4,
480            head_output_dim: 4,
481            num_heads: 1,
482            concat_heads: true,
483            normalize_output: false,
484            ..Default::default()
485        };
486        let model = Gat::new(config).expect("GAT should construct");
487        let graph = make_line_graph(5, 4, 400);
488        let embeddings = model.embed(&graph).expect("embed should succeed");
489
490        // Cosine similarity should be in [-1, 1]
491        for i in 0..5 {
492            for j in 0..5 {
493                if let Some(sim) = embeddings.cosine_similarity(i, j) {
494                    assert!(
495                        (-1.0 - 1e-6..=1.0 + 1e-6).contains(&sim),
496                        "cosine_similarity({}, {}) = {} out of range",
497                        i,
498                        j,
499                        sim
500                    );
501                }
502            }
503        }
504    }
505
506    #[test]
507    fn test_gat_top_k_similar() {
508        let config = GatConfig {
509            input_dim: 4,
510            head_output_dim: 4,
511            num_heads: 2,
512            concat_heads: true,
513            normalize_output: true,
514            ..Default::default()
515        };
516        let model = Gat::new(config).expect("GAT should construct");
517        let graph = make_line_graph(6, 4, 500);
518        let embeddings = model.embed(&graph).expect("embed should succeed");
519
520        let top3 = embeddings.top_k_similar(0, 3);
521        assert!(top3.len() <= 3);
522        // Results should be in descending similarity order
523        for window in top3.windows(2) {
524            assert!(
525                window[0].1 >= window[1].1 - 1e-10,
526                "top_k should be sorted descending"
527            );
528        }
529    }
530
531    #[test]
532    fn test_attention_head_softmax() {
533        // Verify softmax sums to 1
534        let scores = vec![1.0, 2.0, 3.0, 0.5, -1.0];
535        let weights = AttentionHead::softmax(&scores);
536        assert_eq!(weights.len(), scores.len());
537        let sum: f64 = weights.iter().sum();
538        assert!(
539            (sum - 1.0).abs() < 1e-10,
540            "softmax should sum to 1, got {}",
541            sum
542        );
543        // Larger scores should have larger weights
544        assert!(weights[2] > weights[1]);
545        assert!(weights[1] > weights[0]);
546    }
547
548    #[test]
549    fn test_gat_invalid_config() {
550        assert!(Gat::new(GatConfig {
551            num_heads: 0,
552            ..Default::default()
553        })
554        .is_err());
555        assert!(Gat::new(GatConfig {
556            input_dim: 0,
557            ..Default::default()
558        })
559        .is_err());
560        assert!(Gat::new(GatConfig {
561            head_output_dim: 0,
562            ..Default::default()
563        })
564        .is_err());
565    }
566
567    #[test]
568    fn test_gat_isolated_node() {
569        // A single isolated node should still produce an embedding
570        let config = GatConfig {
571            input_dim: 4,
572            head_output_dim: 4,
573            num_heads: 2,
574            concat_heads: true,
575            normalize_output: false,
576            ..Default::default()
577        };
578        let model = Gat::new(config).expect("GAT should construct");
579        let features = vec![vec![1.0, 0.5, -0.5, 0.2]];
580        let adjacency = vec![vec![]]; // no neighbors
581        let graph = GraphData::new(features, adjacency).expect("graph should construct");
582        let embeddings = model.embed(&graph).expect("should embed isolated node");
583        assert_eq!(embeddings.num_nodes, 1);
584        assert!(embeddings.get(0).is_some());
585    }
586}