1use std::collections::HashMap;
25
26use axonml_autograd::Variable;
27use axonml_tensor::Tensor;
28
29use crate::module::Module;
30use crate::parameter::Parameter;
31
32pub struct GCNConv {
60 weight: Parameter,
61 bias: Option<Parameter>,
62 in_features: usize,
63 out_features: usize,
64}
65
66impl GCNConv {
67 pub fn new(in_features: usize, out_features: usize) -> Self {
69 let scale = (2.0 / (in_features + out_features) as f32).sqrt();
71 let weight_data: Vec<f32> = (0..in_features * out_features)
72 .map(|i| {
73 let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
75 x * scale
76 })
77 .collect();
78
79 let weight = Parameter::named(
80 "weight",
81 Tensor::from_vec(weight_data, &[in_features, out_features])
82 .expect("tensor creation failed"),
83 true,
84 );
85
86 let bias_data = vec![0.0; out_features];
87 let bias = Some(Parameter::named(
88 "bias",
89 Tensor::from_vec(bias_data, &[out_features]).expect("tensor creation failed"),
90 true,
91 ));
92
93 Self {
94 weight,
95 bias,
96 in_features,
97 out_features,
98 }
99 }
100
101 pub fn without_bias(in_features: usize, out_features: usize) -> Self {
103 let scale = (2.0 / (in_features + out_features) as f32).sqrt();
104 let weight_data: Vec<f32> = (0..in_features * out_features)
105 .map(|i| {
106 let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
107 x * scale
108 })
109 .collect();
110
111 let weight = Parameter::named(
112 "weight",
113 Tensor::from_vec(weight_data, &[in_features, out_features])
114 .expect("tensor creation failed"),
115 true,
116 );
117
118 Self {
119 weight,
120 bias: None,
121 in_features,
122 out_features,
123 }
124 }
125
126 pub fn forward_graph(&self, x: &Variable, adj: &Variable) -> Variable {
135 let shape = x.shape();
136 assert!(
137 shape.len() == 3,
138 "GCNConv expects input shape (batch, nodes, features), got {:?}",
139 shape
140 );
141 assert_eq!(shape[2], self.in_features, "Input features mismatch");
142
143 let batch = shape[0];
144 let adj_shape = adj.shape();
145
146 let weight = self.weight.variable();
152
153 let mut per_sample: Vec<Variable> = Vec::with_capacity(batch);
154 for b in 0..batch {
155 let x_b = x.select(0, b);
157
158 let adj_b = if adj_shape.len() == 3 {
160 adj.select(0, b)
161 } else {
162 adj.clone()
163 };
164
165 let msg_b = adj_b.matmul(&x_b);
167
168 let mut out_b = msg_b.matmul(&weight);
170
171 if let Some(bias) = &self.bias {
173 out_b = out_b.add_var(&bias.variable());
174 }
175
176 per_sample.push(out_b.unsqueeze(0));
178 }
179
180 let refs: Vec<&Variable> = per_sample.iter().collect();
182 Variable::cat(&refs, 0)
183 }
184
185 pub fn in_features(&self) -> usize {
187 self.in_features
188 }
189
190 pub fn out_features(&self) -> usize {
192 self.out_features
193 }
194}
195
196impl Module for GCNConv {
197 fn forward(&self, input: &Variable) -> Variable {
200 let n = input.shape()[0];
202 let mut eye_data = vec![0.0f32; n * n];
203 for i in 0..n {
204 eye_data[i * n + i] = 1.0;
205 }
206 let adj = Variable::new(
207 axonml_tensor::Tensor::from_vec(eye_data, &[n, n])
208 .expect("identity matrix creation failed"),
209 false,
210 );
211 self.forward_graph(input, &adj)
212 }
213
214 fn parameters(&self) -> Vec<Parameter> {
215 let mut params = vec![self.weight.clone()];
216 if let Some(bias) = &self.bias {
217 params.push(bias.clone());
218 }
219 params
220 }
221
222 fn named_parameters(&self) -> HashMap<String, Parameter> {
223 let mut params = HashMap::new();
224 params.insert("weight".to_string(), self.weight.clone());
225 if let Some(bias) = &self.bias {
226 params.insert("bias".to_string(), bias.clone());
227 }
228 params
229 }
230
231 fn name(&self) -> &'static str {
232 "GCNConv"
233 }
234}
235
236pub struct GATConv {
256 w: Parameter,
257 attn_src: Parameter,
258 attn_dst: Parameter,
259 bias: Option<Parameter>,
260 in_features: usize,
261 out_features: usize,
262 num_heads: usize,
263 negative_slope: f32,
264}
265
266impl GATConv {
267 pub fn new(in_features: usize, out_features: usize, num_heads: usize) -> Self {
274 let total_out = out_features * num_heads;
275 let scale = (2.0 / (in_features + total_out) as f32).sqrt();
276
277 let w_data: Vec<f32> = (0..in_features * total_out)
278 .map(|i| {
279 let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
280 x * scale
281 })
282 .collect();
283
284 let w = Parameter::named(
285 "w",
286 Tensor::from_vec(w_data, &[in_features, total_out]).expect("tensor creation failed"),
287 true,
288 );
289
290 let attn_scale = (1.0 / out_features as f32).sqrt();
292 let attn_src_data: Vec<f32> = (0..total_out)
293 .map(|i| {
294 let x = ((i as f32 * 0.723_606_8) % 1.0) * 2.0 - 1.0;
295 x * attn_scale
296 })
297 .collect();
298 let attn_dst_data: Vec<f32> = (0..total_out)
299 .map(|i| {
300 let x = ((i as f32 * 0.381_966_02) % 1.0) * 2.0 - 1.0;
301 x * attn_scale
302 })
303 .collect();
304
305 let attn_src = Parameter::named(
306 "attn_src",
307 Tensor::from_vec(attn_src_data, &[num_heads, out_features])
308 .expect("tensor creation failed"),
309 true,
310 );
311
312 let attn_dst = Parameter::named(
313 "attn_dst",
314 Tensor::from_vec(attn_dst_data, &[num_heads, out_features])
315 .expect("tensor creation failed"),
316 true,
317 );
318
319 let bias_data = vec![0.0; total_out];
320 let bias = Some(Parameter::named(
321 "bias",
322 Tensor::from_vec(bias_data, &[total_out]).expect("tensor creation failed"),
323 true,
324 ));
325
326 Self {
327 w,
328 attn_src,
329 attn_dst,
330 bias,
331 in_features,
332 out_features,
333 num_heads,
334 negative_slope: 0.2,
335 }
336 }
337
338 pub fn forward_graph(&self, x: &Variable, adj: &Variable) -> Variable {
347 let shape = x.shape();
348 assert!(
349 shape.len() == 3,
350 "GATConv expects (batch, nodes, features), got {:?}",
351 shape
352 );
353
354 let batch = shape[0];
355 let nodes = shape[1];
356 let total_out = self.out_features * self.num_heads;
357
358 let x_data = x.data().to_vec();
359 let adj_data = adj.data().to_vec();
360 let w_data = self.w.data().to_vec();
361 let attn_src_data = self.attn_src.data().to_vec();
362 let attn_dst_data = self.attn_dst.data().to_vec();
363
364 let adj_nodes = if adj.shape().len() == 3 {
365 adj.shape()[1]
366 } else {
367 adj.shape()[0]
368 };
369 assert_eq!(adj_nodes, nodes, "Adjacency matrix size mismatch");
370
371 let mut output = vec![0.0f32; batch * nodes * total_out];
372
373 for b in 0..batch {
374 let mut h = vec![0.0f32; nodes * total_out];
376 for i in 0..nodes {
377 let x_off = (b * nodes + i) * self.in_features;
378 for o in 0..total_out {
379 let mut val = 0.0;
380 for f in 0..self.in_features {
381 val += x_data[x_off + f] * w_data[f * total_out + o];
382 }
383 h[i * total_out + o] = val;
384 }
385 }
386
387 let adj_off = if adj.shape().len() == 3 {
389 b * nodes * nodes
390 } else {
391 0
392 };
393
394 for head in 0..self.num_heads {
395 let head_off = head * self.out_features;
396
397 let mut attn_scores = vec![f32::NEG_INFINITY; nodes * nodes];
400
401 for i in 0..nodes {
402 let mut src_score = 0.0;
404 for f in 0..self.out_features {
405 src_score += h[i * total_out + head_off + f]
406 * attn_src_data[head * self.out_features + f];
407 }
408
409 for j in 0..nodes {
410 let a_ij = adj_data[adj_off + i * nodes + j];
411 if a_ij != 0.0 {
412 let mut dst_score = 0.0;
413 for f in 0..self.out_features {
414 dst_score += h[j * total_out + head_off + f]
415 * attn_dst_data[head * self.out_features + f];
416 }
417
418 let e = src_score + dst_score;
419 let e = if e > 0.0 { e } else { e * self.negative_slope };
421 attn_scores[i * nodes + j] = e;
422 }
423 }
424 }
425
426 for i in 0..nodes {
428 let row_start = i * nodes;
429 let row_end = row_start + nodes;
430 let row = &attn_scores[row_start..row_end];
431
432 let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
433 if max_val == f32::NEG_INFINITY {
434 continue; }
436
437 let mut sum_exp = 0.0f32;
438 let mut exps = vec![0.0; nodes];
439 for j in 0..nodes {
440 if row[j] > f32::NEG_INFINITY {
441 exps[j] = (row[j] - max_val).exp();
442 sum_exp += exps[j];
443 }
444 }
445
446 let out_off = (b * nodes + i) * total_out + head_off;
448 for j in 0..nodes {
449 if exps[j] > 0.0 {
450 let alpha = exps[j] / sum_exp;
451 for f in 0..self.out_features {
452 output[out_off + f] += alpha * h[j * total_out + head_off + f];
453 }
454 }
455 }
456 }
457 }
458 }
459
460 if let Some(bias) = &self.bias {
462 let bias_data = bias.data().to_vec();
463 for b in 0..batch {
464 for i in 0..nodes {
465 let offset = (b * nodes + i) * total_out;
466 for o in 0..total_out {
467 output[offset + o] += bias_data[o];
468 }
469 }
470 }
471 }
472
473 Variable::new(
474 Tensor::from_vec(output, &[batch, nodes, total_out]).expect("tensor creation failed"),
475 x.requires_grad(),
476 )
477 }
478
479 pub fn total_out_features(&self) -> usize {
481 self.out_features * self.num_heads
482 }
483}
484
485impl Module for GATConv {
486 fn forward(&self, input: &Variable) -> Variable {
487 let n = input.shape()[0];
489 let mut eye_data = vec![0.0f32; n * n];
490 for i in 0..n {
491 eye_data[i * n + i] = 1.0;
492 }
493 let adj = Variable::new(
494 axonml_tensor::Tensor::from_vec(eye_data, &[n, n])
495 .expect("identity matrix creation failed"),
496 false,
497 );
498 self.forward_graph(input, &adj)
499 }
500
501 fn parameters(&self) -> Vec<Parameter> {
502 let mut params = vec![self.w.clone(), self.attn_src.clone(), self.attn_dst.clone()];
503 if let Some(bias) = &self.bias {
504 params.push(bias.clone());
505 }
506 params
507 }
508
509 fn named_parameters(&self) -> HashMap<String, Parameter> {
510 let mut params = HashMap::new();
511 params.insert("w".to_string(), self.w.clone());
512 params.insert("attn_src".to_string(), self.attn_src.clone());
513 params.insert("attn_dst".to_string(), self.attn_dst.clone());
514 if let Some(bias) = &self.bias {
515 params.insert("bias".to_string(), bias.clone());
516 }
517 params
518 }
519
520 fn name(&self) -> &'static str {
521 "GATConv"
522 }
523}
524
525#[cfg(test)]
530mod tests {
531 use super::*;
532
533 #[test]
534 fn test_gcn_conv_shape() {
535 let gcn = GCNConv::new(72, 128);
536 let x = Variable::new(
537 Tensor::from_vec(vec![1.0; 2 * 7 * 72], &[2, 7, 72]).expect("tensor creation failed"),
538 false,
539 );
540 let adj = Variable::new(
541 Tensor::from_vec(vec![1.0; 7 * 7], &[7, 7]).expect("tensor creation failed"),
542 false,
543 );
544 let output = gcn.forward_graph(&x, &adj);
545 assert_eq!(output.shape(), vec![2, 7, 128]);
546 }
547
548 #[test]
549 fn test_gcn_conv_identity_adjacency() {
550 let gcn = GCNConv::new(4, 8);
552 let x = Variable::new(
553 Tensor::from_vec(vec![1.0; 3 * 4], &[1, 3, 4]).expect("tensor creation failed"),
554 false,
555 );
556
557 let mut adj_data = vec![0.0; 9];
559 adj_data[0] = 1.0; adj_data[4] = 1.0; adj_data[8] = 1.0; let adj = Variable::new(
563 Tensor::from_vec(adj_data, &[3, 3]).expect("tensor creation failed"),
564 false,
565 );
566
567 let output = gcn.forward_graph(&x, &adj);
568 assert_eq!(output.shape(), vec![1, 3, 8]);
569
570 let data = output.data().to_vec();
572 for i in 0..3 {
573 for f in 0..8 {
574 assert!(
575 (data[i * 8 + f] - data[f]).abs() < 1e-6,
576 "Node outputs should be identical with identity adj and same input"
577 );
578 }
579 }
580 }
581
582 #[test]
583 fn test_gcn_conv_parameters() {
584 let gcn = GCNConv::new(16, 32);
585 let params = gcn.parameters();
586 assert_eq!(params.len(), 2); let total_params: usize = params.iter().map(|p| p.numel()).sum();
589 assert_eq!(total_params, 16 * 32 + 32); }
591
592 #[test]
593 fn test_gcn_conv_no_bias() {
594 let gcn = GCNConv::without_bias(16, 32);
595 let params = gcn.parameters();
596 assert_eq!(params.len(), 1); }
598
599 #[test]
600 fn test_gcn_conv_named_parameters() {
601 let gcn = GCNConv::new(16, 32);
602 let params = gcn.named_parameters();
603 assert!(params.contains_key("weight"));
604 assert!(params.contains_key("bias"));
605 }
606
607 #[test]
608 fn test_gat_conv_shape() {
609 let gat = GATConv::new(72, 32, 4); let x = Variable::new(
611 Tensor::from_vec(vec![1.0; 2 * 7 * 72], &[2, 7, 72]).expect("tensor creation failed"),
612 false,
613 );
614 let adj = Variable::new(
615 Tensor::from_vec(vec![1.0; 7 * 7], &[7, 7]).expect("tensor creation failed"),
616 false,
617 );
618 let output = gat.forward_graph(&x, &adj);
619 assert_eq!(output.shape(), vec![2, 7, 128]); }
621
622 #[test]
623 fn test_gat_conv_single_head() {
624 let gat = GATConv::new(16, 8, 1);
625 let x = Variable::new(
626 Tensor::from_vec(vec![1.0; 5 * 16], &[1, 5, 16]).expect("tensor creation failed"),
627 false,
628 );
629 let adj = Variable::new(
630 Tensor::from_vec(vec![1.0; 5 * 5], &[5, 5]).expect("tensor creation failed"),
631 false,
632 );
633 let output = gat.forward_graph(&x, &adj);
634 assert_eq!(output.shape(), vec![1, 5, 8]);
635 }
636
637 #[test]
638 fn test_gat_conv_parameters() {
639 let gat = GATConv::new(16, 8, 4);
640 let params = gat.parameters();
641 assert_eq!(params.len(), 4); let named = gat.named_parameters();
644 assert!(named.contains_key("w"));
645 assert!(named.contains_key("attn_src"));
646 assert!(named.contains_key("attn_dst"));
647 assert!(named.contains_key("bias"));
648 }
649
650 #[test]
651 fn test_gat_conv_total_output() {
652 let gat = GATConv::new(16, 32, 4);
653 assert_eq!(gat.total_out_features(), 128);
654 }
655
656 #[test]
657 fn test_gcn_zero_adjacency() {
658 let gcn = GCNConv::new(4, 4);
660 let x = Variable::new(
661 Tensor::from_vec(vec![99.0; 3 * 4], &[1, 3, 4]).expect("tensor creation failed"),
662 false,
663 );
664 let adj = Variable::new(
665 Tensor::from_vec(vec![0.0; 9], &[3, 3]).expect("tensor creation failed"),
666 false,
667 );
668 let output = gcn.forward_graph(&x, &adj);
669
670 let data = output.data().to_vec();
672 for val in &data {
673 assert!(
674 val.abs() < 1e-6,
675 "Zero adjacency should zero out message passing"
676 );
677 }
678 }
679}