1#[cfg(feature = "tensor-gnn")]
25use rand_distr::{Distribution, StandardNormal};
26
27#[cfg(feature = "tensor-gnn")]
28use crate::tensor::traits::{TensorBase, TensorOps};
29
30#[cfg(feature = "tensor-gnn")]
31use crate::tensor::dense::DenseTensor;
32
33#[cfg(feature = "tensor-gnn")]
34use crate::tensor::sparse::SparseTensor;
35
36#[cfg(all(feature = "tensor-gnn", not(feature = "std")))]
37use rand::{rngs::StdRng, SeedableRng};
38
39#[cfg(all(feature = "tensor-gnn", feature = "std"))]
40use rand::thread_rng;
41
42pub trait MessageFunction<H: TensorBase>: Send + Sync {
44 fn message(&self, src_features: &H, edge_features: Option<&H>, dst_features: &H) -> H;
54}
55
56pub trait Aggregator<H: TensorBase>: Send + Sync {
58 fn aggregate(&self, messages: &[H]) -> H;
66}
67
68pub trait UpdateFunction<H: TensorBase>: Send + Sync {
70 fn update(&self, old_state: &H, new_message: &H) -> H;
79}
80
81#[derive(Debug, Clone, Default)]
83pub struct SumAggregator;
84
85#[cfg(feature = "tensor-gnn")]
86impl Aggregator<DenseTensor> for SumAggregator {
87 fn aggregate(&self, messages: &[DenseTensor]) -> DenseTensor {
88 if messages.is_empty() {
89 return DenseTensor::zeros(vec![1]);
90 }
91
92 let mut result = messages[0].clone();
93 for msg in &messages[1..] {
94 result = result.add(msg);
95 }
96 result
97 }
98}
99
100#[derive(Debug, Clone, Default)]
102pub struct MeanAggregator;
103
104#[cfg(feature = "tensor-gnn")]
105impl Aggregator<DenseTensor> for MeanAggregator {
106 fn aggregate(&self, messages: &[DenseTensor]) -> DenseTensor {
107 if messages.is_empty() {
108 return DenseTensor::zeros(vec![1]);
109 }
110
111 let sum = SumAggregator.aggregate(messages);
112 sum.mul_scalar(1.0 / messages.len() as f64)
113 }
114}
115
116#[derive(Debug, Clone, Default)]
118pub struct MaxAggregator;
119
120#[cfg(feature = "tensor-gnn")]
121impl Aggregator<DenseTensor> for MaxAggregator {
122 fn aggregate(&self, messages: &[DenseTensor]) -> DenseTensor {
123 if messages.is_empty() {
124 return DenseTensor::zeros(vec![1]);
125 }
126
127 let mut result = messages[0].clone();
128 for msg in &messages[1..] {
129 let data = result.data().to_vec();
131 let msg_data = msg.data();
132 let max_data: Vec<f64> = data
133 .iter()
134 .zip(msg_data.iter())
135 .map(|(&a, &b)| a.max(b))
136 .collect();
137 result = DenseTensor::new(max_data, result.shape().to_vec());
138 }
139 result
140 }
141}
142
143#[derive(Debug, Clone, Default)]
145pub struct IdentityMessage;
146
147#[cfg(feature = "tensor-gnn")]
148impl MessageFunction<DenseTensor> for IdentityMessage {
149 fn message(
150 &self,
151 src_features: &DenseTensor,
152 _edge_features: Option<&DenseTensor>,
153 _dst_features: &DenseTensor,
154 ) -> DenseTensor {
155 src_features.clone()
156 }
157}
158
159#[derive(Debug, Clone)]
161pub struct LinearMessage {
162 weight: DenseTensor,
164}
165
166#[cfg(feature = "tensor-gnn")]
167impl LinearMessage {
168 pub fn new(in_features: usize, out_features: usize) -> Self {
170 let std = (2.0 / (in_features + out_features) as f64).sqrt();
172 let mut rng = thread_rng();
173 let weight_data: Vec<f64> = (0..in_features * out_features)
174 .map(|_| {
175 let x: f64 = StandardNormal.sample(&mut rng);
176 x * std
177 })
178 .collect();
179
180 Self {
181 weight: DenseTensor::new(weight_data, vec![in_features, out_features]),
182 }
183 }
184}
185
186#[cfg(feature = "tensor-gnn")]
187impl MessageFunction<DenseTensor> for LinearMessage {
188 fn message(
189 &self,
190 src_features: &DenseTensor,
191 _edge_features: Option<&DenseTensor>,
192 _dst_features: &DenseTensor,
193 ) -> DenseTensor {
194 src_features.matmul(&self.weight.transpose(None))
196 }
197}
198
199pub struct MessagePassingLayer<M, A, U> {
201 message_fn: M,
203 aggregator: A,
205 update_fn: U,
207}
208
209impl<M, A, U> MessagePassingLayer<M, A, U>
210where
211 M: MessageFunction<DenseTensor>,
212 A: Aggregator<DenseTensor>,
213 U: UpdateFunction<DenseTensor>,
214{
215 pub fn new(message_fn: M, aggregator: A, update_fn: U) -> Self {
217 Self {
218 message_fn,
219 aggregator,
220 update_fn,
221 }
222 }
223
224 pub fn forward(
234 &self,
235 node_features: &DenseTensor,
236 edge_index: &[(usize, usize)],
237 edge_features: Option<&DenseTensor>,
238 ) -> DenseTensor {
239 let mut messages: Vec<Vec<DenseTensor>> = vec![Vec::new(); node_features.shape()[0]];
241
242 for (src, dst) in edge_index {
243 let src_feat = self.extract_node(node_features, *src);
244 let dst_feat = self.extract_node(node_features, *dst);
245 let edge_feat = edge_features.map(|_| DenseTensor::scalar(1.0)); let msg = self
248 .message_fn
249 .message(&src_feat, edge_feat.as_ref(), &dst_feat);
250 messages[*dst].push(msg);
251 }
252
253 let mut updated_features = Vec::new();
255 for (node_idx, node_msgs) in messages.iter().enumerate() {
256 let old_state = self.extract_node(node_features, node_idx);
257
258 if node_msgs.is_empty() {
259 updated_features.extend_from_slice(old_state.data());
260 } else {
261 let aggregated = self.aggregator.aggregate(node_msgs);
262 let updated = self.update_fn.update(&old_state, &aggregated);
263 updated_features.extend_from_slice(updated.data());
264 }
265 }
266
267 DenseTensor::new(updated_features, node_features.shape().to_vec())
268 }
269
270 fn extract_node(&self, features: &DenseTensor, node_idx: usize) -> DenseTensor {
272 let num_features = features.shape()[1];
273 let start = node_idx * num_features;
274 let _end = start + num_features;
275 features.slice(&[0, 1], &[node_idx..node_idx + 1, 0..num_features])
276 }
277}
278
279#[allow(dead_code)]
281pub struct GCNConv {
282 in_features: usize,
284 out_features: usize,
286 weight: DenseTensor,
288 bias: DenseTensor,
290}
291
292#[cfg(feature = "tensor-gnn")]
293impl GCNConv {
294 pub fn new(in_features: usize, out_features: usize) -> Self {
296 let std = (6.0 / (in_features + out_features) as f64).sqrt();
298 let mut rng = thread_rng();
299 let weight_data: Vec<f64> = (0..in_features * out_features)
300 .map(|_| {
301 let x: f64 = StandardNormal.sample(&mut rng);
302 x * std
303 })
304 .collect();
305
306 let bias_data = vec![0.0; out_features];
307
308 Self {
309 in_features,
310 out_features,
311 weight: DenseTensor::new(weight_data, vec![in_features, out_features]),
312 bias: DenseTensor::new(bias_data, vec![out_features]),
313 }
314 }
315
316 pub fn forward(&self, node_features: &DenseTensor, adjacency: &SparseTensor) -> DenseTensor {
325 let h_transformed = node_features.matmul(&self.weight);
327
328 let normalized = self.normalize_adjacency(adjacency);
330
331 normalized.spmv(&h_transformed).unwrap()
333 }
334
335 fn normalize_adjacency(&self, adjacency: &SparseTensor) -> SparseTensor {
337 let degrees = self.compute_degrees(adjacency);
339
340 let _inv_sqrt_degrees = degrees.map(|d: f64| if d > 1e-10 { 1.0 / d.sqrt() } else { 0.0 });
342
343 adjacency.clone() }
347
348 fn compute_degrees(&self, adjacency: &SparseTensor) -> DenseTensor {
350 let num_nodes = adjacency.shape()[0];
351 let mut degrees = vec![0.0; num_nodes];
352
353 let coo = adjacency.to_coo();
354 for &row in coo.row_indices() {
355 degrees[row] += 1.0;
356 }
357
358 DenseTensor::new(degrees, vec![num_nodes])
359 }
360}
361
362#[allow(dead_code)]
364pub struct GATConv {
365 in_features: usize,
367 out_features: usize,
369 num_heads: usize,
371 attention_vec: DenseTensor,
373}
374
375#[cfg(feature = "tensor-gnn")]
376impl GATConv {
377 pub fn new(in_features: usize, out_features: usize, num_heads: usize) -> Self {
379 let std = (6.0 / (in_features + out_features) as f64).sqrt();
380 let mut rng = thread_rng();
381 let attention_data: Vec<f64> = (0..out_features * 2)
382 .map(|_| {
383 let x: f64 = StandardNormal.sample(&mut rng);
384 x * std
385 })
386 .collect();
387
388 Self {
389 in_features,
390 out_features,
391 num_heads,
392 attention_vec: DenseTensor::new(attention_data, vec![out_features * 2]),
393 }
394 }
395
396 pub fn forward(
398 &self,
399 node_features: &DenseTensor,
400 edge_index: &[(usize, usize)],
401 ) -> DenseTensor {
402 let h_transformed = node_features.matmul(&self.weight());
404
405 let attention_scores = self.compute_attention(node_features, edge_index);
407
408 let normalized_attention = self.softmax(&attention_scores, edge_index);
410
411 self.aggregate_with_attention(&h_transformed, &normalized_attention, edge_index)
413 }
414
415 fn weight(&self) -> DenseTensor {
417 DenseTensor::eye(self.in_features)
419 }
420
421 fn compute_attention(
423 &self,
424 node_features: &DenseTensor,
425 edge_index: &[(usize, usize)],
426 ) -> Vec<f64> {
427 edge_index
428 .iter()
429 .map(|(src, dst)| {
430 let src_feat = node_features.data()
431 [src * self.in_features..(src + 1) * self.in_features]
432 .to_vec();
433 let dst_feat = node_features.data()
434 [dst * self.in_features..(dst + 1) * self.in_features]
435 .to_vec();
436
437 let mut concatenated = src_feat;
439 concatenated.extend_from_slice(&dst_feat);
440
441 let score: f64 = concatenated
443 .iter()
444 .zip(self.attention_vec.data().iter().cycle())
445 .map(|(&a, &b)| a * b)
446 .sum();
447
448 score.max(0.0) })
450 .collect()
451 }
452
453 fn softmax(&self, scores: &[f64], edge_index: &[(usize, usize)]) -> Vec<f64> {
455 let mut dst_scores: std::collections::HashMap<usize, Vec<(usize, f64)>> =
457 std::collections::HashMap::new();
458
459 for ((src, dst), score) in edge_index.iter().zip(scores.iter()) {
460 dst_scores.entry(*dst).or_default().push((*src, *score));
461 }
462
463 let mut normalized = vec![0.0; scores.len()];
465 for (dst, scores) in dst_scores {
466 let max_score = scores
467 .iter()
468 .map(|(_, s)| *s)
469 .fold(f64::NEG_INFINITY, f64::max);
470 let exp_scores: Vec<(usize, f64)> = scores
471 .iter()
472 .map(|(src, s)| (*src, (*s - max_score).exp()))
473 .collect();
474
475 let sum_exp: f64 = exp_scores.iter().map(|(_, e)| *e).sum();
476
477 for (src, exp_val) in exp_scores {
478 if let Some(idx) = edge_index.iter().position(|(s, d)| *s == src && *d == dst) {
480 normalized[idx] = exp_val / sum_exp;
481 }
482 }
483 }
484
485 normalized
486 }
487
488 fn aggregate_with_attention(
490 &self,
491 node_features: &DenseTensor,
492 attention: &[f64],
493 edge_index: &[(usize, usize)],
494 ) -> DenseTensor {
495 let num_nodes = node_features.shape()[0];
496 let mut result = vec![0.0; num_nodes * self.out_features];
497
498 for ((src, dst), &attn) in edge_index.iter().zip(attention.iter()) {
499 for i in 0..self.out_features {
500 result[dst * self.out_features + i] +=
501 attn * node_features.data()[src * self.in_features + i];
502 }
503 }
504
505 DenseTensor::new(result, vec![num_nodes, self.out_features])
506 }
507}
508
509pub struct GraphSAGE {
511 in_features: usize,
513 out_features: usize,
515 num_samples: usize,
517}
518
519#[cfg(feature = "tensor-gnn")]
520impl GraphSAGE {
521 pub fn new(in_features: usize, out_features: usize, num_samples: usize) -> Self {
523 Self {
524 in_features,
525 out_features,
526 num_samples,
527 }
528 }
529
530 pub fn forward(
532 &self,
533 node_features: &DenseTensor,
534 edge_index: &[(usize, usize)],
535 ) -> DenseTensor {
536 let num_nodes = node_features.shape()[0];
537 let mut result = Vec::new();
538
539 for node_idx in 0..num_nodes {
540 let neighbors: Vec<usize> = edge_index
542 .iter()
543 .filter(|(src, _)| *src == node_idx)
544 .take(self.num_samples)
545 .map(|(_, dst)| *dst)
546 .collect();
547
548 let neighbor_features = if neighbors.is_empty() {
550 DenseTensor::zeros(vec![self.in_features])
551 } else {
552 let features: Vec<DenseTensor> = neighbors
553 .iter()
554 .map(|&n| {
555 let start = n * self.in_features;
556 let end = start + self.in_features;
557 DenseTensor::new(
558 node_features.data()[start..end].to_vec(),
559 vec![self.in_features],
560 )
561 })
562 .collect();
563 MeanAggregator.aggregate(&features)
564 };
565
566 let self_features = node_features.data()
568 [node_idx * self.in_features..(node_idx + 1) * self.in_features]
569 .to_vec();
570 let mut concatenated = self_features;
571 concatenated.extend_from_slice(neighbor_features.data());
572
573 let transformed: Vec<f64> = concatenated
575 .iter()
576 .take(self.out_features)
577 .copied()
578 .collect();
579
580 result.extend_from_slice(&transformed);
581 }
582
583 DenseTensor::new(result, vec![num_nodes, self.out_features])
584 }
585}
586
587#[cfg(test)]
588mod tests {
589 use super::*;
590
591 #[test]
592 fn test_sum_aggregator() {
593 let aggregator = SumAggregator;
594 let messages = vec![
595 DenseTensor::new(vec![1.0, 2.0], vec![2]),
596 DenseTensor::new(vec![3.0, 4.0], vec![2]),
597 DenseTensor::new(vec![5.0, 6.0], vec![2]),
598 ];
599
600 let result = aggregator.aggregate(&messages);
601 assert_eq!(result.data(), &[9.0, 12.0]);
602 }
603
604 #[test]
605 fn test_mean_aggregator() {
606 let aggregator = MeanAggregator;
607 let messages = vec![
608 DenseTensor::new(vec![1.0, 2.0], vec![2]),
609 DenseTensor::new(vec![3.0, 4.0], vec![2]),
610 DenseTensor::new(vec![5.0, 6.0], vec![2]),
611 ];
612
613 let result = aggregator.aggregate(&messages);
614 assert_eq!(result.data(), &[3.0, 4.0]);
615 }
616
617 #[test]
618 fn test_identity_message() {
619 let message_fn = IdentityMessage;
620 let src = DenseTensor::new(vec![1.0, 2.0, 3.0], vec![3]);
621 let dst = DenseTensor::new(vec![4.0, 5.0, 6.0], vec![3]);
622
623 let result = message_fn.message(&src, None, &dst);
624 assert_eq!(result.data(), src.data());
625 }
626}