Skip to main content

god_graph/tensor/
traits.rs

1//! Tensor Trait 系统:定义张量的抽象接口
2//!
3//! 本模块提供了 tensor 的 trait 层次结构,用于后端抽象和操作定义
4
5use core::fmt::Debug;
6
7/// 数据类型枚举
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum DType {
10    /// 32 位浮点数
11    F32,
12    /// 64 位浮点数
13    F64,
14    /// 32 位整数
15    I32,
16    /// 64 位整数
17    I64,
18    /// 布尔类型
19    Bool,
20}
21
22impl DType {
23    /// 获取数据类型的字节大小
24    pub fn size_bytes(&self) -> usize {
25        match self {
26            DType::F32 => 4,
27            DType::F64 => 8,
28            DType::I32 => 4,
29            DType::I64 => 8,
30            DType::Bool => 1,
31        }
32    }
33}
34
35/// 设备类型枚举
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum Device {
38    /// CPU 设备
39    Cpu,
40    /// CUDA GPU 设备(未来支持)
41    Cuda(usize),
42    /// WebGPU 设备(未来支持)
43    Wgpu,
44}
45
46impl Default for Device {
47    #[inline]
48    fn default() -> Self {
49        Device::Cpu
50    }
51}
52
53/// Tensor 基础 trait:所有张量必须实现的核心操作
54///
55/// 提供形状、数据类型、设备等基本信息查询
56pub trait TensorBase: Clone + Send + Sync + Debug {
57    /// 获取张量的形状(各维度大小)
58    fn shape(&self) -> &[usize];
59
60    /// 获取数据类型
61    fn dtype(&self) -> DType;
62
63    /// 获取设备类型
64    fn device(&self) -> Device;
65
66    /// 获取张量的维度数(rank)
67    fn ndim(&self) -> usize {
68        self.shape().len()
69    }
70
71    /// 获取总元素数量
72    fn numel(&self) -> usize {
73        self.shape().iter().product()
74    }
75
76    /// 检查是否为标量(0 维张量)
77    fn is_scalar(&self) -> bool {
78        self.ndim() == 0
79    }
80
81    /// 检查是否为向量(1 维张量)
82    fn is_vector(&self) -> bool {
83        self.ndim() == 1
84    }
85
86    /// 检查是否为矩阵(2 维张量)
87    fn is_matrix(&self) -> bool {
88        self.ndim() == 2
89    }
90
91    /// 转换为密集张量
92    fn to_dense(&self) -> crate::tensor::dense::DenseTensor;
93
94    /// 转换为稀疏张量(如果适用)
95    #[cfg(feature = "tensor")]
96    fn to_sparse(&self) -> Option<crate::tensor::sparse::SparseTensor>;
97}
98
99/// Tensor 操作 trait:定义数学运算
100///
101/// 提供加法、乘法、矩阵乘法、转置等操作
102pub trait TensorOps: TensorBase {
103    /// 张量加法
104    fn add(&self, other: &Self) -> Self;
105
106    /// 张量减法
107    fn sub(&self, other: &Self) -> Self;
108
109    /// 逐元素乘法(Hadamard 积)
110    fn mul(&self, other: &Self) -> Self;
111
112    /// 逐元素除法
113    fn div(&self, other: &Self) -> Self;
114
115    /// 矩阵乘法(仅适用于 2D 张量)
116    fn matmul(&self, other: &Self) -> Self;
117
118    /// 转置(交换维度)
119    fn transpose(&self, axes: Option<&[usize]>) -> Self;
120
121    /// 沿指定轴求和
122    fn sum(&self, axes: Option<&[usize]>) -> Self;
123
124    /// 沿指定轴求均值
125    fn mean(&self, axes: Option<&[usize]>) -> Self;
126
127    /// 逐元素乘以标量
128    fn mul_scalar(&self, scalar: f64) -> Self;
129
130    /// 逐元素加上标量
131    fn add_scalar(&self, scalar: f64) -> Self;
132
133    /// 逐元素应用函数
134    fn map<F>(&self, f: F) -> Self
135    where
136        F: Fn(f64) -> f64 + Send + Sync;
137
138    /// 重塑张量形状
139    fn reshape(&self, new_shape: &[usize]) -> Self;
140
141    /// 切片操作
142    fn slice(&self, axes: &[usize], ranges: &[core::ops::Range<usize>]) -> Self;
143
144    /// 拼接两个张量
145    fn concat(&self, other: &Self, axis: usize) -> Self;
146
147    /// 获取最大值
148    fn max(&self) -> f64;
149
150    /// 获取最小值
151    fn min(&self) -> f64;
152
153    /// 获取 L2 范数
154    fn norm(&self) -> f64;
155
156    /// 归一化到单位范数
157    fn normalize(&self) -> Self;
158}
159
160/// 稀疏张量操作 trait
161pub trait SparseTensorOps: Clone + Send + Sync + TensorBase {
162    /// 获取非零元素数量
163    fn nnz(&self) -> usize;
164
165    /// 获取稀疏度(非零元素比例)
166    fn sparsity(&self) -> f64 {
167        let total = self.numel();
168        if total == 0 {
169            0.0
170        } else {
171            1.0 - (self.nnz() as f64 / total as f64)
172        }
173    }
174
175    /// 获取 COO 格式视图
176    fn coo(&self) -> COOView<'_>;
177
178    /// 获取行索引
179    fn row_indices(&self) -> &[usize];
180
181    /// 获取列索引
182    fn col_indices(&self) -> &[usize];
183
184    /// 获取非零值(作为 DenseTensor)
185    fn values(&self) -> &crate::tensor::dense::DenseTensor;
186}
187
188/// COO 格式视图:稀疏张量的只读视图
189#[derive(Debug, Clone)]
190pub struct COOView<'a> {
191    /// 行索引
192    pub row_indices: &'a [usize],
193    /// 列索引
194    pub col_indices: &'a [usize],
195    /// 非零值
196    pub values: &'a [f64],
197    /// 张量形状
198    pub shape: [usize; 2],
199}
200
201impl<'a> COOView<'a> {
202    /// 创建新的 COO 视图
203    pub fn new(
204        row_indices: &'a [usize],
205        col_indices: &'a [usize],
206        values: &'a [f64],
207        shape: [usize; 2],
208    ) -> Self {
209        Self {
210            row_indices,
211            col_indices,
212            values,
213            shape,
214        }
215    }
216
217    /// 获取非零元素数量
218    pub fn nnz(&self) -> usize {
219        self.values.len()
220    }
221
222    /// 迭代所有非零元素
223    pub fn iter(&self) -> impl Iterator<Item = (usize, usize, f64)> + '_ {
224        self.row_indices
225            .iter()
226            .zip(self.col_indices.iter())
227            .zip(self.values.iter())
228            .map(|((&r, &c), &v)| (r, c, v))
229    }
230}