Skip to main content

god_graph/tensor/
gnn.rs

1//! GNN(图神经网络)原语模块
2//!
3//! 提供图神经网络的核心构建块:
4//! - 消息传递框架
5//! - 图卷积层(GCN, GAT, GraphSAGE)
6//! - 图 pooling 和 normalization
7//!
8//! ## 示例
9//!
10//! ```ignore
11//! # #[cfg(feature = "tensor-gnn")]
12//! # {
13//! use god_gragh::tensor::gnn::{MessagePassingLayer, GCNConv, SumAggregator};
14//! use god_gragh::tensor::DenseTensor;
15//!
16//! // 创建 GCN 层
17//! let gcn = GCNConv::new(64, 64);
18//!
19//! // 前向传播
20//! let output = gcn.forward(&node_features, &adjacency);
21//! # }
22//! ```
23
24#[cfg(feature = "tensor-gnn")]
25use rand_distr::{Distribution, StandardNormal};
26
27#[cfg(feature = "tensor-gnn")]
28use crate::tensor::traits::{TensorBase, TensorOps};
29
30#[cfg(feature = "tensor-gnn")]
31use crate::tensor::dense::DenseTensor;
32
33#[cfg(feature = "tensor-gnn")]
34use crate::tensor::sparse::SparseTensor;
35
36#[cfg(all(feature = "tensor-gnn", not(feature = "std")))]
37use rand::{rngs::StdRng, SeedableRng};
38
39#[cfg(all(feature = "tensor-gnn", feature = "std"))]
40use rand::thread_rng;
41
42/// 消息函数 trait:定义边上的消息计算
43pub trait MessageFunction<H: TensorBase>: Send + Sync {
44    /// 计算消息
45    ///
46    /// # Arguments
47    /// * `src_features` - 源节点特征
48    /// * `edge_features` - 边特征(可选)
49    /// * `dst_features` - 目标节点特征
50    ///
51    /// # Returns
52    /// 返回计算得到的消息张量
53    fn message(&self, src_features: &H, edge_features: Option<&H>, dst_features: &H) -> H;
54}
55
56/// 聚合器 trait:定义邻居消息的聚合方式
57pub trait Aggregator<H: TensorBase>: Send + Sync {
58    /// 聚合消息
59    ///
60    /// # Arguments
61    /// * `messages` - 消息切片
62    ///
63    /// # Returns
64    /// 返回聚合后的张量
65    fn aggregate(&self, messages: &[H]) -> H;
66}
67
68/// 更新函数 trait:定义节点状态更新
69pub trait UpdateFunction<H: TensorBase>: Send + Sync {
70    /// 更新节点状态
71    ///
72    /// # Arguments
73    /// * `old_state` - 旧的节点状态
74    /// * `new_message` - 新聚合的消息
75    ///
76    /// # Returns
77    /// 返回更新后的状态
78    fn update(&self, old_state: &H, new_message: &H) -> H;
79}
80
81/// 求和聚合器
82#[derive(Debug, Clone, Default)]
83pub struct SumAggregator;
84
85#[cfg(feature = "tensor-gnn")]
86impl Aggregator<DenseTensor> for SumAggregator {
87    fn aggregate(&self, messages: &[DenseTensor]) -> DenseTensor {
88        if messages.is_empty() {
89            return DenseTensor::zeros(vec![1]);
90        }
91
92        let mut result = messages[0].clone();
93        for msg in &messages[1..] {
94            result = result.add(msg);
95        }
96        result
97    }
98}
99
100/// 均值聚合器
101#[derive(Debug, Clone, Default)]
102pub struct MeanAggregator;
103
104#[cfg(feature = "tensor-gnn")]
105impl Aggregator<DenseTensor> for MeanAggregator {
106    fn aggregate(&self, messages: &[DenseTensor]) -> DenseTensor {
107        if messages.is_empty() {
108            return DenseTensor::zeros(vec![1]);
109        }
110
111        let sum = SumAggregator.aggregate(messages);
112        sum.mul_scalar(1.0 / messages.len() as f64)
113    }
114}
115
116/// 最大值聚合器
117#[derive(Debug, Clone, Default)]
118pub struct MaxAggregator;
119
120#[cfg(feature = "tensor-gnn")]
121impl Aggregator<DenseTensor> for MaxAggregator {
122    fn aggregate(&self, messages: &[DenseTensor]) -> DenseTensor {
123        if messages.is_empty() {
124            return DenseTensor::zeros(vec![1]);
125        }
126
127        let mut result = messages[0].clone();
128        for msg in &messages[1..] {
129            // 逐元素取最大值
130            let data = result.data().to_vec();
131            let msg_data = msg.data();
132            let max_data: Vec<f64> = data
133                .iter()
134                .zip(msg_data.iter())
135                .map(|(&a, &b)| a.max(b))
136                .collect();
137            result = DenseTensor::new(max_data, result.shape().to_vec());
138        }
139        result
140    }
141}
142
143/// 恒等消息函数:直接传递源节点特征
144#[derive(Debug, Clone, Default)]
145pub struct IdentityMessage;
146
147#[cfg(feature = "tensor-gnn")]
148impl MessageFunction<DenseTensor> for IdentityMessage {
149    fn message(
150        &self,
151        src_features: &DenseTensor,
152        _edge_features: Option<&DenseTensor>,
153        _dst_features: &DenseTensor,
154    ) -> DenseTensor {
155        src_features.clone()
156    }
157}
158
159/// 线性消息函数:应用线性变换
160#[derive(Debug, Clone)]
161pub struct LinearMessage {
162    /// 权重矩阵
163    weight: DenseTensor,
164}
165
166#[cfg(feature = "tensor-gnn")]
167impl LinearMessage {
168    /// 创建新的线性消息函数
169    pub fn new(in_features: usize, out_features: usize) -> Self {
170        // Xavier 初始化
171        let std = (2.0 / (in_features + out_features) as f64).sqrt();
172        let mut rng = thread_rng();
173        let weight_data: Vec<f64> = (0..in_features * out_features)
174            .map(|_| {
175                let x: f64 = StandardNormal.sample(&mut rng);
176                x * std
177            })
178            .collect();
179
180        Self {
181            weight: DenseTensor::new(weight_data, vec![in_features, out_features]),
182        }
183    }
184}
185
186#[cfg(feature = "tensor-gnn")]
187impl MessageFunction<DenseTensor> for LinearMessage {
188    fn message(
189        &self,
190        src_features: &DenseTensor,
191        _edge_features: Option<&DenseTensor>,
192        _dst_features: &DenseTensor,
193    ) -> DenseTensor {
194        // src_features @ weight.T
195        src_features.matmul(&self.weight.transpose(None))
196    }
197}
198
199/// 消息传递层:GNN 的核心构建块
200pub struct MessagePassingLayer<M, A, U> {
201    /// 消息函数
202    message_fn: M,
203    /// 聚合器
204    aggregator: A,
205    /// 更新函数
206    update_fn: U,
207}
208
209impl<M, A, U> MessagePassingLayer<M, A, U>
210where
211    M: MessageFunction<DenseTensor>,
212    A: Aggregator<DenseTensor>,
213    U: UpdateFunction<DenseTensor>,
214{
215    /// 创建新的消息传递层
216    pub fn new(message_fn: M, aggregator: A, update_fn: U) -> Self {
217        Self {
218            message_fn,
219            aggregator,
220            update_fn,
221        }
222    }
223
224    /// 前向传播
225    ///
226    /// # Arguments
227    /// * `node_features` - 节点特征 [num_nodes, hidden_size]
228    /// * `edge_index` - 边索引 [(src, dst), ...]
229    /// * `edge_features` - 边特征(可选)
230    ///
231    /// # Returns
232    /// 返回更新后的节点特征
233    pub fn forward(
234        &self,
235        node_features: &DenseTensor,
236        edge_index: &[(usize, usize)],
237        edge_features: Option<&DenseTensor>,
238    ) -> DenseTensor {
239        // 为每个节点收集消息
240        let mut messages: Vec<Vec<DenseTensor>> = vec![Vec::new(); node_features.shape()[0]];
241
242        for (src, dst) in edge_index {
243            let src_feat = self.extract_node(node_features, *src);
244            let dst_feat = self.extract_node(node_features, *dst);
245            let edge_feat = edge_features.map(|_| DenseTensor::scalar(1.0)); // 简化
246
247            let msg = self
248                .message_fn
249                .message(&src_feat, edge_feat.as_ref(), &dst_feat);
250            messages[*dst].push(msg);
251        }
252
253        // 聚合消息并更新
254        let mut updated_features = Vec::new();
255        for (node_idx, node_msgs) in messages.iter().enumerate() {
256            let old_state = self.extract_node(node_features, node_idx);
257
258            if node_msgs.is_empty() {
259                updated_features.extend_from_slice(old_state.data());
260            } else {
261                let aggregated = self.aggregator.aggregate(node_msgs);
262                let updated = self.update_fn.update(&old_state, &aggregated);
263                updated_features.extend_from_slice(updated.data());
264            }
265        }
266
267        DenseTensor::new(updated_features, node_features.shape().to_vec())
268    }
269
270    /// 提取节点特征
271    fn extract_node(&self, features: &DenseTensor, node_idx: usize) -> DenseTensor {
272        let num_features = features.shape()[1];
273        let start = node_idx * num_features;
274        let _end = start + num_features;
275        features.slice(&[0, 1], &[node_idx..node_idx + 1, 0..num_features])
276    }
277}
278
279/// GCN(图卷积网络)层
280#[allow(dead_code)]
281pub struct GCNConv {
282    /// 输入特征维度
283    in_features: usize,
284    /// 输出特征维度
285    out_features: usize,
286    /// 权重矩阵
287    weight: DenseTensor,
288    /// 偏置
289    bias: DenseTensor,
290}
291
292#[cfg(feature = "tensor-gnn")]
293impl GCNConv {
294    /// 创建新的 GCN 层
295    pub fn new(in_features: usize, out_features: usize) -> Self {
296        // Xavier 初始化
297        let std = (6.0 / (in_features + out_features) as f64).sqrt();
298        let mut rng = thread_rng();
299        let weight_data: Vec<f64> = (0..in_features * out_features)
300            .map(|_| {
301                let x: f64 = StandardNormal.sample(&mut rng);
302                x * std
303            })
304            .collect();
305
306        let bias_data = vec![0.0; out_features];
307
308        Self {
309            in_features,
310            out_features,
311            weight: DenseTensor::new(weight_data, vec![in_features, out_features]),
312            bias: DenseTensor::new(bias_data, vec![out_features]),
313        }
314    }
315
316    /// 前向传播
317    ///
318    /// # Arguments
319    /// * `node_features` - 节点特征 [num_nodes, in_features]
320    /// * `adjacency` - 邻接矩阵(稀疏格式)
321    ///
322    /// # Returns
323    /// 返回更新后的节点特征 [num_nodes, out_features]
324    pub fn forward(&self, node_features: &DenseTensor, adjacency: &SparseTensor) -> DenseTensor {
325        // 1. 线性变换:H @ W
326        let h_transformed = node_features.matmul(&self.weight);
327
328        // 2. 度归一化:D^(-1/2) A D^(-1/2)
329        let normalized = self.normalize_adjacency(adjacency);
330
331        // 3. 图卷积:normalized_adj @ H_transformed
332        normalized.spmv(&h_transformed).unwrap()
333    }
334
335    /// 归一化邻接矩阵
336    fn normalize_adjacency(&self, adjacency: &SparseTensor) -> SparseTensor {
337        // 计算度
338        let degrees = self.compute_degrees(adjacency);
339
340        // 计算 D^(-1/2)
341        let _inv_sqrt_degrees = degrees.map(|d: f64| if d > 1e-10 { 1.0 / d.sqrt() } else { 0.0 });
342
343        // 归一化:D^(-1/2) A D^(-1/2)
344        // 简化实现:实际需要对每个边权重乘以对应的度归一化因子
345        adjacency.clone() // TODO: 实现完整的归一化
346    }
347
348    /// 计算节点度
349    fn compute_degrees(&self, adjacency: &SparseTensor) -> DenseTensor {
350        let num_nodes = adjacency.shape()[0];
351        let mut degrees = vec![0.0; num_nodes];
352
353        let coo = adjacency.to_coo();
354        for &row in coo.row_indices() {
355            degrees[row] += 1.0;
356        }
357
358        DenseTensor::new(degrees, vec![num_nodes])
359    }
360}
361
362/// GAT(图注意力网络)层
363#[allow(dead_code)]
364pub struct GATConv {
365    /// 输入特征维度
366    in_features: usize,
367    /// 输出特征维度
368    out_features: usize,
369    /// 注意力头数
370    num_heads: usize,
371    /// 注意力权重向量
372    attention_vec: DenseTensor,
373}
374
375#[cfg(feature = "tensor-gnn")]
376impl GATConv {
377    /// 创建新的 GAT 层
378    pub fn new(in_features: usize, out_features: usize, num_heads: usize) -> Self {
379        let std = (6.0 / (in_features + out_features) as f64).sqrt();
380        let mut rng = thread_rng();
381        let attention_data: Vec<f64> = (0..out_features * 2)
382            .map(|_| {
383                let x: f64 = StandardNormal.sample(&mut rng);
384                x * std
385            })
386            .collect();
387
388        Self {
389            in_features,
390            out_features,
391            num_heads,
392            attention_vec: DenseTensor::new(attention_data, vec![out_features * 2]),
393        }
394    }
395
396    /// 前向传播
397    pub fn forward(
398        &self,
399        node_features: &DenseTensor,
400        edge_index: &[(usize, usize)],
401    ) -> DenseTensor {
402        // 1. 线性变换
403        let h_transformed = node_features.matmul(&self.weight());
404
405        // 2. 计算注意力分数
406        let attention_scores = self.compute_attention(node_features, edge_index);
407
408        // 3. Softmax 归一化
409        let normalized_attention = self.softmax(&attention_scores, edge_index);
410
411        // 4. 加权聚合
412        self.aggregate_with_attention(&h_transformed, &normalized_attention, edge_index)
413    }
414
415    /// 获取权重矩阵
416    fn weight(&self) -> DenseTensor {
417        // 简化实现
418        DenseTensor::eye(self.in_features)
419    }
420
421    /// 计算注意力分数
422    fn compute_attention(
423        &self,
424        node_features: &DenseTensor,
425        edge_index: &[(usize, usize)],
426    ) -> Vec<f64> {
427        edge_index
428            .iter()
429            .map(|(src, dst)| {
430                let src_feat = node_features.data()
431                    [src * self.in_features..(src + 1) * self.in_features]
432                    .to_vec();
433                let dst_feat = node_features.data()
434                    [dst * self.in_features..(dst + 1) * self.in_features]
435                    .to_vec();
436
437                // 拼接并计算注意力
438                let mut concatenated = src_feat;
439                concatenated.extend_from_slice(&dst_feat);
440
441                // LeakyReLU(attention_vec @ concatenated)
442                let score: f64 = concatenated
443                    .iter()
444                    .zip(self.attention_vec.data().iter().cycle())
445                    .map(|(&a, &b)| a * b)
446                    .sum();
447
448                score.max(0.0) // LeakyReLU with alpha=0
449            })
450            .collect()
451    }
452
453    /// Softmax 归一化
454    fn softmax(&self, scores: &[f64], edge_index: &[(usize, usize)]) -> Vec<f64> {
455        // 按目标节点分组
456        let mut dst_scores: std::collections::HashMap<usize, Vec<(usize, f64)>> =
457            std::collections::HashMap::new();
458
459        for ((src, dst), score) in edge_index.iter().zip(scores.iter()) {
460            dst_scores.entry(*dst).or_default().push((*src, *score));
461        }
462
463        // 对每个目标节点的注意力分数进行 softmax
464        let mut normalized = vec![0.0; scores.len()];
465        for (dst, scores) in dst_scores {
466            let max_score = scores
467                .iter()
468                .map(|(_, s)| *s)
469                .fold(f64::NEG_INFINITY, f64::max);
470            let exp_scores: Vec<(usize, f64)> = scores
471                .iter()
472                .map(|(src, s)| (*src, (*s - max_score).exp()))
473                .collect();
474
475            let sum_exp: f64 = exp_scores.iter().map(|(_, e)| *e).sum();
476
477            for (src, exp_val) in exp_scores {
478                // 找到对应的索引
479                if let Some(idx) = edge_index.iter().position(|(s, d)| *s == src && *d == dst) {
480                    normalized[idx] = exp_val / sum_exp;
481                }
482            }
483        }
484
485        normalized
486    }
487
488    /// 带注意力的聚合
489    fn aggregate_with_attention(
490        &self,
491        node_features: &DenseTensor,
492        attention: &[f64],
493        edge_index: &[(usize, usize)],
494    ) -> DenseTensor {
495        let num_nodes = node_features.shape()[0];
496        let mut result = vec![0.0; num_nodes * self.out_features];
497
498        for ((src, dst), &attn) in edge_index.iter().zip(attention.iter()) {
499            for i in 0..self.out_features {
500                result[dst * self.out_features + i] +=
501                    attn * node_features.data()[src * self.in_features + i];
502            }
503        }
504
505        DenseTensor::new(result, vec![num_nodes, self.out_features])
506    }
507}
508
509/// GraphSAGE 层
510pub struct GraphSAGE {
511    /// 输入特征维度
512    in_features: usize,
513    /// 输出特征维度
514    out_features: usize,
515    /// 邻居采样数
516    num_samples: usize,
517}
518
519#[cfg(feature = "tensor-gnn")]
520impl GraphSAGE {
521    /// 创建新的 GraphSAGE 层
522    pub fn new(in_features: usize, out_features: usize, num_samples: usize) -> Self {
523        Self {
524            in_features,
525            out_features,
526            num_samples,
527        }
528    }
529
530    /// 前向传播
531    pub fn forward(
532        &self,
533        node_features: &DenseTensor,
534        edge_index: &[(usize, usize)],
535    ) -> DenseTensor {
536        let num_nodes = node_features.shape()[0];
537        let mut result = Vec::new();
538
539        for node_idx in 0..num_nodes {
540            // 1. 采样邻居
541            let neighbors: Vec<usize> = edge_index
542                .iter()
543                .filter(|(src, _)| *src == node_idx)
544                .take(self.num_samples)
545                .map(|(_, dst)| *dst)
546                .collect();
547
548            // 2. 聚合邻居特征(均值)
549            let neighbor_features = if neighbors.is_empty() {
550                DenseTensor::zeros(vec![self.in_features])
551            } else {
552                let features: Vec<DenseTensor> = neighbors
553                    .iter()
554                    .map(|&n| {
555                        let start = n * self.in_features;
556                        let end = start + self.in_features;
557                        DenseTensor::new(
558                            node_features.data()[start..end].to_vec(),
559                            vec![self.in_features],
560                        )
561                    })
562                    .collect();
563                MeanAggregator.aggregate(&features)
564            };
565
566            // 3. 拼接自身特征和邻居特征
567            let self_features = node_features.data()
568                [node_idx * self.in_features..(node_idx + 1) * self.in_features]
569                .to_vec();
570            let mut concatenated = self_features;
571            concatenated.extend_from_slice(neighbor_features.data());
572
573            // 4. 线性变换(简化:直接取前 out_features 个)
574            let transformed: Vec<f64> = concatenated
575                .iter()
576                .take(self.out_features)
577                .copied()
578                .collect();
579
580            result.extend_from_slice(&transformed);
581        }
582
583        DenseTensor::new(result, vec![num_nodes, self.out_features])
584    }
585}
586
587#[cfg(test)]
588mod tests {
589    use super::*;
590
591    #[test]
592    fn test_sum_aggregator() {
593        let aggregator = SumAggregator;
594        let messages = vec![
595            DenseTensor::new(vec![1.0, 2.0], vec![2]),
596            DenseTensor::new(vec![3.0, 4.0], vec![2]),
597            DenseTensor::new(vec![5.0, 6.0], vec![2]),
598        ];
599
600        let result = aggregator.aggregate(&messages);
601        assert_eq!(result.data(), &[9.0, 12.0]);
602    }
603
604    #[test]
605    fn test_mean_aggregator() {
606        let aggregator = MeanAggregator;
607        let messages = vec![
608            DenseTensor::new(vec![1.0, 2.0], vec![2]),
609            DenseTensor::new(vec![3.0, 4.0], vec![2]),
610            DenseTensor::new(vec![5.0, 6.0], vec![2]),
611        ];
612
613        let result = aggregator.aggregate(&messages);
614        assert_eq!(result.data(), &[3.0, 4.0]);
615    }
616
617    #[test]
618    fn test_identity_message() {
619        let message_fn = IdentityMessage;
620        let src = DenseTensor::new(vec![1.0, 2.0, 3.0], vec![3]);
621        let dst = DenseTensor::new(vec![4.0, 5.0, 6.0], vec![3]);
622
623        let result = message_fn.message(&src, None, &dst);
624        assert_eq!(result.data(), src.data());
625    }
626}