1use crate::EmbeddingError;
11use anyhow::{anyhow, Result};
12use serde::{Deserialize, Serialize};
13
14use super::graphsage::{cosine_similarity_vecs, dot_product, GraphData, SimpleLcg};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct GatConfig {
19 pub input_dim: usize,
21 pub head_output_dim: usize,
23 pub num_heads: usize,
25 pub dropout: f64,
27 pub alpha: f64,
29 pub concat_heads: bool,
31 pub normalize_output: bool,
33 pub seed: u64,
35}
36
37impl Default for GatConfig {
38 fn default() -> Self {
39 Self {
40 input_dim: 64,
41 head_output_dim: 8,
42 num_heads: 8,
43 dropout: 0.6,
44 alpha: 0.2,
45 concat_heads: true,
46 normalize_output: true,
47 seed: 42,
48 }
49 }
50}
51
52impl GatConfig {
53 pub fn output_dim(&self) -> usize {
55 if self.concat_heads {
56 self.head_output_dim * self.num_heads
57 } else {
58 self.head_output_dim
59 }
60 }
61}
62
63#[derive(Debug, Clone)]
68struct AttentionHead {
69 w: Vec<Vec<f64>>,
71 a_src: Vec<f64>,
73 a_dst: Vec<f64>,
75 output_dim: usize,
77 alpha: f64,
79}
80
81impl AttentionHead {
82 fn new(input_dim: usize, output_dim: usize, alpha: f64, rng: &mut SimpleLcg) -> Self {
84 let scale = (6.0 / (input_dim + output_dim) as f64).sqrt();
85 let w = (0..output_dim)
86 .map(|_| (0..input_dim).map(|_| rng.next_f64_range(scale)).collect())
87 .collect();
88
89 let attn_scale = (2.0 / output_dim as f64).sqrt();
90 let a_src = (0..output_dim)
91 .map(|_| rng.next_f64_range(attn_scale))
92 .collect();
93 let a_dst = (0..output_dim)
94 .map(|_| rng.next_f64_range(attn_scale))
95 .collect();
96
97 Self {
98 w,
99 a_src,
100 a_dst,
101 output_dim,
102 alpha,
103 }
104 }
105
106 fn transform(&self, feat: &[f64]) -> Vec<f64> {
108 let mut out = vec![0.0f64; self.output_dim];
109 for (i, row) in self.w.iter().enumerate() {
110 for (j, &wv) in row.iter().enumerate() {
111 if j < feat.len() {
112 out[i] += wv * feat[j];
113 }
114 }
115 }
116 out
117 }
118
119 fn attention_coeff(&self, h_i: &[f64], h_j: &[f64]) -> f64 {
123 let src_score = dot_product(&self.a_src, h_i);
124 let dst_score = dot_product(&self.a_dst, h_j);
125 Self::leaky_relu(src_score + dst_score, self.alpha)
126 }
127
128 fn leaky_relu(x: f64, alpha: f64) -> f64 {
130 if x >= 0.0 {
131 x
132 } else {
133 alpha * x
134 }
135 }
136
137 fn softmax(scores: &[f64]) -> Vec<f64> {
139 if scores.is_empty() {
140 return Vec::new();
141 }
142 let max_score = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
144 let exps: Vec<f64> = scores.iter().map(|&s| (s - max_score).exp()).collect();
145 let sum: f64 = exps.iter().sum();
146 if sum < 1e-12 {
147 return vec![1.0 / scores.len() as f64; scores.len()];
149 }
150 exps.iter().map(|e| e / sum).collect()
151 }
152
153 fn forward(&self, node_feat: &[f64], neighbor_feats: &[Vec<f64>]) -> Vec<f64> {
158 let h_self = self.transform(node_feat);
160
161 if neighbor_feats.is_empty() {
162 return h_self;
164 }
165
166 let neighbor_transformed: Vec<Vec<f64>> =
167 neighbor_feats.iter().map(|f| self.transform(f)).collect();
168
169 let mut all_feats = vec![&h_self as &Vec<f64>];
171 all_feats.extend(neighbor_transformed.iter());
172
173 let scores: Vec<f64> = all_feats
174 .iter()
175 .map(|h_j| self.attention_coeff(&h_self, h_j))
176 .collect();
177
178 let weights = Self::softmax(&scores);
179
180 let mut output = vec![0.0f64; self.output_dim];
182 for (weight, h_j) in weights.iter().zip(all_feats.iter()) {
183 for (o, &v) in output.iter_mut().zip(h_j.iter()) {
184 *o += weight * v;
185 }
186 }
187
188 output
190 .into_iter()
191 .map(|x| Self::leaky_relu(x, self.alpha))
192 .collect()
193 }
194}
195
196#[derive(Debug, Clone)]
202pub struct Gat {
203 config: GatConfig,
205 heads: Vec<AttentionHead>,
207}
208
209impl Gat {
210 pub fn new(config: GatConfig) -> Result<Self> {
212 if config.input_dim == 0 {
213 return Err(anyhow!("input_dim must be > 0"));
214 }
215 if config.num_heads == 0 {
216 return Err(anyhow!("num_heads must be > 0"));
217 }
218 if config.head_output_dim == 0 {
219 return Err(anyhow!("head_output_dim must be > 0"));
220 }
221
222 let mut rng = SimpleLcg::new(config.seed);
223 let heads = (0..config.num_heads)
224 .map(|_| {
225 AttentionHead::new(
226 config.input_dim,
227 config.head_output_dim,
228 config.alpha,
229 &mut rng,
230 )
231 })
232 .collect();
233
234 Ok(Self { config, heads })
235 }
236
237 pub fn embed(&self, graph: &GraphData) -> Result<GatEmbeddings> {
239 if graph.num_nodes() == 0 {
240 return Err(anyhow!("Graph has no nodes"));
241 }
242 if graph.feature_dim() != self.config.input_dim {
243 return Err(anyhow!(
244 "Graph feature_dim {} != GAT input_dim {}",
245 graph.feature_dim(),
246 self.config.input_dim
247 ));
248 }
249
250 let embeddings: Vec<Vec<f64>> = (0..graph.num_nodes())
251 .map(|node| self.forward_node(node, graph))
252 .collect();
253
254 let embeddings = if self.config.normalize_output {
255 embeddings.into_iter().map(|e| normalize_l2(&e)).collect()
256 } else {
257 embeddings
258 };
259
260 let output_dim = self.config.output_dim();
261 let num_nodes = graph.num_nodes();
262
263 Ok(GatEmbeddings {
264 embeddings,
265 config: self.config.clone(),
266 num_nodes,
267 dim: output_dim,
268 })
269 }
270
271 fn forward_node(&self, node: usize, graph: &GraphData) -> Vec<f64> {
273 let node_feat = match graph.node_features.get(node) {
274 Some(f) => f.as_slice(),
275 None => return vec![0.0; self.config.output_dim()],
276 };
277
278 let neighbors = graph.neighbors(node);
279 let neighbor_feats: Vec<Vec<f64>> = neighbors
280 .iter()
281 .filter_map(|&n| graph.node_features.get(n).cloned())
282 .collect();
283
284 let head_outputs: Vec<Vec<f64>> = self
286 .heads
287 .iter()
288 .map(|head| head.forward(node_feat, &neighbor_feats))
289 .collect();
290
291 if self.config.concat_heads {
292 let mut concat = Vec::with_capacity(self.config.output_dim());
294 for head_out in &head_outputs {
295 concat.extend(head_out.iter().copied());
296 }
297 concat
298 } else {
299 let dim = self.config.head_output_dim;
301 let mut avg = vec![0.0f64; dim];
302 for head_out in &head_outputs {
303 for (a, &v) in avg.iter_mut().zip(head_out.iter()) {
304 *a += v;
305 }
306 }
307 let n = self.heads.len() as f64;
308 avg.iter_mut().for_each(|v| *v /= n);
309 avg
310 }
311 }
312}
313
314#[derive(Debug, Clone)]
316pub struct GatEmbeddings {
317 pub embeddings: Vec<Vec<f64>>,
319 pub config: GatConfig,
321 pub num_nodes: usize,
323 pub dim: usize,
325}
326
327impl GatEmbeddings {
328 pub fn get(&self, node: usize) -> Option<&[f64]> {
330 self.embeddings.get(node).map(|v| v.as_slice())
331 }
332
333 pub fn cosine_similarity(&self, a: usize, b: usize) -> Option<f64> {
335 let va = self.get(a)?;
336 let vb = self.get(b)?;
337 Some(cosine_similarity_vecs(va, vb))
338 }
339
340 pub fn top_k_similar(&self, node: usize, k: usize) -> Vec<(usize, f64)> {
342 let query = match self.get(node) {
343 Some(v) => v,
344 None => return Vec::new(),
345 };
346
347 let mut similarities: Vec<(usize, f64)> = (0..self.num_nodes)
348 .filter(|&i| i != node)
349 .filter_map(|i| self.get(i).map(|v| (i, cosine_similarity_vecs(query, v))))
350 .collect();
351
352 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
353 similarities.truncate(k);
354 similarities
355 }
356}
357
358fn normalize_l2(v: &[f64]) -> Vec<f64> {
360 let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
361 if norm < 1e-12 {
362 return v.to_vec();
363 }
364 v.iter().map(|x| x / norm).collect()
365}
366
367pub fn softmax(scores: &[f64]) -> Vec<f64> {
369 AttentionHead::softmax(scores)
370}
371
372pub fn gat_err(msg: impl Into<String>) -> EmbeddingError {
374 EmbeddingError::Other(anyhow!(msg.into()))
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380
381 fn make_line_graph(n: usize, feat_dim: usize, seed: u64) -> GraphData {
382 let mut rng = SimpleLcg::new(seed);
383 let features: Vec<Vec<f64>> = (0..n)
384 .map(|_| (0..feat_dim).map(|_| rng.next_f64()).collect())
385 .collect();
386 let mut adjacency: Vec<Vec<usize>> = vec![Vec::new(); n];
387 for i in 0..n.saturating_sub(1) {
388 adjacency[i].push(i + 1);
389 adjacency[i + 1].push(i);
390 }
391 GraphData::new(features, adjacency).expect("line graph construction should succeed")
392 }
393
394 #[test]
395 fn test_gat_config_default() {
396 let config = GatConfig::default();
397 assert_eq!(config.num_heads, 8);
398 assert_eq!(config.head_output_dim, 8);
399 assert_eq!(config.output_dim(), 64); }
401
402 #[test]
403 fn test_gat_config_avg() {
404 let config = GatConfig {
405 concat_heads: false,
406 num_heads: 4,
407 head_output_dim: 16,
408 ..Default::default()
409 };
410 assert_eq!(config.output_dim(), 16); }
412
413 #[test]
414 fn test_gat_embed_shape() {
415 let config = GatConfig {
416 input_dim: 8,
417 head_output_dim: 4,
418 num_heads: 2,
419 concat_heads: true,
420 normalize_output: false,
421 ..Default::default()
422 };
423 let model = Gat::new(config.clone()).expect("GAT construction should succeed");
424 let graph = make_line_graph(5, 8, 100);
425 let embeddings = model.embed(&graph).expect("embed should succeed");
426
427 assert_eq!(embeddings.num_nodes, 5);
428 assert_eq!(embeddings.dim, 8); for i in 0..5 {
430 assert_eq!(embeddings.get(i).expect("embedding should exist").len(), 8);
431 }
432 }
433
434 #[test]
435 fn test_gat_embed_avg_heads() {
436 let config = GatConfig {
437 input_dim: 8,
438 head_output_dim: 4,
439 num_heads: 3,
440 concat_heads: false,
441 normalize_output: false,
442 ..Default::default()
443 };
444 let model = Gat::new(config.clone()).expect("GAT should construct");
445 let graph = make_line_graph(4, 8, 200);
446 let embeddings = model.embed(&graph).expect("embed should succeed");
447
448 assert_eq!(embeddings.dim, 4); for i in 0..4 {
450 assert_eq!(embeddings.get(i).expect("embedding exists").len(), 4);
451 }
452 }
453
454 #[test]
455 fn test_gat_normalized_output() {
456 let config = GatConfig {
457 input_dim: 4,
458 head_output_dim: 4,
459 num_heads: 2,
460 concat_heads: false,
461 normalize_output: true,
462 ..Default::default()
463 };
464 let model = Gat::new(config).expect("GAT should construct");
465 let graph = make_line_graph(5, 4, 300);
466 let embeddings = model.embed(&graph).expect("embed should succeed");
467
468 for i in 0..5 {
469 let emb = embeddings.get(i).expect("embedding exists");
470 let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
471 assert!(norm <= 1.0 + 1e-6, "norm {} should be <= 1", norm);
473 }
474 }
475
476 #[test]
477 fn test_gat_cosine_similarity() {
478 let config = GatConfig {
479 input_dim: 4,
480 head_output_dim: 4,
481 num_heads: 1,
482 concat_heads: true,
483 normalize_output: false,
484 ..Default::default()
485 };
486 let model = Gat::new(config).expect("GAT should construct");
487 let graph = make_line_graph(5, 4, 400);
488 let embeddings = model.embed(&graph).expect("embed should succeed");
489
490 for i in 0..5 {
492 for j in 0..5 {
493 if let Some(sim) = embeddings.cosine_similarity(i, j) {
494 assert!(
495 (-1.0 - 1e-6..=1.0 + 1e-6).contains(&sim),
496 "cosine_similarity({}, {}) = {} out of range",
497 i,
498 j,
499 sim
500 );
501 }
502 }
503 }
504 }
505
506 #[test]
507 fn test_gat_top_k_similar() {
508 let config = GatConfig {
509 input_dim: 4,
510 head_output_dim: 4,
511 num_heads: 2,
512 concat_heads: true,
513 normalize_output: true,
514 ..Default::default()
515 };
516 let model = Gat::new(config).expect("GAT should construct");
517 let graph = make_line_graph(6, 4, 500);
518 let embeddings = model.embed(&graph).expect("embed should succeed");
519
520 let top3 = embeddings.top_k_similar(0, 3);
521 assert!(top3.len() <= 3);
522 for window in top3.windows(2) {
524 assert!(
525 window[0].1 >= window[1].1 - 1e-10,
526 "top_k should be sorted descending"
527 );
528 }
529 }
530
531 #[test]
532 fn test_attention_head_softmax() {
533 let scores = vec![1.0, 2.0, 3.0, 0.5, -1.0];
535 let weights = AttentionHead::softmax(&scores);
536 assert_eq!(weights.len(), scores.len());
537 let sum: f64 = weights.iter().sum();
538 assert!(
539 (sum - 1.0).abs() < 1e-10,
540 "softmax should sum to 1, got {}",
541 sum
542 );
543 assert!(weights[2] > weights[1]);
545 assert!(weights[1] > weights[0]);
546 }
547
548 #[test]
549 fn test_gat_invalid_config() {
550 assert!(Gat::new(GatConfig {
551 num_heads: 0,
552 ..Default::default()
553 })
554 .is_err());
555 assert!(Gat::new(GatConfig {
556 input_dim: 0,
557 ..Default::default()
558 })
559 .is_err());
560 assert!(Gat::new(GatConfig {
561 head_output_dim: 0,
562 ..Default::default()
563 })
564 .is_err());
565 }
566
567 #[test]
568 fn test_gat_isolated_node() {
569 let config = GatConfig {
571 input_dim: 4,
572 head_output_dim: 4,
573 num_heads: 2,
574 concat_heads: true,
575 normalize_output: false,
576 ..Default::default()
577 };
578 let model = Gat::new(config).expect("GAT should construct");
579 let features = vec![vec![1.0, 0.5, -0.5, 0.2]];
580 let adjacency = vec![vec![]]; let graph = GraphData::new(features, adjacency).expect("graph should construct");
582 let embeddings = model.embed(&graph).expect("should embed isolated node");
583 assert_eq!(embeddings.num_nodes, 1);
584 assert!(embeddings.get(0).is_some());
585 }
586}