1use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_tensor::Tensor;
21
22use crate::module::Module;
23use crate::parameter::Parameter;
24
25pub struct GCNConv {
53 weight: Parameter,
54 bias: Option<Parameter>,
55 in_features: usize,
56 out_features: usize,
57}
58
59impl GCNConv {
60 pub fn new(in_features: usize, out_features: usize) -> Self {
62 let scale = (2.0 / (in_features + out_features) as f32).sqrt();
64 let weight_data: Vec<f32> = (0..in_features * out_features)
65 .map(|i| {
66 let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
68 x * scale
69 })
70 .collect();
71
72 let weight = Parameter::named(
73 "weight",
74 Tensor::from_vec(weight_data, &[in_features, out_features])
75 .expect("tensor creation failed"),
76 true,
77 );
78
79 let bias_data = vec![0.0; out_features];
80 let bias = Some(Parameter::named(
81 "bias",
82 Tensor::from_vec(bias_data, &[out_features]).expect("tensor creation failed"),
83 true,
84 ));
85
86 Self {
87 weight,
88 bias,
89 in_features,
90 out_features,
91 }
92 }
93
94 pub fn without_bias(in_features: usize, out_features: usize) -> Self {
96 let scale = (2.0 / (in_features + out_features) as f32).sqrt();
97 let weight_data: Vec<f32> = (0..in_features * out_features)
98 .map(|i| {
99 let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
100 x * scale
101 })
102 .collect();
103
104 let weight = Parameter::named(
105 "weight",
106 Tensor::from_vec(weight_data, &[in_features, out_features])
107 .expect("tensor creation failed"),
108 true,
109 );
110
111 Self {
112 weight,
113 bias: None,
114 in_features,
115 out_features,
116 }
117 }
118
119 pub fn forward_graph(&self, x: &Variable, adj: &Variable) -> Variable {
128 let shape = x.shape();
129 assert!(
130 shape.len() == 3,
131 "GCNConv expects input shape (batch, nodes, features), got {:?}",
132 shape
133 );
134 assert_eq!(shape[2], self.in_features, "Input features mismatch");
135
136 let batch = shape[0];
137 let adj_shape = adj.shape();
138
139 let weight = self.weight.variable();
145
146 let mut per_sample: Vec<Variable> = Vec::with_capacity(batch);
147 for b in 0..batch {
148 let x_b = x.select(0, b);
150
151 let adj_b = if adj_shape.len() == 3 {
153 adj.select(0, b)
154 } else {
155 adj.clone()
156 };
157
158 let msg_b = adj_b.matmul(&x_b);
160
161 let mut out_b = msg_b.matmul(&weight);
163
164 if let Some(bias) = &self.bias {
166 out_b = out_b.add_var(&bias.variable());
167 }
168
169 per_sample.push(out_b.unsqueeze(0));
171 }
172
173 let refs: Vec<&Variable> = per_sample.iter().collect();
175 Variable::cat(&refs, 0)
176 }
177
178 pub fn in_features(&self) -> usize {
180 self.in_features
181 }
182
183 pub fn out_features(&self) -> usize {
185 self.out_features
186 }
187}
188
189impl Module for GCNConv {
190 fn forward(&self, input: &Variable) -> Variable {
193 let n = input.shape()[0];
195 let mut eye_data = vec![0.0f32; n * n];
196 for i in 0..n {
197 eye_data[i * n + i] = 1.0;
198 }
199 let adj = Variable::new(
200 axonml_tensor::Tensor::from_vec(eye_data, &[n, n])
201 .expect("identity matrix creation failed"),
202 false,
203 );
204 self.forward_graph(input, &adj)
205 }
206
207 fn parameters(&self) -> Vec<Parameter> {
208 let mut params = vec![self.weight.clone()];
209 if let Some(bias) = &self.bias {
210 params.push(bias.clone());
211 }
212 params
213 }
214
215 fn named_parameters(&self) -> HashMap<String, Parameter> {
216 let mut params = HashMap::new();
217 params.insert("weight".to_string(), self.weight.clone());
218 if let Some(bias) = &self.bias {
219 params.insert("bias".to_string(), bias.clone());
220 }
221 params
222 }
223
224 fn name(&self) -> &'static str {
225 "GCNConv"
226 }
227}
228
229pub struct GATConv {
249 w: Parameter,
250 attn_src: Parameter,
251 attn_dst: Parameter,
252 bias: Option<Parameter>,
253 in_features: usize,
254 out_features: usize,
255 num_heads: usize,
256 negative_slope: f32,
257}
258
259impl GATConv {
260 pub fn new(in_features: usize, out_features: usize, num_heads: usize) -> Self {
267 let total_out = out_features * num_heads;
268 let scale = (2.0 / (in_features + total_out) as f32).sqrt();
269
270 let w_data: Vec<f32> = (0..in_features * total_out)
271 .map(|i| {
272 let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
273 x * scale
274 })
275 .collect();
276
277 let w = Parameter::named(
278 "w",
279 Tensor::from_vec(w_data, &[in_features, total_out]).expect("tensor creation failed"),
280 true,
281 );
282
283 let attn_scale = (1.0 / out_features as f32).sqrt();
285 let attn_src_data: Vec<f32> = (0..total_out)
286 .map(|i| {
287 let x = ((i as f32 * 0.723_606_8) % 1.0) * 2.0 - 1.0;
288 x * attn_scale
289 })
290 .collect();
291 let attn_dst_data: Vec<f32> = (0..total_out)
292 .map(|i| {
293 let x = ((i as f32 * 0.381_966_02) % 1.0) * 2.0 - 1.0;
294 x * attn_scale
295 })
296 .collect();
297
298 let attn_src = Parameter::named(
299 "attn_src",
300 Tensor::from_vec(attn_src_data, &[num_heads, out_features])
301 .expect("tensor creation failed"),
302 true,
303 );
304
305 let attn_dst = Parameter::named(
306 "attn_dst",
307 Tensor::from_vec(attn_dst_data, &[num_heads, out_features])
308 .expect("tensor creation failed"),
309 true,
310 );
311
312 let bias_data = vec![0.0; total_out];
313 let bias = Some(Parameter::named(
314 "bias",
315 Tensor::from_vec(bias_data, &[total_out]).expect("tensor creation failed"),
316 true,
317 ));
318
319 Self {
320 w,
321 attn_src,
322 attn_dst,
323 bias,
324 in_features,
325 out_features,
326 num_heads,
327 negative_slope: 0.2,
328 }
329 }
330
331 pub fn forward_graph(&self, x: &Variable, adj: &Variable) -> Variable {
340 let shape = x.shape();
341 assert!(
342 shape.len() == 3,
343 "GATConv expects (batch, nodes, features), got {:?}",
344 shape
345 );
346
347 let batch = shape[0];
348 let nodes = shape[1];
349 let total_out = self.out_features * self.num_heads;
350
351 let x_data = x.data().to_vec();
352 let adj_data = adj.data().to_vec();
353 let w_data = self.w.data().to_vec();
354 let attn_src_data = self.attn_src.data().to_vec();
355 let attn_dst_data = self.attn_dst.data().to_vec();
356
357 let adj_nodes = if adj.shape().len() == 3 {
358 adj.shape()[1]
359 } else {
360 adj.shape()[0]
361 };
362 assert_eq!(adj_nodes, nodes, "Adjacency matrix size mismatch");
363
364 let mut output = vec![0.0f32; batch * nodes * total_out];
365
366 for b in 0..batch {
367 let mut h = vec![0.0f32; nodes * total_out];
369 for i in 0..nodes {
370 let x_off = (b * nodes + i) * self.in_features;
371 for o in 0..total_out {
372 let mut val = 0.0;
373 for f in 0..self.in_features {
374 val += x_data[x_off + f] * w_data[f * total_out + o];
375 }
376 h[i * total_out + o] = val;
377 }
378 }
379
380 let adj_off = if adj.shape().len() == 3 {
382 b * nodes * nodes
383 } else {
384 0
385 };
386
387 for head in 0..self.num_heads {
388 let head_off = head * self.out_features;
389
390 let mut attn_scores = vec![f32::NEG_INFINITY; nodes * nodes];
393
394 for i in 0..nodes {
395 let mut src_score = 0.0;
397 for f in 0..self.out_features {
398 src_score += h[i * total_out + head_off + f]
399 * attn_src_data[head * self.out_features + f];
400 }
401
402 for j in 0..nodes {
403 let a_ij = adj_data[adj_off + i * nodes + j];
404 if a_ij != 0.0 {
405 let mut dst_score = 0.0;
406 for f in 0..self.out_features {
407 dst_score += h[j * total_out + head_off + f]
408 * attn_dst_data[head * self.out_features + f];
409 }
410
411 let e = src_score + dst_score;
412 let e = if e > 0.0 { e } else { e * self.negative_slope };
414 attn_scores[i * nodes + j] = e;
415 }
416 }
417 }
418
419 for i in 0..nodes {
421 let row_start = i * nodes;
422 let row_end = row_start + nodes;
423 let row = &attn_scores[row_start..row_end];
424
425 let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
426 if max_val == f32::NEG_INFINITY {
427 continue; }
429
430 let mut sum_exp = 0.0f32;
431 let mut exps = vec![0.0; nodes];
432 for j in 0..nodes {
433 if row[j] > f32::NEG_INFINITY {
434 exps[j] = (row[j] - max_val).exp();
435 sum_exp += exps[j];
436 }
437 }
438
439 let out_off = (b * nodes + i) * total_out + head_off;
441 for j in 0..nodes {
442 if exps[j] > 0.0 {
443 let alpha = exps[j] / sum_exp;
444 for f in 0..self.out_features {
445 output[out_off + f] += alpha * h[j * total_out + head_off + f];
446 }
447 }
448 }
449 }
450 }
451 }
452
453 if let Some(bias) = &self.bias {
455 let bias_data = bias.data().to_vec();
456 for b in 0..batch {
457 for i in 0..nodes {
458 let offset = (b * nodes + i) * total_out;
459 for o in 0..total_out {
460 output[offset + o] += bias_data[o];
461 }
462 }
463 }
464 }
465
466 Variable::new(
467 Tensor::from_vec(output, &[batch, nodes, total_out]).expect("tensor creation failed"),
468 x.requires_grad(),
469 )
470 }
471
472 pub fn total_out_features(&self) -> usize {
474 self.out_features * self.num_heads
475 }
476}
477
478impl Module for GATConv {
479 fn forward(&self, input: &Variable) -> Variable {
480 let n = input.shape()[0];
482 let mut eye_data = vec![0.0f32; n * n];
483 for i in 0..n {
484 eye_data[i * n + i] = 1.0;
485 }
486 let adj = Variable::new(
487 axonml_tensor::Tensor::from_vec(eye_data, &[n, n])
488 .expect("identity matrix creation failed"),
489 false,
490 );
491 self.forward_graph(input, &adj)
492 }
493
494 fn parameters(&self) -> Vec<Parameter> {
495 let mut params = vec![self.w.clone(), self.attn_src.clone(), self.attn_dst.clone()];
496 if let Some(bias) = &self.bias {
497 params.push(bias.clone());
498 }
499 params
500 }
501
502 fn named_parameters(&self) -> HashMap<String, Parameter> {
503 let mut params = HashMap::new();
504 params.insert("w".to_string(), self.w.clone());
505 params.insert("attn_src".to_string(), self.attn_src.clone());
506 params.insert("attn_dst".to_string(), self.attn_dst.clone());
507 if let Some(bias) = &self.bias {
508 params.insert("bias".to_string(), bias.clone());
509 }
510 params
511 }
512
513 fn name(&self) -> &'static str {
514 "GATConv"
515 }
516}
517
518#[cfg(test)]
523mod tests {
524 use super::*;
525
526 #[test]
527 fn test_gcn_conv_shape() {
528 let gcn = GCNConv::new(72, 128);
529 let x = Variable::new(
530 Tensor::from_vec(vec![1.0; 2 * 7 * 72], &[2, 7, 72]).expect("tensor creation failed"),
531 false,
532 );
533 let adj = Variable::new(
534 Tensor::from_vec(vec![1.0; 7 * 7], &[7, 7]).expect("tensor creation failed"),
535 false,
536 );
537 let output = gcn.forward_graph(&x, &adj);
538 assert_eq!(output.shape(), vec![2, 7, 128]);
539 }
540
541 #[test]
542 fn test_gcn_conv_identity_adjacency() {
543 let gcn = GCNConv::new(4, 8);
545 let x = Variable::new(
546 Tensor::from_vec(vec![1.0; 1 * 3 * 4], &[1, 3, 4]).expect("tensor creation failed"),
547 false,
548 );
549
550 let mut adj_data = vec![0.0; 9];
552 adj_data[0] = 1.0; adj_data[4] = 1.0; adj_data[8] = 1.0; let adj = Variable::new(
556 Tensor::from_vec(adj_data, &[3, 3]).expect("tensor creation failed"),
557 false,
558 );
559
560 let output = gcn.forward_graph(&x, &adj);
561 assert_eq!(output.shape(), vec![1, 3, 8]);
562
563 let data = output.data().to_vec();
565 for i in 0..3 {
566 for f in 0..8 {
567 assert!(
568 (data[i * 8 + f] - data[f]).abs() < 1e-6,
569 "Node outputs should be identical with identity adj and same input"
570 );
571 }
572 }
573 }
574
575 #[test]
576 fn test_gcn_conv_parameters() {
577 let gcn = GCNConv::new(16, 32);
578 let params = gcn.parameters();
579 assert_eq!(params.len(), 2); let total_params: usize = params.iter().map(|p| p.numel()).sum();
582 assert_eq!(total_params, 16 * 32 + 32); }
584
585 #[test]
586 fn test_gcn_conv_no_bias() {
587 let gcn = GCNConv::without_bias(16, 32);
588 let params = gcn.parameters();
589 assert_eq!(params.len(), 1); }
591
592 #[test]
593 fn test_gcn_conv_named_parameters() {
594 let gcn = GCNConv::new(16, 32);
595 let params = gcn.named_parameters();
596 assert!(params.contains_key("weight"));
597 assert!(params.contains_key("bias"));
598 }
599
600 #[test]
601 fn test_gat_conv_shape() {
602 let gat = GATConv::new(72, 32, 4); let x = Variable::new(
604 Tensor::from_vec(vec![1.0; 2 * 7 * 72], &[2, 7, 72]).expect("tensor creation failed"),
605 false,
606 );
607 let adj = Variable::new(
608 Tensor::from_vec(vec![1.0; 7 * 7], &[7, 7]).expect("tensor creation failed"),
609 false,
610 );
611 let output = gat.forward_graph(&x, &adj);
612 assert_eq!(output.shape(), vec![2, 7, 128]); }
614
615 #[test]
616 fn test_gat_conv_single_head() {
617 let gat = GATConv::new(16, 8, 1);
618 let x = Variable::new(
619 Tensor::from_vec(vec![1.0; 1 * 5 * 16], &[1, 5, 16]).expect("tensor creation failed"),
620 false,
621 );
622 let adj = Variable::new(
623 Tensor::from_vec(vec![1.0; 5 * 5], &[5, 5]).expect("tensor creation failed"),
624 false,
625 );
626 let output = gat.forward_graph(&x, &adj);
627 assert_eq!(output.shape(), vec![1, 5, 8]);
628 }
629
630 #[test]
631 fn test_gat_conv_parameters() {
632 let gat = GATConv::new(16, 8, 4);
633 let params = gat.parameters();
634 assert_eq!(params.len(), 4); let named = gat.named_parameters();
637 assert!(named.contains_key("w"));
638 assert!(named.contains_key("attn_src"));
639 assert!(named.contains_key("attn_dst"));
640 assert!(named.contains_key("bias"));
641 }
642
643 #[test]
644 fn test_gat_conv_total_output() {
645 let gat = GATConv::new(16, 32, 4);
646 assert_eq!(gat.total_out_features(), 128);
647 }
648
649 #[test]
650 fn test_gcn_zero_adjacency() {
651 let gcn = GCNConv::new(4, 4);
653 let x = Variable::new(
654 Tensor::from_vec(vec![99.0; 1 * 3 * 4], &[1, 3, 4]).expect("tensor creation failed"),
655 false,
656 );
657 let adj = Variable::new(
658 Tensor::from_vec(vec![0.0; 9], &[3, 3]).expect("tensor creation failed"),
659 false,
660 );
661 let output = gcn.forward_graph(&x, &adj);
662
663 let data = output.data().to_vec();
665 for val in &data {
666 assert!(
667 val.abs() < 1e-6,
668 "Zero adjacency should zero out message passing"
669 );
670 }
671 }
672}