god_graph/tensor/
traits.rs1use core::fmt::Debug;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum DType {
10 F32,
12 F64,
14 I32,
16 I64,
18 Bool,
20}
21
22impl DType {
23 pub fn size_bytes(&self) -> usize {
25 match self {
26 DType::F32 => 4,
27 DType::F64 => 8,
28 DType::I32 => 4,
29 DType::I64 => 8,
30 DType::Bool => 1,
31 }
32 }
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum Device {
38 Cpu,
40 Cuda(usize),
42 Wgpu,
44}
45
46impl Default for Device {
47 #[inline]
48 fn default() -> Self {
49 Device::Cpu
50 }
51}
52
53pub trait TensorBase: Clone + Send + Sync + Debug {
57 fn shape(&self) -> &[usize];
59
60 fn dtype(&self) -> DType;
62
63 fn device(&self) -> Device;
65
66 fn ndim(&self) -> usize {
68 self.shape().len()
69 }
70
71 fn numel(&self) -> usize {
73 self.shape().iter().product()
74 }
75
76 fn is_scalar(&self) -> bool {
78 self.ndim() == 0
79 }
80
81 fn is_vector(&self) -> bool {
83 self.ndim() == 1
84 }
85
86 fn is_matrix(&self) -> bool {
88 self.ndim() == 2
89 }
90
91 fn to_dense(&self) -> crate::tensor::dense::DenseTensor;
93
94 #[cfg(feature = "tensor")]
96 fn to_sparse(&self) -> Option<crate::tensor::sparse::SparseTensor>;
97}
98
99pub trait TensorOps: TensorBase {
103 fn add(&self, other: &Self) -> Self;
105
106 fn sub(&self, other: &Self) -> Self;
108
109 fn mul(&self, other: &Self) -> Self;
111
112 fn div(&self, other: &Self) -> Self;
114
115 fn matmul(&self, other: &Self) -> Self;
117
118 fn transpose(&self, axes: Option<&[usize]>) -> Self;
120
121 fn sum(&self, axes: Option<&[usize]>) -> Self;
123
124 fn mean(&self, axes: Option<&[usize]>) -> Self;
126
127 fn mul_scalar(&self, scalar: f64) -> Self;
129
130 fn add_scalar(&self, scalar: f64) -> Self;
132
133 fn map<F>(&self, f: F) -> Self
135 where
136 F: Fn(f64) -> f64 + Send + Sync;
137
138 fn reshape(&self, new_shape: &[usize]) -> Self;
140
141 fn slice(&self, axes: &[usize], ranges: &[core::ops::Range<usize>]) -> Self;
143
144 fn concat(&self, other: &Self, axis: usize) -> Self;
146
147 fn max(&self) -> f64;
149
150 fn min(&self) -> f64;
152
153 fn norm(&self) -> f64;
155
156 fn normalize(&self) -> Self;
158}
159
160pub trait SparseTensorOps: Clone + Send + Sync + TensorBase {
162 fn nnz(&self) -> usize;
164
165 fn sparsity(&self) -> f64 {
167 let total = self.numel();
168 if total == 0 {
169 0.0
170 } else {
171 1.0 - (self.nnz() as f64 / total as f64)
172 }
173 }
174
175 fn coo(&self) -> COOView<'_>;
177
178 fn row_indices(&self) -> &[usize];
180
181 fn col_indices(&self) -> &[usize];
183
184 fn values(&self) -> &crate::tensor::dense::DenseTensor;
186}
187
188#[derive(Debug, Clone)]
190pub struct COOView<'a> {
191 pub row_indices: &'a [usize],
193 pub col_indices: &'a [usize],
195 pub values: &'a [f64],
197 pub shape: [usize; 2],
199}
200
201impl<'a> COOView<'a> {
202 pub fn new(
204 row_indices: &'a [usize],
205 col_indices: &'a [usize],
206 values: &'a [f64],
207 shape: [usize; 2],
208 ) -> Self {
209 Self {
210 row_indices,
211 col_indices,
212 values,
213 shape,
214 }
215 }
216
217 pub fn nnz(&self) -> usize {
219 self.values.len()
220 }
221
222 pub fn iter(&self) -> impl Iterator<Item = (usize, usize, f64)> + '_ {
224 self.row_indices
225 .iter()
226 .zip(self.col_indices.iter())
227 .zip(self.values.iter())
228 .map(|((&r, &c), &v)| (r, c, v))
229 }
230}