Skip to main content

god_graph/tensor/
mod.rs

1//! Tensor 模块:为图神经网络和高性能计算提供张量支持
2//!
3//! 本模块实现了 God-Graph 的 tensor 基础设施,包括:
4//! - Dense tensor(密集张量):基于 ndarray 的 N 维数组
5//! - Sparse tensor(稀疏张量):COO, CSR, BSR 格式
6//! - Tensor 操作:矩阵乘法、转置、归约等
7//! - 多后端支持:NdArray, Dfdx (GPU), Candle
8//! - 内存池优化:减少分配开销
9//! - 梯度检查点:降低反向传播内存占用
10//!
11//! ## 特性
12//!
13//! - **后端抽象**:通过 trait 系统支持多种 backend(ndarray, dfdx, candle)
14//! - **稀疏格式**:COO(坐标格式)、CSR(压缩稀疏行)、BSR(块稀疏行)
15//! - **SIMD 优化**:使用 wide crate 实现 SIMD 向量化
16//! - **内存对齐**:64 字节缓存行对齐,避免 false sharing
17//! - **内存池**:可复用的张量分配,适用于迭代算法
18//!
19//! ## 示例
20//!
21//! ```
22//! # #[cfg(feature = "tensor")]
23//! # {
24//! use god_gragh::tensor::{DenseTensor, TensorBase};
25//!
26//! // 创建 2x3 密集张量
27//! let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
28//! let tensor = DenseTensor::from_vec(data, vec![2, 3]);
29//!
30//! assert_eq!(tensor.shape(), &[2, 3]);
31//! assert_eq!(tensor.ndim(), 2);
32//! # }
33//! ```
34
35#[cfg(feature = "tensor")]
36pub mod traits;
37
38#[cfg(feature = "tensor")]
39pub mod dense;
40
41#[cfg(feature = "tensor")]
42pub mod sparse;
43
44#[cfg(feature = "tensor")]
45pub mod ops;
46
47#[cfg(feature = "tensor")]
48pub mod error;
49
50#[cfg(feature = "tensor")]
51pub mod types;
52
53#[cfg(feature = "tensor")]
54pub mod backend;
55
56#[cfg(feature = "tensor-pool")]
57pub mod pool;
58
59#[cfg(feature = "tensor-gnn")]
60pub mod gnn;
61
62#[cfg(feature = "tensor")]
63pub mod graph_tensor;
64
65#[cfg(feature = "tensor")]
66pub mod differentiable;
67
68#[cfg(feature = "tensor")]
69pub mod decomposition;
70
71#[cfg(feature = "tensor")]
72pub mod unified_graph;
73
74// 重新导出核心类型
75#[cfg(feature = "tensor")]
76pub use traits::{COOView, DType, Device, SparseTensorOps, TensorBase, TensorOps};
77
78#[cfg(feature = "tensor")]
79pub use dense::DenseTensor;
80
81#[cfg(feature = "tensor")]
82pub use sparse::{COOTensor, CSRTensor, SparseTensor};
83
84#[cfg(feature = "tensor")]
85pub use error::TensorError;
86
87#[cfg(feature = "tensor")]
88pub use types::{EdgeFeatures, NodeFeatures, TensorEdge, TensorNode};
89
90#[cfg(feature = "tensor")]
91pub use types::AdjacencyMatrix;
92
93#[cfg(feature = "tensor")]
94pub use types::DegreeMatrix;
95
96#[cfg(feature = "tensor")]
97pub use backend::{NdArrayStorage, TensorStorage, UnifiedStorage};
98
99#[cfg(feature = "tensor-pool")]
100pub use pool::{ArenaStats, ArenaTensor, PoolConfig, PoolStats, PooledTensor, TensorArena, TensorPool};
101
102#[cfg(feature = "tensor-autograd")]
103pub use pool::GradientCheckpoint;
104
105#[cfg(feature = "tensor-gnn")]
106pub use gnn::{
107    Aggregator, GATConv, GCNConv, GraphSAGE, IdentityMessage, LinearMessage, MaxAggregator,
108    MeanAggregator, MessageFunction, MessagePassingLayer, SumAggregator, UpdateFunction,
109};
110
111#[cfg(feature = "tensor")]
112pub use graph_tensor::{
113    GraphFeatureExtractor, GraphReconstructor,
114};
115
116#[cfg(feature = "tensor")]
117pub use graph_tensor::{
118    GraphAdjacencyMatrix, GraphBatch, GraphTensorExt,
119};
120
121#[cfg(feature = "tensor")]
122pub use differentiable::{
123    DifferentiableEdge, DifferentiableGraph, DifferentiableNode, EdgeEditOp, EdgeEditPolicy,
124    EditOperation, GradientConfig, GradientRecorder, GraphTransformer, GumbelSoftmaxSampler,
125    NodeEditOp, StructureEdit, ThresholdEditPolicy,
126};
127
128#[cfg(feature = "tensor")]
129pub use unified_graph::{EdgeData, NodeData, UnifiedConfig, UnifiedGraph};
130
131#[cfg(feature = "tensor")]
132pub use decomposition::{
133    lie_algebra::{lie_exponential, lie_logarithm, skew_symmetric_projection},
134    qr::{orthogonalize, qr_decompose},
135    svd::{low_rank_approx, svd_decompose},
136    tensor_ring::{compress_tensor_ring, tensor_ring_decompose, TensorRing},
137};