1use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_tensor::Tensor;
21
22use crate::module::Module;
23use crate::parameter::Parameter;
24
25pub struct GCNConv {
53 weight: Parameter,
54 bias: Option<Parameter>,
55 in_features: usize,
56 out_features: usize,
57}
58
59impl GCNConv {
60 pub fn new(in_features: usize, out_features: usize) -> Self {
62 let scale = (2.0 / (in_features + out_features) as f32).sqrt();
64 let weight_data: Vec<f32> = (0..in_features * out_features)
65 .map(|i| {
66 let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
68 x * scale
69 })
70 .collect();
71
72 let weight = Parameter::named(
73 "weight",
74 Tensor::from_vec(weight_data, &[in_features, out_features]).expect("tensor creation failed"),
75 true,
76 );
77
78 let bias_data = vec![0.0; out_features];
79 let bias = Some(Parameter::named(
80 "bias",
81 Tensor::from_vec(bias_data, &[out_features]).expect("tensor creation failed"),
82 true,
83 ));
84
85 Self {
86 weight,
87 bias,
88 in_features,
89 out_features,
90 }
91 }
92
93 pub fn without_bias(in_features: usize, out_features: usize) -> Self {
95 let scale = (2.0 / (in_features + out_features) as f32).sqrt();
96 let weight_data: Vec<f32> = (0..in_features * out_features)
97 .map(|i| {
98 let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
99 x * scale
100 })
101 .collect();
102
103 let weight = Parameter::named(
104 "weight",
105 Tensor::from_vec(weight_data, &[in_features, out_features]).expect("tensor creation failed"),
106 true,
107 );
108
109 Self {
110 weight,
111 bias: None,
112 in_features,
113 out_features,
114 }
115 }
116
117 pub fn forward_graph(&self, x: &Variable, adj: &Variable) -> Variable {
126 let shape = x.shape();
127 assert!(
128 shape.len() == 3,
129 "GCNConv expects input shape (batch, nodes, features), got {:?}",
130 shape
131 );
132 assert_eq!(shape[2], self.in_features, "Input features mismatch");
133
134 let batch = shape[0];
135 let adj_shape = adj.shape();
136
137 let weight = self.weight.variable();
143
144 let mut per_sample: Vec<Variable> = Vec::with_capacity(batch);
145 for b in 0..batch {
146 let x_b = x.select(0, b);
148
149 let adj_b = if adj_shape.len() == 3 {
151 adj.select(0, b)
152 } else {
153 adj.clone()
154 };
155
156 let msg_b = adj_b.matmul(&x_b);
158
159 let mut out_b = msg_b.matmul(&weight);
161
162 if let Some(bias) = &self.bias {
164 out_b = out_b.add_var(&bias.variable());
165 }
166
167 per_sample.push(out_b.unsqueeze(0));
169 }
170
171 let refs: Vec<&Variable> = per_sample.iter().collect();
173 Variable::cat(&refs, 0)
174 }
175
176 pub fn in_features(&self) -> usize {
178 self.in_features
179 }
180
181 pub fn out_features(&self) -> usize {
183 self.out_features
184 }
185}
186
187impl Module for GCNConv {
188 fn forward(&self, input: &Variable) -> Variable {
191 let n = input.shape()[0];
193 let mut eye_data = vec![0.0f32; n * n];
194 for i in 0..n {
195 eye_data[i * n + i] = 1.0;
196 }
197 let adj = Variable::new(
198 axonml_tensor::Tensor::from_vec(eye_data, &[n, n]).expect("identity matrix creation failed"),
199 false,
200 );
201 self.forward_graph(input, &adj)
202 }
203
204 fn parameters(&self) -> Vec<Parameter> {
205 let mut params = vec![self.weight.clone()];
206 if let Some(bias) = &self.bias {
207 params.push(bias.clone());
208 }
209 params
210 }
211
212 fn named_parameters(&self) -> HashMap<String, Parameter> {
213 let mut params = HashMap::new();
214 params.insert("weight".to_string(), self.weight.clone());
215 if let Some(bias) = &self.bias {
216 params.insert("bias".to_string(), bias.clone());
217 }
218 params
219 }
220
221 fn name(&self) -> &'static str {
222 "GCNConv"
223 }
224}
225
226pub struct GATConv {
246 w: Parameter,
247 attn_src: Parameter,
248 attn_dst: Parameter,
249 bias: Option<Parameter>,
250 in_features: usize,
251 out_features: usize,
252 num_heads: usize,
253 negative_slope: f32,
254}
255
256impl GATConv {
257 pub fn new(in_features: usize, out_features: usize, num_heads: usize) -> Self {
264 let total_out = out_features * num_heads;
265 let scale = (2.0 / (in_features + total_out) as f32).sqrt();
266
267 let w_data: Vec<f32> = (0..in_features * total_out)
268 .map(|i| {
269 let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
270 x * scale
271 })
272 .collect();
273
274 let w = Parameter::named(
275 "w",
276 Tensor::from_vec(w_data, &[in_features, total_out]).expect("tensor creation failed"),
277 true,
278 );
279
280 let attn_scale = (1.0 / out_features as f32).sqrt();
282 let attn_src_data: Vec<f32> = (0..total_out)
283 .map(|i| {
284 let x = ((i as f32 * 0.723_606_8) % 1.0) * 2.0 - 1.0;
285 x * attn_scale
286 })
287 .collect();
288 let attn_dst_data: Vec<f32> = (0..total_out)
289 .map(|i| {
290 let x = ((i as f32 * 0.381_966_02) % 1.0) * 2.0 - 1.0;
291 x * attn_scale
292 })
293 .collect();
294
295 let attn_src = Parameter::named(
296 "attn_src",
297 Tensor::from_vec(attn_src_data, &[num_heads, out_features]).expect("tensor creation failed"),
298 true,
299 );
300
301 let attn_dst = Parameter::named(
302 "attn_dst",
303 Tensor::from_vec(attn_dst_data, &[num_heads, out_features]).expect("tensor creation failed"),
304 true,
305 );
306
307 let bias_data = vec![0.0; total_out];
308 let bias = Some(Parameter::named(
309 "bias",
310 Tensor::from_vec(bias_data, &[total_out]).expect("tensor creation failed"),
311 true,
312 ));
313
314 Self {
315 w,
316 attn_src,
317 attn_dst,
318 bias,
319 in_features,
320 out_features,
321 num_heads,
322 negative_slope: 0.2,
323 }
324 }
325
326 pub fn forward_graph(&self, x: &Variable, adj: &Variable) -> Variable {
335 let shape = x.shape();
336 assert!(
337 shape.len() == 3,
338 "GATConv expects (batch, nodes, features), got {:?}",
339 shape
340 );
341
342 let batch = shape[0];
343 let nodes = shape[1];
344 let total_out = self.out_features * self.num_heads;
345
346 let x_data = x.data().to_vec();
347 let adj_data = adj.data().to_vec();
348 let w_data = self.w.data().to_vec();
349 let attn_src_data = self.attn_src.data().to_vec();
350 let attn_dst_data = self.attn_dst.data().to_vec();
351
352 let adj_nodes = if adj.shape().len() == 3 {
353 adj.shape()[1]
354 } else {
355 adj.shape()[0]
356 };
357 assert_eq!(adj_nodes, nodes, "Adjacency matrix size mismatch");
358
359 let mut output = vec![0.0f32; batch * nodes * total_out];
360
361 for b in 0..batch {
362 let mut h = vec![0.0f32; nodes * total_out];
364 for i in 0..nodes {
365 let x_off = (b * nodes + i) * self.in_features;
366 for o in 0..total_out {
367 let mut val = 0.0;
368 for f in 0..self.in_features {
369 val += x_data[x_off + f] * w_data[f * total_out + o];
370 }
371 h[i * total_out + o] = val;
372 }
373 }
374
375 let adj_off = if adj.shape().len() == 3 {
377 b * nodes * nodes
378 } else {
379 0
380 };
381
382 for head in 0..self.num_heads {
383 let head_off = head * self.out_features;
384
385 let mut attn_scores = vec![f32::NEG_INFINITY; nodes * nodes];
388
389 for i in 0..nodes {
390 let mut src_score = 0.0;
392 for f in 0..self.out_features {
393 src_score += h[i * total_out + head_off + f]
394 * attn_src_data[head * self.out_features + f];
395 }
396
397 for j in 0..nodes {
398 let a_ij = adj_data[adj_off + i * nodes + j];
399 if a_ij != 0.0 {
400 let mut dst_score = 0.0;
401 for f in 0..self.out_features {
402 dst_score += h[j * total_out + head_off + f]
403 * attn_dst_data[head * self.out_features + f];
404 }
405
406 let e = src_score + dst_score;
407 let e = if e > 0.0 { e } else { e * self.negative_slope };
409 attn_scores[i * nodes + j] = e;
410 }
411 }
412 }
413
414 for i in 0..nodes {
416 let row_start = i * nodes;
417 let row_end = row_start + nodes;
418 let row = &attn_scores[row_start..row_end];
419
420 let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
421 if max_val == f32::NEG_INFINITY {
422 continue; }
424
425 let mut sum_exp = 0.0f32;
426 let mut exps = vec![0.0; nodes];
427 for j in 0..nodes {
428 if row[j] > f32::NEG_INFINITY {
429 exps[j] = (row[j] - max_val).exp();
430 sum_exp += exps[j];
431 }
432 }
433
434 let out_off = (b * nodes + i) * total_out + head_off;
436 for j in 0..nodes {
437 if exps[j] > 0.0 {
438 let alpha = exps[j] / sum_exp;
439 for f in 0..self.out_features {
440 output[out_off + f] += alpha * h[j * total_out + head_off + f];
441 }
442 }
443 }
444 }
445 }
446 }
447
448 if let Some(bias) = &self.bias {
450 let bias_data = bias.data().to_vec();
451 for b in 0..batch {
452 for i in 0..nodes {
453 let offset = (b * nodes + i) * total_out;
454 for o in 0..total_out {
455 output[offset + o] += bias_data[o];
456 }
457 }
458 }
459 }
460
461 Variable::new(
462 Tensor::from_vec(output, &[batch, nodes, total_out]).expect("tensor creation failed"),
463 x.requires_grad(),
464 )
465 }
466
467 pub fn total_out_features(&self) -> usize {
469 self.out_features * self.num_heads
470 }
471}
472
473impl Module for GATConv {
474 fn forward(&self, input: &Variable) -> Variable {
475 let n = input.shape()[0];
477 let mut eye_data = vec![0.0f32; n * n];
478 for i in 0..n {
479 eye_data[i * n + i] = 1.0;
480 }
481 let adj = Variable::new(
482 axonml_tensor::Tensor::from_vec(eye_data, &[n, n]).expect("identity matrix creation failed"),
483 false,
484 );
485 self.forward_graph(input, &adj)
486 }
487
488 fn parameters(&self) -> Vec<Parameter> {
489 let mut params = vec![self.w.clone(), self.attn_src.clone(), self.attn_dst.clone()];
490 if let Some(bias) = &self.bias {
491 params.push(bias.clone());
492 }
493 params
494 }
495
496 fn named_parameters(&self) -> HashMap<String, Parameter> {
497 let mut params = HashMap::new();
498 params.insert("w".to_string(), self.w.clone());
499 params.insert("attn_src".to_string(), self.attn_src.clone());
500 params.insert("attn_dst".to_string(), self.attn_dst.clone());
501 if let Some(bias) = &self.bias {
502 params.insert("bias".to_string(), bias.clone());
503 }
504 params
505 }
506
507 fn name(&self) -> &'static str {
508 "GATConv"
509 }
510}
511
512#[cfg(test)]
517mod tests {
518 use super::*;
519
520 #[test]
521 fn test_gcn_conv_shape() {
522 let gcn = GCNConv::new(72, 128);
523 let x = Variable::new(
524 Tensor::from_vec(vec![1.0; 2 * 7 * 72], &[2, 7, 72]).expect("tensor creation failed"),
525 false,
526 );
527 let adj = Variable::new(Tensor::from_vec(vec![1.0; 7 * 7], &[7, 7]).expect("tensor creation failed"), false);
528 let output = gcn.forward_graph(&x, &adj);
529 assert_eq!(output.shape(), vec![2, 7, 128]);
530 }
531
532 #[test]
533 fn test_gcn_conv_identity_adjacency() {
534 let gcn = GCNConv::new(4, 8);
536 let x = Variable::new(
537 Tensor::from_vec(vec![1.0; 1 * 3 * 4], &[1, 3, 4]).expect("tensor creation failed"),
538 false,
539 );
540
541 let mut adj_data = vec![0.0; 9];
543 adj_data[0] = 1.0; adj_data[4] = 1.0; adj_data[8] = 1.0; let adj = Variable::new(Tensor::from_vec(adj_data, &[3, 3]).expect("tensor creation failed"), false);
547
548 let output = gcn.forward_graph(&x, &adj);
549 assert_eq!(output.shape(), vec![1, 3, 8]);
550
551 let data = output.data().to_vec();
553 for i in 0..3 {
554 for f in 0..8 {
555 assert!(
556 (data[i * 8 + f] - data[f]).abs() < 1e-6,
557 "Node outputs should be identical with identity adj and same input"
558 );
559 }
560 }
561 }
562
563 #[test]
564 fn test_gcn_conv_parameters() {
565 let gcn = GCNConv::new(16, 32);
566 let params = gcn.parameters();
567 assert_eq!(params.len(), 2); let total_params: usize = params.iter().map(|p| p.numel()).sum();
570 assert_eq!(total_params, 16 * 32 + 32); }
572
573 #[test]
574 fn test_gcn_conv_no_bias() {
575 let gcn = GCNConv::without_bias(16, 32);
576 let params = gcn.parameters();
577 assert_eq!(params.len(), 1); }
579
580 #[test]
581 fn test_gcn_conv_named_parameters() {
582 let gcn = GCNConv::new(16, 32);
583 let params = gcn.named_parameters();
584 assert!(params.contains_key("weight"));
585 assert!(params.contains_key("bias"));
586 }
587
588 #[test]
589 fn test_gat_conv_shape() {
590 let gat = GATConv::new(72, 32, 4); let x = Variable::new(
592 Tensor::from_vec(vec![1.0; 2 * 7 * 72], &[2, 7, 72]).expect("tensor creation failed"),
593 false,
594 );
595 let adj = Variable::new(Tensor::from_vec(vec![1.0; 7 * 7], &[7, 7]).expect("tensor creation failed"), false);
596 let output = gat.forward_graph(&x, &adj);
597 assert_eq!(output.shape(), vec![2, 7, 128]); }
599
600 #[test]
601 fn test_gat_conv_single_head() {
602 let gat = GATConv::new(16, 8, 1);
603 let x = Variable::new(
604 Tensor::from_vec(vec![1.0; 1 * 5 * 16], &[1, 5, 16]).expect("tensor creation failed"),
605 false,
606 );
607 let adj = Variable::new(Tensor::from_vec(vec![1.0; 5 * 5], &[5, 5]).expect("tensor creation failed"), false);
608 let output = gat.forward_graph(&x, &adj);
609 assert_eq!(output.shape(), vec![1, 5, 8]);
610 }
611
612 #[test]
613 fn test_gat_conv_parameters() {
614 let gat = GATConv::new(16, 8, 4);
615 let params = gat.parameters();
616 assert_eq!(params.len(), 4); let named = gat.named_parameters();
619 assert!(named.contains_key("w"));
620 assert!(named.contains_key("attn_src"));
621 assert!(named.contains_key("attn_dst"));
622 assert!(named.contains_key("bias"));
623 }
624
625 #[test]
626 fn test_gat_conv_total_output() {
627 let gat = GATConv::new(16, 32, 4);
628 assert_eq!(gat.total_out_features(), 128);
629 }
630
631 #[test]
632 fn test_gcn_zero_adjacency() {
633 let gcn = GCNConv::new(4, 4);
635 let x = Variable::new(
636 Tensor::from_vec(vec![99.0; 1 * 3 * 4], &[1, 3, 4]).expect("tensor creation failed"),
637 false,
638 );
639 let adj = Variable::new(Tensor::from_vec(vec![0.0; 9], &[3, 3]).expect("tensor creation failed"), false);
640 let output = gcn.forward_graph(&x, &adj);
641
642 let data = output.data().to_vec();
644 for val in &data {
645 assert!(
646 val.abs() < 1e-6,
647 "Zero adjacency should zero out message passing"
648 );
649 }
650 }
651}