Skip to main content

god_graph/tensor/
types.rs

1//! Tensor 感知的节点和边类型
2//!
3//! 扩展现有的节点/边系统以原生支持 tensor 数据
4//! 用于图神经网络(GNN)和其他机器学习应用
5
6use core::fmt;
7use core::hash::{Hash, Hasher};
8use core::marker::PhantomData;
9
10use crate::edge::EdgeIndex;
11use crate::node::NodeIndex;
12use crate::tensor::dense::DenseTensor;
13use crate::tensor::traits::TensorBase;
14
15#[cfg(feature = "tensor")]
16use crate::tensor::sparse::SparseTensor;
17
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20
21/// Tensor 节点:带有 tensor 数据的节点包装器
22///
23/// 提供零成本抽象,与现有 NodeIndex 兼容
24/// 支持任意实现 TensorBase trait 的 tensor 类型
25#[derive(Clone)]
26#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
27pub struct TensorNode<T: TensorBase> {
28    /// 节点索引
29    index: NodeIndex,
30    /// Tensor 数据
31    data: T,
32    /// 类型标记
33    _marker: PhantomData<T>,
34}
35
36impl<T: TensorBase> TensorNode<T> {
37    /// 创建新的 TensorNode
38    pub fn new(index: NodeIndex, data: T) -> Self {
39        Self {
40            index,
41            data,
42            _marker: PhantomData,
43        }
44    }
45
46    /// 获取节点索引
47    pub fn index(&self) -> NodeIndex {
48        self.index
49    }
50
51    /// 获取 tensor 数据引用
52    pub fn data(&self) -> &T {
53        &self.data
54    }
55
56    /// 获取 tensor 数据可变引用
57    pub fn data_mut(&mut self) -> &mut T {
58        &mut self.data
59    }
60
61    /// 获取 tensor 的形状
62    pub fn shape(&self) -> &[usize] {
63        self.data.shape()
64    }
65
66    /// 设置新的 tensor 数据
67    pub fn set_data(&mut self, data: T) {
68        self.data = data;
69    }
70
71    /// 消耗 self 并返回内部数据
72    pub fn into_data(self) -> T {
73        self.data
74    }
75}
76
77impl<T: TensorBase> fmt::Debug for TensorNode<T> {
78    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79        f.debug_struct("TensorNode")
80            .field("index", &self.index)
81            .field("shape", &self.data.shape())
82            .field("dtype", &self.data.dtype())
83            .finish()
84    }
85}
86
87impl<T: TensorBase> PartialEq for TensorNode<T> {
88    fn eq(&self, other: &Self) -> bool {
89        self.index == other.index
90    }
91}
92
93impl<T: TensorBase> Eq for TensorNode<T> {}
94
95impl<T: TensorBase> Hash for TensorNode<T> {
96    fn hash<H: Hasher>(&self, state: &mut H) {
97        self.index.hash(state);
98    }
99}
100
101/// Tensor 边:带有 tensor 数据的边包装器
102///
103/// 用于存储边特征(注意力权重、关系类型等)
104#[derive(Clone)]
105#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
106pub struct TensorEdge<E: TensorBase> {
107    /// 边索引
108    index: EdgeIndex,
109    /// Tensor 数据
110    data: E,
111    /// 源节点索引
112    source: NodeIndex,
113    /// 目标节点索引
114    target: NodeIndex,
115}
116
117impl<E: TensorBase> TensorEdge<E> {
118    /// 创建新的 TensorEdge
119    pub fn new(index: EdgeIndex, data: E, source: NodeIndex, target: NodeIndex) -> Self {
120        Self {
121            index,
122            data,
123            source,
124            target,
125        }
126    }
127
128    /// 获取边索引
129    pub fn index(&self) -> EdgeIndex {
130        self.index
131    }
132
133    /// 获取 tensor 数据引用
134    pub fn data(&self) -> &E {
135        &self.data
136    }
137
138    /// 获取 tensor 数据可变引用
139    pub fn data_mut(&mut self) -> &mut E {
140        &mut self.data
141    }
142
143    /// 获取源节点索引
144    pub fn source(&self) -> NodeIndex {
145        self.source
146    }
147
148    /// 获取目标节点索引
149    pub fn target(&self) -> NodeIndex {
150        self.target
151    }
152
153    /// 获取端点对
154    pub fn endpoints(&self) -> (NodeIndex, NodeIndex) {
155        (self.source, self.target)
156    }
157
158    /// 获取 tensor 的形状
159    pub fn shape(&self) -> &[usize] {
160        self.data.shape()
161    }
162
163    /// 设置新的 tensor 数据
164    pub fn set_data(&mut self, data: E) {
165        self.data = data;
166    }
167}
168
169impl<E: TensorBase> fmt::Debug for TensorEdge<E> {
170    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171        f.debug_struct("TensorEdge")
172            .field("index", &self.index)
173            .field(
174                "endpoints",
175                &format!("({:?}, {:?})", self.source, self.target),
176            )
177            .field("shape", &self.data.shape())
178            .field("dtype", &self.data.dtype())
179            .finish()
180    }
181}
182
183impl<E: TensorBase> PartialEq for TensorEdge<E> {
184    fn eq(&self, other: &Self) -> bool {
185        self.index == other.index
186    }
187}
188
189impl<E: TensorBase> Eq for TensorEdge<E> {}
190
191impl<E: TensorBase> Hash for TensorEdge<E> {
192    fn hash<H: Hasher>(&self, state: &mut H) {
193        self.index.hash(state);
194    }
195}
196
197/// 节点特征张量:用于存储节点的特征向量/矩阵
198///
199/// 这是 TensorNode 的便捷类型别名,使用 DenseTensor 作为默认后端
200pub type NodeFeatures = TensorNode<DenseTensor>;
201
202/// 边特征张量:用于存储边的特征(如注意力权重)
203///
204/// 这是 TensorEdge 的便捷类型别名,使用 DenseTensor 作为默认后端
205pub type EdgeFeatures = TensorEdge<DenseTensor>;
206
207/// 节点嵌入:用于图神经网络的节点表示
208///
209/// 通常是低维稠密向量
210pub type NodeEmbedding = TensorNode<DenseTensor>;
211
212/// 图神经网络中的隐藏状态
213///
214/// 用于存储 GNN 层的中间激活值
215pub type HiddenState = DenseTensor;
216
217/// 批量节点特征:用于 mini-batch 处理
218///
219/// 形状为 [batch_size, num_features] 或 [batch_size, num_nodes, num_features]
220pub struct BatchedNodeFeatures<T: TensorBase> {
221    /// 批量中的图索引
222    pub graph_indices: Vec<usize>,
223    /// 批量中的节点索引
224    pub node_indices: Vec<NodeIndex>,
225    /// 批量特征张量
226    pub features: T,
227}
228
229impl<T: TensorBase> BatchedNodeFeatures<T> {
230    /// 创建新的批量节点特征
231    pub fn new(graph_indices: Vec<usize>, node_indices: Vec<NodeIndex>, features: T) -> Self {
232        Self {
233            graph_indices,
234            node_indices,
235            features,
236        }
237    }
238
239    /// 获取批量大小
240    pub fn batch_size(&self) -> usize {
241        self.graph_indices.len()
242    }
243
244    /// 获取特征张量
245    pub fn features(&self) -> &T {
246        &self.features
247    }
248
249    /// 获取指定样本的特征
250    pub fn get_sample(&self, sample_idx: usize) -> Option<&T> {
251        if sample_idx < self.graph_indices.len() {
252            Some(&self.features)
253        } else {
254            None
255        }
256    }
257}
258
259/// 图神经网络消息:在消息传递过程中使用
260///
261/// 包含源节点特征、边特征和目标节点特征
262pub struct GNMessage<T: TensorBase> {
263    /// 源节点特征
264    pub source_features: T,
265    /// 边特征(如果有)
266    pub edge_features: Option<T>,
267    /// 目标节点特征
268    pub target_features: T,
269}
270
271impl<T: TensorBase> GNMessage<T> {
272    /// 创建新的 GNN 消息
273    pub fn new(source_features: T, edge_features: Option<T>, target_features: T) -> Self {
274        Self {
275            source_features,
276            edge_features,
277            target_features,
278        }
279    }
280
281    /// 获取源节点特征
282    pub fn source(&self) -> &T {
283        &self.source_features
284    }
285
286    /// 获取边特征
287    pub fn edge(&self) -> Option<&T> {
288        self.edge_features.as_ref()
289    }
290
291    /// 获取目标节点特征
292    pub fn target(&self) -> &T {
293        &self.target_features
294    }
295}
296
297/// 邻接矩阵表示:用于图神经网络计算
298///
299/// 使用稀疏张量格式存储图的邻接矩阵
300#[cfg(feature = "tensor")]
301pub struct AdjacencyMatrix {
302    /// 邻接张量(稀疏格式)
303    pub tensor: SparseTensor,
304    /// 节点数量
305    pub num_nodes: usize,
306}
307
308#[cfg(feature = "tensor")]
309impl AdjacencyMatrix {
310    /// 从边列表创建邻接矩阵
311    pub fn from_edges(edges: &[(usize, usize, f64)], num_nodes: usize) -> Self {
312        let tensor = SparseTensor::from_edges(edges, [num_nodes, num_nodes]);
313        Self { tensor, num_nodes }
314    }
315
316    /// 获取非零元素数量
317    pub fn nnz(&self) -> usize {
318        self.tensor.nnz()
319    }
320
321    /// 转换为稀疏张量
322    pub fn to_sparse(&self) -> SparseTensor {
323        self.tensor.clone()
324    }
325
326    /// 转换为密集张量
327    pub fn to_dense(&self) -> DenseTensor {
328        self.tensor.to_dense()
329    }
330}
331
332/// 度矩阵:对角矩阵,对角线元素为节点的度
333pub struct DegreeMatrix {
334    /// 度向量
335    pub degrees: DenseTensor,
336    /// 节点数量
337    pub num_nodes: usize,
338}
339
340#[cfg(feature = "tensor")]
341impl DegreeMatrix {
342    /// 从邻接矩阵计算度矩阵
343    pub fn from_adjacency(adj: &AdjacencyMatrix) -> Self {
344        let degrees = vec![0.0; adj.num_nodes];
345        let mut degrees_tensor = DenseTensor::new(degrees, vec![adj.num_nodes]);
346
347        // 计算每个节点的度
348        let coo = adj.tensor.to_coo();
349        for &row in coo.row_indices() {
350            let current = degrees_tensor.get(&[row]).unwrap();
351            degrees_tensor.set(&[row], current + 1.0).unwrap();
352        }
353
354        Self {
355            degrees: degrees_tensor,
356            num_nodes: adj.num_nodes,
357        }
358    }
359
360    /// 获取度向量
361    pub fn degrees(&self) -> &DenseTensor {
362        &self.degrees
363    }
364
365    /// 计算 D^(-1/2)(用于图卷积的归一化)
366    pub fn inverse_sqrt(&self, epsilon: f64) -> DenseTensor {
367        let shape = self.degrees.shape().to_vec();
368        let inv_sqrt: Vec<f64> = self.degrees.data()
369            .iter()
370            .map(|&d| if d > epsilon { 1.0 / d.sqrt() } else { 0.0 })
371            .collect();
372        DenseTensor::new(inv_sqrt, shape)
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379
380    #[test]
381    fn test_tensor_node_creation() {
382        let index = NodeIndex::new(0, 1);
383        let data = DenseTensor::new(vec![1.0, 2.0, 3.0], vec![3]);
384        let node = TensorNode::new(index, data.clone());
385
386        assert_eq!(node.index(), index);
387        assert_eq!(node.data(), &data);
388        assert_eq!(node.shape(), &[3]);
389    }
390
391    #[test]
392    fn test_tensor_edge_creation() {
393        let index = EdgeIndex::new(0, 1);
394        let source = NodeIndex::new(0, 1);
395        let target = NodeIndex::new(1, 1);
396        let data = DenseTensor::scalar(0.5);
397
398        let edge = TensorEdge::new(index, data.clone(), source, target);
399
400        assert_eq!(edge.index(), index);
401        assert_eq!(edge.source(), source);
402        assert_eq!(edge.target(), target);
403        assert_eq!(edge.endpoints(), (source, target));
404    }
405
406    #[test]
407    #[cfg(feature = "tensor")]
408    fn test_adjacency_matrix() {
409        let edges = vec![(0, 1, 1.0), (0, 2, 1.0), (1, 2, 1.0)];
410        let adj = AdjacencyMatrix::from_edges(&edges, 3);
411
412        assert_eq!(adj.num_nodes, 3);
413        assert_eq!(adj.nnz(), 3);
414
415        let dense = adj.to_dense();
416        assert_eq!(dense.shape(), &[3, 3]);
417        assert_eq!(dense.get(&[0, 1]).unwrap(), 1.0);
418        assert_eq!(dense.get(&[0, 2]).unwrap(), 1.0);
419    }
420
421    #[test]
422    #[cfg(feature = "tensor")]
423    fn test_degree_matrix() {
424        let edges = vec![(0, 1, 1.0), (0, 2, 1.0), (1, 2, 1.0)];
425        let adj = AdjacencyMatrix::from_edges(&edges, 3);
426        let degree = DegreeMatrix::from_adjacency(&adj);
427
428        assert_eq!(degree.num_nodes, 3);
429        // 节点 0 的度为 2 (出边:0->1, 0->2)
430        // 节点 1 的度为 1 (出边:1->2)
431        // 节点 2 的度为 0 (无出边)
432        assert!((degree.degrees().get(&[0]).unwrap() - 2.0).abs() < 1e-10);
433        assert!((degree.degrees().get(&[1]).unwrap() - 1.0).abs() < 1e-10);
434        assert!((degree.degrees().get(&[2]).unwrap() - 0.0).abs() < 1e-10);
435    }
436
437    #[test]
438    fn test_gnn_message() {
439        let src = DenseTensor::new(vec![1.0, 2.0], vec![2]);
440        let edge = DenseTensor::scalar(0.5);
441        let dst = DenseTensor::new(vec![3.0, 4.0], vec![2]);
442
443        let msg = GNMessage::new(src.clone(), Some(edge.clone()), dst.clone());
444
445        assert_eq!(msg.source().data(), &[1.0, 2.0]);
446        assert_eq!(msg.edge().unwrap().data(), &[0.5]);
447        assert_eq!(msg.target().data(), &[3.0, 4.0]);
448    }
449}