1use std::collections::HashMap;
10
11use axonml_autograd::Variable;
12use axonml_tensor::Tensor;
13
14use crate::module::Module;
15use crate::parameter::Parameter;
16
17
18pub struct GCNConv {
46 weight: Parameter,
47 bias: Option<Parameter>,
48 in_features: usize,
49 out_features: usize,
50}
51
52impl GCNConv {
53 pub fn new(in_features: usize, out_features: usize) -> Self {
55 let scale = (2.0 / (in_features + out_features) as f32).sqrt();
57 let weight_data: Vec<f32> = (0..in_features * out_features)
58 .map(|i| {
59 let x = ((i as f32 * 0.6180339887) % 1.0) * 2.0 - 1.0;
61 x * scale
62 })
63 .collect();
64
65 let weight = Parameter::named(
66 "weight",
67 Tensor::from_vec(weight_data, &[in_features, out_features]).unwrap(),
68 true,
69 );
70
71 let bias_data = vec![0.0; out_features];
72 let bias = Some(Parameter::named(
73 "bias",
74 Tensor::from_vec(bias_data, &[out_features]).unwrap(),
75 true,
76 ));
77
78 Self {
79 weight,
80 bias,
81 in_features,
82 out_features,
83 }
84 }
85
86 pub fn without_bias(in_features: usize, out_features: usize) -> Self {
88 let scale = (2.0 / (in_features + out_features) as f32).sqrt();
89 let weight_data: Vec<f32> = (0..in_features * out_features)
90 .map(|i| {
91 let x = ((i as f32 * 0.6180339887) % 1.0) * 2.0 - 1.0;
92 x * scale
93 })
94 .collect();
95
96 let weight = Parameter::named(
97 "weight",
98 Tensor::from_vec(weight_data, &[in_features, out_features]).unwrap(),
99 true,
100 );
101
102 Self {
103 weight,
104 bias: None,
105 in_features,
106 out_features,
107 }
108 }
109
110 pub fn forward_graph(&self, x: &Variable, adj: &Variable) -> Variable {
119 let shape = x.shape();
120 assert!(shape.len() == 3, "GCNConv expects input shape (batch, nodes, features), got {:?}", shape);
121 assert_eq!(shape[2], self.in_features, "Input features mismatch");
122
123 let batch = shape[0];
124 let nodes = shape[1];
125 let adj_shape = adj.shape();
126
127 let x_data = x.data().to_vec();
128 let adj_data = adj.data().to_vec();
129 let w_data = self.weight.data().to_vec();
130
131 let mut output = vec![0.0f32; batch * nodes * self.out_features];
132
133 for b in 0..batch {
134 let adj_offset = if adj_shape.len() == 3 {
136 b * nodes * nodes
137 } else {
138 0 };
140
141 for i in 0..nodes {
144 let mut message = vec![0.0f32; self.in_features];
146 for j in 0..nodes {
147 let a_ij = adj_data[adj_offset + i * nodes + j];
148 if a_ij != 0.0 {
149 let x_offset = (b * nodes + j) * self.in_features;
150 for f in 0..self.in_features {
151 message[f] += a_ij * x_data[x_offset + f];
152 }
153 }
154 }
155
156 let out_offset = (b * nodes + i) * self.out_features;
158 for o in 0..self.out_features {
159 let mut val = 0.0;
160 for f in 0..self.in_features {
161 val += message[f] * w_data[f * self.out_features + o];
162 }
163 output[out_offset + o] = val;
164 }
165 }
166 }
167
168 if let Some(bias) = &self.bias {
170 let bias_data = bias.data().to_vec();
171 for b in 0..batch {
172 for i in 0..nodes {
173 let offset = (b * nodes + i) * self.out_features;
174 for o in 0..self.out_features {
175 output[offset + o] += bias_data[o];
176 }
177 }
178 }
179 }
180
181 Variable::new(
182 Tensor::from_vec(output, &[batch, nodes, self.out_features]).unwrap(),
183 x.requires_grad() || adj.requires_grad(),
184 )
185 }
186
187 pub fn in_features(&self) -> usize {
189 self.in_features
190 }
191
192 pub fn out_features(&self) -> usize {
194 self.out_features
195 }
196}
197
198impl Module for GCNConv {
199 fn forward(&self, _input: &Variable) -> Variable {
200 panic!("GCNConv requires an adjacency matrix. Use forward_graph(x, adj) instead.")
201 }
202
203 fn parameters(&self) -> Vec<Parameter> {
204 let mut params = vec![self.weight.clone()];
205 if let Some(bias) = &self.bias {
206 params.push(bias.clone());
207 }
208 params
209 }
210
211 fn named_parameters(&self) -> HashMap<String, Parameter> {
212 let mut params = HashMap::new();
213 params.insert("weight".to_string(), self.weight.clone());
214 if let Some(bias) = &self.bias {
215 params.insert("bias".to_string(), bias.clone());
216 }
217 params
218 }
219
220 fn name(&self) -> &'static str {
221 "GCNConv"
222 }
223}
224
225pub struct GATConv {
245 w: Parameter,
246 attn_src: Parameter,
247 attn_dst: Parameter,
248 bias: Option<Parameter>,
249 in_features: usize,
250 out_features: usize,
251 num_heads: usize,
252 negative_slope: f32,
253}
254
255impl GATConv {
256 pub fn new(in_features: usize, out_features: usize, num_heads: usize) -> Self {
263 let total_out = out_features * num_heads;
264 let scale = (2.0 / (in_features + total_out) as f32).sqrt();
265
266 let w_data: Vec<f32> = (0..in_features * total_out)
267 .map(|i| {
268 let x = ((i as f32 * 0.6180339887) % 1.0) * 2.0 - 1.0;
269 x * scale
270 })
271 .collect();
272
273 let w = Parameter::named(
274 "w",
275 Tensor::from_vec(w_data, &[in_features, total_out]).unwrap(),
276 true,
277 );
278
279 let attn_scale = (1.0 / out_features as f32).sqrt();
281 let attn_src_data: Vec<f32> = (0..total_out)
282 .map(|i| {
283 let x = ((i as f32 * 0.7236067977) % 1.0) * 2.0 - 1.0;
284 x * attn_scale
285 })
286 .collect();
287 let attn_dst_data: Vec<f32> = (0..total_out)
288 .map(|i| {
289 let x = ((i as f32 * 0.3819660113) % 1.0) * 2.0 - 1.0;
290 x * attn_scale
291 })
292 .collect();
293
294 let attn_src = Parameter::named(
295 "attn_src",
296 Tensor::from_vec(attn_src_data, &[num_heads, out_features]).unwrap(),
297 true,
298 );
299
300 let attn_dst = Parameter::named(
301 "attn_dst",
302 Tensor::from_vec(attn_dst_data, &[num_heads, out_features]).unwrap(),
303 true,
304 );
305
306 let bias_data = vec![0.0; total_out];
307 let bias = Some(Parameter::named(
308 "bias",
309 Tensor::from_vec(bias_data, &[total_out]).unwrap(),
310 true,
311 ));
312
313 Self {
314 w,
315 attn_src,
316 attn_dst,
317 bias,
318 in_features,
319 out_features,
320 num_heads,
321 negative_slope: 0.2,
322 }
323 }
324
325 pub fn forward_graph(&self, x: &Variable, adj: &Variable) -> Variable {
334 let shape = x.shape();
335 assert!(shape.len() == 3, "GATConv expects (batch, nodes, features), got {:?}", shape);
336
337 let batch = shape[0];
338 let nodes = shape[1];
339 let total_out = self.out_features * self.num_heads;
340
341 let x_data = x.data().to_vec();
342 let adj_data = adj.data().to_vec();
343 let w_data = self.w.data().to_vec();
344 let attn_src_data = self.attn_src.data().to_vec();
345 let attn_dst_data = self.attn_dst.data().to_vec();
346
347 let adj_nodes = if adj.shape().len() == 3 { adj.shape()[1] } else { adj.shape()[0] };
348 assert_eq!(adj_nodes, nodes, "Adjacency matrix size mismatch");
349
350 let mut output = vec![0.0f32; batch * nodes * total_out];
351
352 for b in 0..batch {
353 let mut h = vec![0.0f32; nodes * total_out];
355 for i in 0..nodes {
356 let x_off = (b * nodes + i) * self.in_features;
357 for o in 0..total_out {
358 let mut val = 0.0;
359 for f in 0..self.in_features {
360 val += x_data[x_off + f] * w_data[f * total_out + o];
361 }
362 h[i * total_out + o] = val;
363 }
364 }
365
366 let adj_off = if adj.shape().len() == 3 { b * nodes * nodes } else { 0 };
368
369 for head in 0..self.num_heads {
370 let head_off = head * self.out_features;
371
372 let mut attn_scores = vec![f32::NEG_INFINITY; nodes * nodes];
375
376 for i in 0..nodes {
377 let mut src_score = 0.0;
379 for f in 0..self.out_features {
380 src_score += h[i * total_out + head_off + f] * attn_src_data[head * self.out_features + f];
381 }
382
383 for j in 0..nodes {
384 let a_ij = adj_data[adj_off + i * nodes + j];
385 if a_ij != 0.0 {
386 let mut dst_score = 0.0;
387 for f in 0..self.out_features {
388 dst_score += h[j * total_out + head_off + f] * attn_dst_data[head * self.out_features + f];
389 }
390
391 let e = src_score + dst_score;
392 let e = if e > 0.0 { e } else { e * self.negative_slope };
394 attn_scores[i * nodes + j] = e;
395 }
396 }
397 }
398
399 for i in 0..nodes {
401 let row_start = i * nodes;
402 let row_end = row_start + nodes;
403 let row = &attn_scores[row_start..row_end];
404
405 let max_val = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
406 if max_val == f32::NEG_INFINITY {
407 continue; }
409
410 let mut sum_exp = 0.0f32;
411 let mut exps = vec![0.0; nodes];
412 for j in 0..nodes {
413 if row[j] > f32::NEG_INFINITY {
414 exps[j] = (row[j] - max_val).exp();
415 sum_exp += exps[j];
416 }
417 }
418
419 let out_off = (b * nodes + i) * total_out + head_off;
421 for j in 0..nodes {
422 if exps[j] > 0.0 {
423 let alpha = exps[j] / sum_exp;
424 for f in 0..self.out_features {
425 output[out_off + f] += alpha * h[j * total_out + head_off + f];
426 }
427 }
428 }
429 }
430 }
431 }
432
433 if let Some(bias) = &self.bias {
435 let bias_data = bias.data().to_vec();
436 for b in 0..batch {
437 for i in 0..nodes {
438 let offset = (b * nodes + i) * total_out;
439 for o in 0..total_out {
440 output[offset + o] += bias_data[o];
441 }
442 }
443 }
444 }
445
446 Variable::new(
447 Tensor::from_vec(output, &[batch, nodes, total_out]).unwrap(),
448 x.requires_grad(),
449 )
450 }
451
452 pub fn total_out_features(&self) -> usize {
454 self.out_features * self.num_heads
455 }
456}
457
458impl Module for GATConv {
459 fn forward(&self, _input: &Variable) -> Variable {
460 panic!("GATConv requires an adjacency matrix. Use forward_graph(x, adj) instead.")
461 }
462
463 fn parameters(&self) -> Vec<Parameter> {
464 let mut params = vec![self.w.clone(), self.attn_src.clone(), self.attn_dst.clone()];
465 if let Some(bias) = &self.bias {
466 params.push(bias.clone());
467 }
468 params
469 }
470
471 fn named_parameters(&self) -> HashMap<String, Parameter> {
472 let mut params = HashMap::new();
473 params.insert("w".to_string(), self.w.clone());
474 params.insert("attn_src".to_string(), self.attn_src.clone());
475 params.insert("attn_dst".to_string(), self.attn_dst.clone());
476 if let Some(bias) = &self.bias {
477 params.insert("bias".to_string(), bias.clone());
478 }
479 params
480 }
481
482 fn name(&self) -> &'static str {
483 "GATConv"
484 }
485}
486
487#[cfg(test)]
492mod tests {
493 use super::*;
494
495 #[test]
496 fn test_gcn_conv_shape() {
497 let gcn = GCNConv::new(72, 128);
498 let x = Variable::new(
499 Tensor::from_vec(vec![1.0; 2 * 7 * 72], &[2, 7, 72]).unwrap(),
500 false,
501 );
502 let adj = Variable::new(
503 Tensor::from_vec(vec![1.0; 7 * 7], &[7, 7]).unwrap(),
504 false,
505 );
506 let output = gcn.forward_graph(&x, &adj);
507 assert_eq!(output.shape(), vec![2, 7, 128]);
508 }
509
510 #[test]
511 fn test_gcn_conv_identity_adjacency() {
512 let gcn = GCNConv::new(4, 8);
514 let x = Variable::new(
515 Tensor::from_vec(vec![1.0; 1 * 3 * 4], &[1, 3, 4]).unwrap(),
516 false,
517 );
518
519 let mut adj_data = vec![0.0; 9];
521 adj_data[0] = 1.0; adj_data[4] = 1.0; adj_data[8] = 1.0; let adj = Variable::new(
525 Tensor::from_vec(adj_data, &[3, 3]).unwrap(),
526 false,
527 );
528
529 let output = gcn.forward_graph(&x, &adj);
530 assert_eq!(output.shape(), vec![1, 3, 8]);
531
532 let data = output.data().to_vec();
534 for i in 0..3 {
535 for f in 0..8 {
536 assert!((data[i * 8 + f] - data[f]).abs() < 1e-6,
537 "Node outputs should be identical with identity adj and same input");
538 }
539 }
540 }
541
542 #[test]
543 fn test_gcn_conv_parameters() {
544 let gcn = GCNConv::new(16, 32);
545 let params = gcn.parameters();
546 assert_eq!(params.len(), 2); let total_params: usize = params.iter().map(|p| p.numel()).sum();
549 assert_eq!(total_params, 16 * 32 + 32); }
551
552 #[test]
553 fn test_gcn_conv_no_bias() {
554 let gcn = GCNConv::without_bias(16, 32);
555 let params = gcn.parameters();
556 assert_eq!(params.len(), 1); }
558
559 #[test]
560 fn test_gcn_conv_named_parameters() {
561 let gcn = GCNConv::new(16, 32);
562 let params = gcn.named_parameters();
563 assert!(params.contains_key("weight"));
564 assert!(params.contains_key("bias"));
565 }
566
567 #[test]
568 fn test_gat_conv_shape() {
569 let gat = GATConv::new(72, 32, 4); let x = Variable::new(
571 Tensor::from_vec(vec![1.0; 2 * 7 * 72], &[2, 7, 72]).unwrap(),
572 false,
573 );
574 let adj = Variable::new(
575 Tensor::from_vec(vec![1.0; 7 * 7], &[7, 7]).unwrap(),
576 false,
577 );
578 let output = gat.forward_graph(&x, &adj);
579 assert_eq!(output.shape(), vec![2, 7, 128]); }
581
582 #[test]
583 fn test_gat_conv_single_head() {
584 let gat = GATConv::new(16, 8, 1);
585 let x = Variable::new(
586 Tensor::from_vec(vec![1.0; 1 * 5 * 16], &[1, 5, 16]).unwrap(),
587 false,
588 );
589 let adj = Variable::new(
590 Tensor::from_vec(vec![1.0; 5 * 5], &[5, 5]).unwrap(),
591 false,
592 );
593 let output = gat.forward_graph(&x, &adj);
594 assert_eq!(output.shape(), vec![1, 5, 8]);
595 }
596
597 #[test]
598 fn test_gat_conv_parameters() {
599 let gat = GATConv::new(16, 8, 4);
600 let params = gat.parameters();
601 assert_eq!(params.len(), 4); let named = gat.named_parameters();
604 assert!(named.contains_key("w"));
605 assert!(named.contains_key("attn_src"));
606 assert!(named.contains_key("attn_dst"));
607 assert!(named.contains_key("bias"));
608 }
609
610 #[test]
611 fn test_gat_conv_total_output() {
612 let gat = GATConv::new(16, 32, 4);
613 assert_eq!(gat.total_out_features(), 128);
614 }
615
616 #[test]
617 fn test_gcn_zero_adjacency() {
618 let gcn = GCNConv::new(4, 4);
620 let x = Variable::new(
621 Tensor::from_vec(vec![99.0; 1 * 3 * 4], &[1, 3, 4]).unwrap(),
622 false,
623 );
624 let adj = Variable::new(
625 Tensor::from_vec(vec![0.0; 9], &[3, 3]).unwrap(),
626 false,
627 );
628 let output = gcn.forward_graph(&x, &adj);
629
630 let data = output.data().to_vec();
632 for val in &data {
633 assert!(val.abs() < 1e-6, "Zero adjacency should zero out message passing");
634 }
635 }
636}