1use crate::types::CsrGraph;
8use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct GNNConfig {
18 pub num_layers: usize,
20 pub hidden_dim: usize,
22 pub output_dim: usize,
24 pub aggregation: AggregationType,
26 pub activation: ActivationType,
28 pub dropout: f64,
30 pub add_self_loops: bool,
32 pub layer_norm: bool,
34}
35
36impl Default for GNNConfig {
37 fn default() -> Self {
38 Self {
39 num_layers: 2,
40 hidden_dim: 64,
41 output_dim: 32,
42 aggregation: AggregationType::Mean,
43 activation: ActivationType::ReLU,
44 dropout: 0.0,
45 add_self_loops: true,
46 layer_norm: true,
47 }
48 }
49}
50
51#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
53pub enum AggregationType {
54 Sum,
56 Mean,
58 Max,
60 SAGE,
62}
63
64#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
66pub enum ActivationType {
67 ReLU,
69 LeakyReLU,
71 ELU,
73 Sigmoid,
75 Tanh,
77 None,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct GNNWeights {
84 pub layer_weights: Vec<Vec<Vec<f64>>>,
86 pub layer_biases: Vec<Vec<f64>>,
88}
89
90impl GNNWeights {
91 pub fn random(input_dim: usize, config: &GNNConfig) -> Self {
93 use rand::{Rng, rng};
94 let mut r = rng();
95
96 let mut layer_weights = Vec::new();
97 let mut layer_biases = Vec::new();
98
99 let mut prev_dim = input_dim;
100
101 for i in 0..config.num_layers {
102 let out_dim = if i == config.num_layers - 1 {
103 config.output_dim
104 } else {
105 config.hidden_dim
106 };
107
108 let scale = (2.0 / (prev_dim + out_dim) as f64).sqrt();
110
111 let weights: Vec<Vec<f64>> = (0..prev_dim)
112 .map(|_| {
113 (0..out_dim)
114 .map(|_| r.random_range(-scale..scale))
115 .collect()
116 })
117 .collect();
118
119 let biases: Vec<f64> = (0..out_dim).map(|_| 0.0).collect();
120
121 layer_weights.push(weights);
122 layer_biases.push(biases);
123 prev_dim = out_dim;
124 }
125
126 Self {
127 layer_weights,
128 layer_biases,
129 }
130 }
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct GNNResult {
136 pub embeddings: Vec<Vec<f64>>,
138 pub predictions: Option<Vec<usize>>,
140 pub probabilities: Option<Vec<Vec<f64>>>,
142}
143
144#[derive(Debug, Clone)]
150pub struct GNNInference {
151 metadata: KernelMetadata,
152}
153
154impl Default for GNNInference {
155 fn default() -> Self {
156 Self::new()
157 }
158}
159
160impl GNNInference {
161 #[must_use]
163 pub fn new() -> Self {
164 Self {
165 metadata: KernelMetadata::batch("graph/gnn-inference", Domain::GraphAnalytics)
166 .with_description("Message passing neural network inference")
167 .with_throughput(10_000)
168 .with_latency_us(100.0),
169 }
170 }
171
172 pub fn compute(
174 graph: &CsrGraph,
175 node_features: &[Vec<f64>],
176 weights: &GNNWeights,
177 config: &GNNConfig,
178 ) -> GNNResult {
179 if graph.num_nodes == 0 || node_features.is_empty() {
180 return GNNResult {
181 embeddings: Vec::new(),
182 predictions: None,
183 probabilities: None,
184 };
185 }
186
187 let n = graph.num_nodes;
188
189 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
191 for node in 0..n {
192 let start = graph.row_offsets[node] as usize;
193 let end = graph.row_offsets[node + 1] as usize;
194 for &neighbor in &graph.col_indices[start..end] {
195 adj[node].push(neighbor as usize);
196 if !adj[neighbor as usize].contains(&node) {
198 adj[neighbor as usize].push(node);
199 }
200 }
201 }
202
203 if config.add_self_loops {
204 for i in 0..n {
205 if !adj[i].contains(&i) {
206 adj[i].push(i);
207 }
208 }
209 }
210
211 let mut embeddings: Vec<Vec<f64>> = node_features.to_vec();
213
214 for layer_idx in 0..config.num_layers {
216 embeddings = Self::message_passing_layer(
217 &embeddings,
218 &adj,
219 &weights.layer_weights[layer_idx],
220 &weights.layer_biases[layer_idx],
221 config,
222 layer_idx == config.num_layers - 1,
223 );
224 }
225
226 let (predictions, probabilities) = if config.output_dim > 1 {
228 let probs: Vec<Vec<f64>> = embeddings.iter().map(|e| Self::softmax(e)).collect();
229 let preds: Vec<usize> = probs
230 .iter()
231 .map(|p| {
232 p.iter()
233 .enumerate()
234 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
235 .map(|(i, _)| i)
236 .unwrap_or(0)
237 })
238 .collect();
239 (Some(preds), Some(probs))
240 } else {
241 (None, None)
242 };
243
244 GNNResult {
245 embeddings,
246 predictions,
247 probabilities,
248 }
249 }
250
251 fn message_passing_layer(
253 embeddings: &[Vec<f64>],
254 adj: &[Vec<usize>],
255 weights: &[Vec<f64>],
256 biases: &[f64],
257 config: &GNNConfig,
258 is_last: bool,
259 ) -> Vec<Vec<f64>> {
260 let n = embeddings.len();
261 let out_dim = biases.len();
262 let mut new_embeddings = vec![vec![0.0; out_dim]; n];
263
264 for i in 0..n {
265 let aggregated = Self::aggregate_neighbors(embeddings, &adj[i], config.aggregation);
267
268 for j in 0..out_dim {
270 let mut val = biases[j];
271 for (k, &agg_val) in aggregated.iter().enumerate() {
272 if k < weights.len() && j < weights[k].len() {
273 val += weights[k][j] * agg_val;
274 }
275 }
276
277 if !is_last {
279 val = Self::activate(val, config.activation);
280 }
281
282 new_embeddings[i][j] = val;
283 }
284
285 if config.layer_norm && !is_last {
287 let mean: f64 = new_embeddings[i].iter().sum::<f64>() / out_dim as f64;
288 let var: f64 = new_embeddings[i]
289 .iter()
290 .map(|x| (x - mean).powi(2))
291 .sum::<f64>()
292 / out_dim as f64;
293 let std = (var + 1e-5).sqrt();
294
295 for j in 0..out_dim {
296 new_embeddings[i][j] = (new_embeddings[i][j] - mean) / std;
297 }
298 }
299 }
300
301 new_embeddings
302 }
303
304 fn aggregate_neighbors(
306 embeddings: &[Vec<f64>],
307 neighbors: &[usize],
308 agg_type: AggregationType,
309 ) -> Vec<f64> {
310 if neighbors.is_empty() {
311 return vec![0.0; embeddings.get(0).map(|e| e.len()).unwrap_or(0)];
312 }
313
314 let dim = embeddings[neighbors[0]].len();
315
316 match agg_type {
317 AggregationType::Sum => {
318 let mut result = vec![0.0; dim];
319 for &n in neighbors {
320 for (i, &v) in embeddings[n].iter().enumerate() {
321 result[i] += v;
322 }
323 }
324 result
325 }
326 AggregationType::Mean => {
327 let mut result = vec![0.0; dim];
328 for &n in neighbors {
329 for (i, &v) in embeddings[n].iter().enumerate() {
330 result[i] += v;
331 }
332 }
333 let count = neighbors.len() as f64;
334 result.iter_mut().for_each(|v| *v /= count);
335 result
336 }
337 AggregationType::Max => {
338 let mut result = vec![f64::NEG_INFINITY; dim];
339 for &n in neighbors {
340 for (i, &v) in embeddings[n].iter().enumerate() {
341 result[i] = result[i].max(v);
342 }
343 }
344 result
345 }
346 AggregationType::SAGE => {
347 let mut result = vec![0.0; dim];
350 for &n in neighbors {
351 for (i, &v) in embeddings[n].iter().enumerate() {
352 result[i] += v;
353 }
354 }
355 let count = neighbors.len() as f64;
356 result.iter_mut().for_each(|v| *v /= count);
357 result
358 }
359 }
360 }
361
362 fn activate(x: f64, activation: ActivationType) -> f64 {
364 match activation {
365 ActivationType::ReLU => x.max(0.0),
366 ActivationType::LeakyReLU => {
367 if x > 0.0 {
368 x
369 } else {
370 0.01 * x
371 }
372 }
373 ActivationType::ELU => {
374 if x > 0.0 {
375 x
376 } else {
377 x.exp() - 1.0
378 }
379 }
380 ActivationType::Sigmoid => 1.0 / (1.0 + (-x).exp()),
381 ActivationType::Tanh => x.tanh(),
382 ActivationType::None => x,
383 }
384 }
385
386 fn softmax(x: &[f64]) -> Vec<f64> {
388 let max_val = x.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
389 let exp_sum: f64 = x.iter().map(|v| (v - max_val).exp()).sum();
390 x.iter().map(|v| (v - max_val).exp() / exp_sum).collect()
391 }
392}
393
394impl GpuKernel for GNNInference {
395 fn metadata(&self) -> &KernelMetadata {
396 &self.metadata
397 }
398}
399
400#[derive(Debug, Clone, Serialize, Deserialize)]
406pub struct GraphAttentionConfig {
407 pub num_heads: usize,
409 pub head_dim: usize,
411 pub output_dim: usize,
413 pub attention_dropout: f64,
415 pub concat_heads: bool,
417 pub negative_slope: f64,
419}
420
421impl Default for GraphAttentionConfig {
422 fn default() -> Self {
423 Self {
424 num_heads: 4,
425 head_dim: 16,
426 output_dim: 64,
427 attention_dropout: 0.0,
428 concat_heads: true,
429 negative_slope: 0.2,
430 }
431 }
432}
433
434#[derive(Debug, Clone, Serialize, Deserialize)]
436pub struct GATWeights {
437 pub query_weights: Vec<Vec<Vec<f64>>>,
439 pub key_weights: Vec<Vec<Vec<f64>>>,
441 pub value_weights: Vec<Vec<Vec<f64>>>,
443 pub attention_vectors: Vec<Vec<f64>>,
445 pub output_weights: Vec<Vec<f64>>,
447}
448
449impl GATWeights {
450 pub fn random(input_dim: usize, config: &GraphAttentionConfig) -> Self {
452 use rand::{Rng, rng};
453 let mut r = rng();
454
455 let scale = (2.0 / (input_dim + config.head_dim) as f64).sqrt();
456
457 let mut query_weights = Vec::new();
458 let mut key_weights = Vec::new();
459 let mut value_weights = Vec::new();
460 let mut attention_vectors = Vec::new();
461
462 for _ in 0..config.num_heads {
463 let q: Vec<Vec<f64>> = (0..input_dim)
464 .map(|_| {
465 (0..config.head_dim)
466 .map(|_| r.random_range(-scale..scale))
467 .collect()
468 })
469 .collect();
470 let k: Vec<Vec<f64>> = (0..input_dim)
471 .map(|_| {
472 (0..config.head_dim)
473 .map(|_| r.random_range(-scale..scale))
474 .collect()
475 })
476 .collect();
477 let v: Vec<Vec<f64>> = (0..input_dim)
478 .map(|_| {
479 (0..config.head_dim)
480 .map(|_| r.random_range(-scale..scale))
481 .collect()
482 })
483 .collect();
484 let a: Vec<f64> = (0..config.head_dim * 2)
485 .map(|_| r.random_range(-scale..scale))
486 .collect();
487
488 query_weights.push(q);
489 key_weights.push(k);
490 value_weights.push(v);
491 attention_vectors.push(a);
492 }
493
494 let total_dim = if config.concat_heads {
495 config.num_heads * config.head_dim
496 } else {
497 config.head_dim
498 };
499
500 let out_scale = (2.0 / (total_dim + config.output_dim) as f64).sqrt();
501 let output_weights: Vec<Vec<f64>> = (0..total_dim)
502 .map(|_| {
503 (0..config.output_dim)
504 .map(|_| r.random_range(-out_scale..out_scale))
505 .collect()
506 })
507 .collect();
508
509 Self {
510 query_weights,
511 key_weights,
512 value_weights,
513 attention_vectors,
514 output_weights,
515 }
516 }
517}
518
519#[derive(Debug, Clone, Serialize, Deserialize)]
521pub struct GATResult {
522 pub embeddings: Vec<Vec<f64>>,
524 pub attention_weights: Vec<Vec<(usize, usize, f64)>>,
526}
527
528#[derive(Debug, Clone)]
534pub struct GraphAttention {
535 metadata: KernelMetadata,
536}
537
538impl Default for GraphAttention {
539 fn default() -> Self {
540 Self::new()
541 }
542}
543
544impl GraphAttention {
545 #[must_use]
547 pub fn new() -> Self {
548 Self {
549 metadata: KernelMetadata::batch("graph/graph-attention", Domain::GraphAnalytics)
550 .with_description("Graph attention networks with multi-head attention")
551 .with_throughput(5_000)
552 .with_latency_us(200.0),
553 }
554 }
555
556 pub fn compute(
558 graph: &CsrGraph,
559 node_features: &[Vec<f64>],
560 weights: &GATWeights,
561 config: &GraphAttentionConfig,
562 ) -> GATResult {
563 if graph.num_nodes == 0 || node_features.is_empty() {
564 return GATResult {
565 embeddings: Vec::new(),
566 attention_weights: Vec::new(),
567 };
568 }
569
570 let n = graph.num_nodes;
571
572 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
574 for node in 0..n {
575 let start = graph.row_offsets[node] as usize;
576 let end = graph.row_offsets[node + 1] as usize;
577 for &neighbor in &graph.col_indices[start..end] {
578 adj[node].push(neighbor as usize);
579 if !adj[neighbor as usize].contains(&node) {
580 adj[neighbor as usize].push(node);
581 }
582 }
583 }
584 for i in 0..n {
585 if !adj[i].contains(&i) {
586 adj[i].push(i);
587 }
588 }
589
590 let mut head_outputs: Vec<Vec<Vec<f64>>> = Vec::new();
592 let mut all_attention_weights: Vec<Vec<(usize, usize, f64)>> = Vec::new();
593
594 for head in 0..config.num_heads {
595 let (output, attn_weights) = Self::compute_head(
596 node_features,
597 &adj,
598 &weights.query_weights[head],
599 &weights.key_weights[head],
600 &weights.value_weights[head],
601 &weights.attention_vectors[head],
602 config,
603 );
604 head_outputs.push(output);
605 all_attention_weights.push(attn_weights);
606 }
607
608 let combined: Vec<Vec<f64>> = if config.concat_heads {
610 (0..n)
611 .map(|i| head_outputs.iter().flat_map(|h| h[i].clone()).collect())
612 .collect()
613 } else {
614 (0..n)
616 .map(|i| {
617 let dim = head_outputs[0][i].len();
618 let mut avg = vec![0.0; dim];
619 for h in &head_outputs {
620 for (j, &v) in h[i].iter().enumerate() {
621 avg[j] += v;
622 }
623 }
624 avg.iter_mut().for_each(|v| *v /= config.num_heads as f64);
625 avg
626 })
627 .collect()
628 };
629
630 let embeddings: Vec<Vec<f64>> = combined
632 .iter()
633 .map(|c| Self::linear_transform(c, &weights.output_weights))
634 .collect();
635
636 GATResult {
637 embeddings,
638 attention_weights: all_attention_weights,
639 }
640 }
641
642 fn compute_head(
644 features: &[Vec<f64>],
645 adj: &[Vec<usize>],
646 query_w: &[Vec<f64>],
647 key_w: &[Vec<f64>],
648 value_w: &[Vec<f64>],
649 attn_vec: &[f64],
650 config: &GraphAttentionConfig,
651 ) -> (Vec<Vec<f64>>, Vec<(usize, usize, f64)>) {
652 let n = features.len();
653 let head_dim = config.head_dim;
654
655 let queries: Vec<Vec<f64>> = features
657 .iter()
658 .map(|f| Self::linear_transform(f, query_w))
659 .collect();
660 let keys: Vec<Vec<f64>> = features
661 .iter()
662 .map(|f| Self::linear_transform(f, key_w))
663 .collect();
664 let values: Vec<Vec<f64>> = features
665 .iter()
666 .map(|f| Self::linear_transform(f, value_w))
667 .collect();
668
669 let mut output = vec![vec![0.0; head_dim]; n];
670 let mut attention_list: Vec<(usize, usize, f64)> = Vec::new();
671
672 for i in 0..n {
673 if adj[i].is_empty() {
674 continue;
675 }
676
677 let mut scores: Vec<f64> = Vec::with_capacity(adj[i].len());
679
680 for &j in &adj[i] {
681 let mut concat = queries[i].clone();
683 concat.extend(keys[j].iter().cloned());
684
685 let score: f64 = concat.iter().zip(attn_vec.iter()).map(|(c, a)| c * a).sum();
686
687 let score = if score > 0.0 {
689 score
690 } else {
691 config.negative_slope * score
692 };
693
694 scores.push(score);
695 }
696
697 let max_score = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
699 let exp_scores: Vec<f64> = scores.iter().map(|s| (s - max_score).exp()).collect();
700 let sum_exp: f64 = exp_scores.iter().sum();
701 let attention: Vec<f64> = exp_scores.iter().map(|e| e / sum_exp).collect();
702
703 for (idx, &j) in adj[i].iter().enumerate() {
705 let attn = attention[idx];
706 attention_list.push((i, j, attn));
707
708 for (k, &v) in values[j].iter().enumerate() {
709 output[i][k] += attn * v;
710 }
711 }
712 }
713
714 (output, attention_list)
715 }
716
717 fn linear_transform(input: &[f64], weights: &[Vec<f64>]) -> Vec<f64> {
719 if weights.is_empty() {
720 return Vec::new();
721 }
722
723 let out_dim = weights[0].len();
724 let mut output = vec![0.0; out_dim];
725
726 for (i, &x) in input.iter().enumerate() {
727 if i < weights.len() {
728 for (j, &w) in weights[i].iter().enumerate() {
729 output[j] += x * w;
730 }
731 }
732 }
733
734 output
735 }
736
737 pub fn node_importance(attention_weights: &[(usize, usize, f64)], n: usize) -> Vec<f64> {
739 let mut importance = vec![0.0; n];
740 let mut counts = vec![0usize; n];
741
742 for &(_, target, weight) in attention_weights {
743 if target < n {
744 importance[target] += weight;
745 counts[target] += 1;
746 }
747 }
748
749 for i in 0..n {
751 if counts[i] > 0 {
752 importance[i] /= counts[i] as f64;
753 }
754 }
755
756 importance
757 }
758}
759
760impl GpuKernel for GraphAttention {
761 fn metadata(&self) -> &KernelMetadata {
762 &self.metadata
763 }
764}
765
766#[cfg(test)]
767mod tests {
768 use super::*;
769 use std::collections::HashMap;
770
771 fn create_test_graph() -> CsrGraph {
772 CsrGraph::from_edges(3, &[(0, 1), (1, 2), (2, 0)])
774 }
775
776 #[test]
777 fn test_gnn_inference_metadata() {
778 let kernel = GNNInference::new();
779 assert_eq!(kernel.metadata().id, "graph/gnn-inference");
780 }
781
782 #[test]
783 fn test_gnn_inference_basic() {
784 let graph = create_test_graph();
785 let features = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
786
787 let config = GNNConfig {
788 num_layers: 2,
789 hidden_dim: 4,
790 output_dim: 2,
791 ..Default::default()
792 };
793
794 let weights = GNNWeights::random(2, &config);
795 let result = GNNInference::compute(&graph, &features, &weights, &config);
796
797 assert_eq!(result.embeddings.len(), 3);
798 assert_eq!(result.embeddings[0].len(), 2);
799 assert!(result.predictions.is_some());
800 }
801
802 #[test]
803 fn test_gnn_aggregation_types() {
804 let graph = create_test_graph();
805 let features = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
806
807 for agg in [
808 AggregationType::Sum,
809 AggregationType::Mean,
810 AggregationType::Max,
811 AggregationType::SAGE,
812 ] {
813 let config = GNNConfig {
814 aggregation: agg,
815 num_layers: 1,
816 hidden_dim: 4,
817 output_dim: 2,
818 ..Default::default()
819 };
820
821 let weights = GNNWeights::random(2, &config);
822 let result = GNNInference::compute(&graph, &features, &weights, &config);
823
824 assert_eq!(result.embeddings.len(), 3);
825 }
826 }
827
828 #[test]
829 fn test_gnn_empty_graph() {
830 let graph = CsrGraph::empty();
831 let features: Vec<Vec<f64>> = vec![];
832 let config = GNNConfig::default();
833 let weights = GNNWeights::random(2, &config);
834
835 let result = GNNInference::compute(&graph, &features, &weights, &config);
836 assert!(result.embeddings.is_empty());
837 }
838
839 #[test]
840 fn test_graph_attention_metadata() {
841 let kernel = GraphAttention::new();
842 assert_eq!(kernel.metadata().id, "graph/graph-attention");
843 }
844
845 #[test]
846 fn test_graph_attention_basic() {
847 let graph = create_test_graph();
848 let features = vec![
849 vec![1.0, 0.0, 0.0, 0.0],
850 vec![0.0, 1.0, 0.0, 0.0],
851 vec![0.0, 0.0, 1.0, 0.0],
852 ];
853
854 let config = GraphAttentionConfig {
855 num_heads: 2,
856 head_dim: 4,
857 output_dim: 3,
858 ..Default::default()
859 };
860
861 let weights = GATWeights::random(4, &config);
862 let result = GraphAttention::compute(&graph, &features, &weights, &config);
863
864 assert_eq!(result.embeddings.len(), 3);
865 assert_eq!(result.embeddings[0].len(), 3);
866 assert!(!result.attention_weights.is_empty());
867 }
868
869 #[test]
870 fn test_attention_weights_sum_to_one() {
871 let graph = create_test_graph();
872 let features = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
873
874 let config = GraphAttentionConfig {
875 num_heads: 1,
876 head_dim: 4,
877 output_dim: 2,
878 ..Default::default()
879 };
880
881 let weights = GATWeights::random(2, &config);
882 let result = GraphAttention::compute(&graph, &features, &weights, &config);
883
884 let mut sums: HashMap<usize, f64> = HashMap::new();
886 for &(src, _, weight) in &result.attention_weights[0] {
887 *sums.entry(src).or_insert(0.0) += weight;
888 }
889
890 for (_, sum) in sums {
892 assert!(
893 (sum - 1.0).abs() < 0.01,
894 "Attention should sum to 1, got {}",
895 sum
896 );
897 }
898 }
899
900 #[test]
901 fn test_node_importance() {
902 let attn_weights = vec![
903 (0, 1, 0.5),
904 (0, 2, 0.5),
905 (1, 0, 0.3),
906 (1, 2, 0.7),
907 (2, 0, 0.4),
908 (2, 1, 0.6),
909 ];
910
911 let importance = GraphAttention::node_importance(&attn_weights, 3);
912
913 assert_eq!(importance.len(), 3);
914 assert!(importance.iter().all(|&i| i >= 0.0));
916 }
917
918 #[test]
919 fn test_activation_functions() {
920 assert_eq!(GNNInference::activate(1.0, ActivationType::ReLU), 1.0);
921 assert_eq!(GNNInference::activate(-1.0, ActivationType::ReLU), 0.0);
922 assert!((GNNInference::activate(0.0, ActivationType::Sigmoid) - 0.5).abs() < 0.001);
923 assert_eq!(GNNInference::activate(1.0, ActivationType::None), 1.0);
924 }
925}