1use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_tensor::Tensor;
21
22use crate::module::Module;
23use crate::parameter::Parameter;
24
25pub struct GCNConv {
53 weight: Parameter,
54 bias: Option<Parameter>,
55 in_features: usize,
56 out_features: usize,
57}
58
59impl GCNConv {
60 pub fn new(in_features: usize, out_features: usize) -> Self {
62 let scale = (2.0 / (in_features + out_features) as f32).sqrt();
64 let weight_data: Vec<f32> = (0..in_features * out_features)
65 .map(|i| {
66 let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
68 x * scale
69 })
70 .collect();
71
72 let weight = Parameter::named(
73 "weight",
74 Tensor::from_vec(weight_data, &[in_features, out_features]).unwrap(),
75 true,
76 );
77
78 let bias_data = vec![0.0; out_features];
79 let bias = Some(Parameter::named(
80 "bias",
81 Tensor::from_vec(bias_data, &[out_features]).unwrap(),
82 true,
83 ));
84
85 Self {
86 weight,
87 bias,
88 in_features,
89 out_features,
90 }
91 }
92
93 pub fn without_bias(in_features: usize, out_features: usize) -> Self {
95 let scale = (2.0 / (in_features + out_features) as f32).sqrt();
96 let weight_data: Vec<f32> = (0..in_features * out_features)
97 .map(|i| {
98 let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
99 x * scale
100 })
101 .collect();
102
103 let weight = Parameter::named(
104 "weight",
105 Tensor::from_vec(weight_data, &[in_features, out_features]).unwrap(),
106 true,
107 );
108
109 Self {
110 weight,
111 bias: None,
112 in_features,
113 out_features,
114 }
115 }
116
117 pub fn forward_graph(&self, x: &Variable, adj: &Variable) -> Variable {
126 let shape = x.shape();
127 assert!(
128 shape.len() == 3,
129 "GCNConv expects input shape (batch, nodes, features), got {:?}",
130 shape
131 );
132 assert_eq!(shape[2], self.in_features, "Input features mismatch");
133
134 let batch = shape[0];
135 let adj_shape = adj.shape();
136
137 let weight = self.weight.variable();
143
144 let mut per_sample: Vec<Variable> = Vec::with_capacity(batch);
145 for b in 0..batch {
146 let x_b = x.select(0, b);
148
149 let adj_b = if adj_shape.len() == 3 {
151 adj.select(0, b)
152 } else {
153 adj.clone()
154 };
155
156 let msg_b = adj_b.matmul(&x_b);
158
159 let mut out_b = msg_b.matmul(&weight);
161
162 if let Some(bias) = &self.bias {
164 out_b = out_b.add_var(&bias.variable());
165 }
166
167 per_sample.push(out_b.unsqueeze(0));
169 }
170
171 let refs: Vec<&Variable> = per_sample.iter().collect();
173 Variable::cat(&refs, 0)
174 }
175
176 pub fn in_features(&self) -> usize {
178 self.in_features
179 }
180
181 pub fn out_features(&self) -> usize {
183 self.out_features
184 }
185}
186
187impl Module for GCNConv {
188 fn forward(&self, _input: &Variable) -> Variable {
189 panic!("GCNConv requires an adjacency matrix. Use forward_graph(x, adj) instead.")
190 }
191
192 fn parameters(&self) -> Vec<Parameter> {
193 let mut params = vec![self.weight.clone()];
194 if let Some(bias) = &self.bias {
195 params.push(bias.clone());
196 }
197 params
198 }
199
200 fn named_parameters(&self) -> HashMap<String, Parameter> {
201 let mut params = HashMap::new();
202 params.insert("weight".to_string(), self.weight.clone());
203 if let Some(bias) = &self.bias {
204 params.insert("bias".to_string(), bias.clone());
205 }
206 params
207 }
208
209 fn name(&self) -> &'static str {
210 "GCNConv"
211 }
212}
213
214pub struct GATConv {
234 w: Parameter,
235 attn_src: Parameter,
236 attn_dst: Parameter,
237 bias: Option<Parameter>,
238 in_features: usize,
239 out_features: usize,
240 num_heads: usize,
241 negative_slope: f32,
242}
243
244impl GATConv {
245 pub fn new(in_features: usize, out_features: usize, num_heads: usize) -> Self {
252 let total_out = out_features * num_heads;
253 let scale = (2.0 / (in_features + total_out) as f32).sqrt();
254
255 let w_data: Vec<f32> = (0..in_features * total_out)
256 .map(|i| {
257 let x = ((i as f32 * 0.618_034) % 1.0) * 2.0 - 1.0;
258 x * scale
259 })
260 .collect();
261
262 let w = Parameter::named(
263 "w",
264 Tensor::from_vec(w_data, &[in_features, total_out]).unwrap(),
265 true,
266 );
267
268 let attn_scale = (1.0 / out_features as f32).sqrt();
270 let attn_src_data: Vec<f32> = (0..total_out)
271 .map(|i| {
272 let x = ((i as f32 * 0.723_606_8) % 1.0) * 2.0 - 1.0;
273 x * attn_scale
274 })
275 .collect();
276 let attn_dst_data: Vec<f32> = (0..total_out)
277 .map(|i| {
278 let x = ((i as f32 * 0.381_966_02) % 1.0) * 2.0 - 1.0;
279 x * attn_scale
280 })
281 .collect();
282
283 let attn_src = Parameter::named(
284 "attn_src",
285 Tensor::from_vec(attn_src_data, &[num_heads, out_features]).unwrap(),
286 true,
287 );
288
289 let attn_dst = Parameter::named(
290 "attn_dst",
291 Tensor::from_vec(attn_dst_data, &[num_heads, out_features]).unwrap(),
292 true,
293 );
294
295 let bias_data = vec![0.0; total_out];
296 let bias = Some(Parameter::named(
297 "bias",
298 Tensor::from_vec(bias_data, &[total_out]).unwrap(),
299 true,
300 ));
301
302 Self {
303 w,
304 attn_src,
305 attn_dst,
306 bias,
307 in_features,
308 out_features,
309 num_heads,
310 negative_slope: 0.2,
311 }
312 }
313
314 pub fn forward_graph(&self, x: &Variable, adj: &Variable) -> Variable {
323 let shape = x.shape();
324 assert!(
325 shape.len() == 3,
326 "GATConv expects (batch, nodes, features), got {:?}",
327 shape
328 );
329
330 let batch = shape[0];
331 let nodes = shape[1];
332 let total_out = self.out_features * self.num_heads;
333
334 let x_data = x.data().to_vec();
335 let adj_data = adj.data().to_vec();
336 let w_data = self.w.data().to_vec();
337 let attn_src_data = self.attn_src.data().to_vec();
338 let attn_dst_data = self.attn_dst.data().to_vec();
339
340 let adj_nodes = if adj.shape().len() == 3 {
341 adj.shape()[1]
342 } else {
343 adj.shape()[0]
344 };
345 assert_eq!(adj_nodes, nodes, "Adjacency matrix size mismatch");
346
347 let mut output = vec![0.0f32; batch * nodes * total_out];
348
349 for b in 0..batch {
350 let mut h = vec![0.0f32; nodes * total_out];
352 for i in 0..nodes {
353 let x_off = (b * nodes + i) * self.in_features;
354 for o in 0..total_out {
355 let mut val = 0.0;
356 for f in 0..self.in_features {
357 val += x_data[x_off + f] * w_data[f * total_out + o];
358 }
359 h[i * total_out + o] = val;
360 }
361 }
362
363 let adj_off = if adj.shape().len() == 3 {
365 b * nodes * nodes
366 } else {
367 0
368 };
369
370 for head in 0..self.num_heads {
371 let head_off = head * self.out_features;
372
373 let mut attn_scores = vec![f32::NEG_INFINITY; nodes * nodes];
376
377 for i in 0..nodes {
378 let mut src_score = 0.0;
380 for f in 0..self.out_features {
381 src_score += h[i * total_out + head_off + f]
382 * attn_src_data[head * self.out_features + f];
383 }
384
385 for j in 0..nodes {
386 let a_ij = adj_data[adj_off + i * nodes + j];
387 if a_ij != 0.0 {
388 let mut dst_score = 0.0;
389 for f in 0..self.out_features {
390 dst_score += h[j * total_out + head_off + f]
391 * attn_dst_data[head * self.out_features + f];
392 }
393
394 let e = src_score + dst_score;
395 let e = if e > 0.0 { e } else { e * self.negative_slope };
397 attn_scores[i * nodes + j] = e;
398 }
399 }
400 }
401
402 for i in 0..nodes {
404 let row_start = i * nodes;
405 let row_end = row_start + nodes;
406 let row = &attn_scores[row_start..row_end];
407
408 let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
409 if max_val == f32::NEG_INFINITY {
410 continue; }
412
413 let mut sum_exp = 0.0f32;
414 let mut exps = vec![0.0; nodes];
415 for j in 0..nodes {
416 if row[j] > f32::NEG_INFINITY {
417 exps[j] = (row[j] - max_val).exp();
418 sum_exp += exps[j];
419 }
420 }
421
422 let out_off = (b * nodes + i) * total_out + head_off;
424 for j in 0..nodes {
425 if exps[j] > 0.0 {
426 let alpha = exps[j] / sum_exp;
427 for f in 0..self.out_features {
428 output[out_off + f] += alpha * h[j * total_out + head_off + f];
429 }
430 }
431 }
432 }
433 }
434 }
435
436 if let Some(bias) = &self.bias {
438 let bias_data = bias.data().to_vec();
439 for b in 0..batch {
440 for i in 0..nodes {
441 let offset = (b * nodes + i) * total_out;
442 for o in 0..total_out {
443 output[offset + o] += bias_data[o];
444 }
445 }
446 }
447 }
448
449 Variable::new(
450 Tensor::from_vec(output, &[batch, nodes, total_out]).unwrap(),
451 x.requires_grad(),
452 )
453 }
454
455 pub fn total_out_features(&self) -> usize {
457 self.out_features * self.num_heads
458 }
459}
460
461impl Module for GATConv {
462 fn forward(&self, _input: &Variable) -> Variable {
463 panic!("GATConv requires an adjacency matrix. Use forward_graph(x, adj) instead.")
464 }
465
466 fn parameters(&self) -> Vec<Parameter> {
467 let mut params = vec![self.w.clone(), self.attn_src.clone(), self.attn_dst.clone()];
468 if let Some(bias) = &self.bias {
469 params.push(bias.clone());
470 }
471 params
472 }
473
474 fn named_parameters(&self) -> HashMap<String, Parameter> {
475 let mut params = HashMap::new();
476 params.insert("w".to_string(), self.w.clone());
477 params.insert("attn_src".to_string(), self.attn_src.clone());
478 params.insert("attn_dst".to_string(), self.attn_dst.clone());
479 if let Some(bias) = &self.bias {
480 params.insert("bias".to_string(), bias.clone());
481 }
482 params
483 }
484
485 fn name(&self) -> &'static str {
486 "GATConv"
487 }
488}
489
490#[cfg(test)]
495mod tests {
496 use super::*;
497
498 #[test]
499 fn test_gcn_conv_shape() {
500 let gcn = GCNConv::new(72, 128);
501 let x = Variable::new(
502 Tensor::from_vec(vec![1.0; 2 * 7 * 72], &[2, 7, 72]).unwrap(),
503 false,
504 );
505 let adj = Variable::new(Tensor::from_vec(vec![1.0; 7 * 7], &[7, 7]).unwrap(), false);
506 let output = gcn.forward_graph(&x, &adj);
507 assert_eq!(output.shape(), vec![2, 7, 128]);
508 }
509
510 #[test]
511 fn test_gcn_conv_identity_adjacency() {
512 let gcn = GCNConv::new(4, 8);
514 let x = Variable::new(
515 Tensor::from_vec(vec![1.0; 1 * 3 * 4], &[1, 3, 4]).unwrap(),
516 false,
517 );
518
519 let mut adj_data = vec![0.0; 9];
521 adj_data[0] = 1.0; adj_data[4] = 1.0; adj_data[8] = 1.0; let adj = Variable::new(Tensor::from_vec(adj_data, &[3, 3]).unwrap(), false);
525
526 let output = gcn.forward_graph(&x, &adj);
527 assert_eq!(output.shape(), vec![1, 3, 8]);
528
529 let data = output.data().to_vec();
531 for i in 0..3 {
532 for f in 0..8 {
533 assert!(
534 (data[i * 8 + f] - data[f]).abs() < 1e-6,
535 "Node outputs should be identical with identity adj and same input"
536 );
537 }
538 }
539 }
540
541 #[test]
542 fn test_gcn_conv_parameters() {
543 let gcn = GCNConv::new(16, 32);
544 let params = gcn.parameters();
545 assert_eq!(params.len(), 2); let total_params: usize = params.iter().map(|p| p.numel()).sum();
548 assert_eq!(total_params, 16 * 32 + 32); }
550
551 #[test]
552 fn test_gcn_conv_no_bias() {
553 let gcn = GCNConv::without_bias(16, 32);
554 let params = gcn.parameters();
555 assert_eq!(params.len(), 1); }
557
558 #[test]
559 fn test_gcn_conv_named_parameters() {
560 let gcn = GCNConv::new(16, 32);
561 let params = gcn.named_parameters();
562 assert!(params.contains_key("weight"));
563 assert!(params.contains_key("bias"));
564 }
565
566 #[test]
567 fn test_gat_conv_shape() {
568 let gat = GATConv::new(72, 32, 4); let x = Variable::new(
570 Tensor::from_vec(vec![1.0; 2 * 7 * 72], &[2, 7, 72]).unwrap(),
571 false,
572 );
573 let adj = Variable::new(Tensor::from_vec(vec![1.0; 7 * 7], &[7, 7]).unwrap(), false);
574 let output = gat.forward_graph(&x, &adj);
575 assert_eq!(output.shape(), vec![2, 7, 128]); }
577
578 #[test]
579 fn test_gat_conv_single_head() {
580 let gat = GATConv::new(16, 8, 1);
581 let x = Variable::new(
582 Tensor::from_vec(vec![1.0; 1 * 5 * 16], &[1, 5, 16]).unwrap(),
583 false,
584 );
585 let adj = Variable::new(Tensor::from_vec(vec![1.0; 5 * 5], &[5, 5]).unwrap(), false);
586 let output = gat.forward_graph(&x, &adj);
587 assert_eq!(output.shape(), vec![1, 5, 8]);
588 }
589
590 #[test]
591 fn test_gat_conv_parameters() {
592 let gat = GATConv::new(16, 8, 4);
593 let params = gat.parameters();
594 assert_eq!(params.len(), 4); let named = gat.named_parameters();
597 assert!(named.contains_key("w"));
598 assert!(named.contains_key("attn_src"));
599 assert!(named.contains_key("attn_dst"));
600 assert!(named.contains_key("bias"));
601 }
602
603 #[test]
604 fn test_gat_conv_total_output() {
605 let gat = GATConv::new(16, 32, 4);
606 assert_eq!(gat.total_out_features(), 128);
607 }
608
609 #[test]
610 fn test_gcn_zero_adjacency() {
611 let gcn = GCNConv::new(4, 4);
613 let x = Variable::new(
614 Tensor::from_vec(vec![99.0; 1 * 3 * 4], &[1, 3, 4]).unwrap(),
615 false,
616 );
617 let adj = Variable::new(Tensor::from_vec(vec![0.0; 9], &[3, 3]).unwrap(), false);
618 let output = gcn.forward_graph(&x, &adj);
619
620 let data = output.data().to_vec();
622 for val in &data {
623 assert!(
624 val.abs() < 1e-6,
625 "Zero adjacency should zero out message passing"
626 );
627 }
628 }
629}