god_graph/tensor/
backend.rs1use core::fmt::Debug;
9
10use crate::tensor::traits::{DType, Device};
11
12pub trait TensorStorage: Clone + Send + Sync + Debug {
14 fn dtype(&self) -> DType;
16
17 fn device(&self) -> Device;
19
20 fn nbytes(&self) -> usize;
22
23 fn is_contiguous(&self) -> bool;
25
26 fn alignment(&self) -> usize;
28}
29
30#[derive(Clone, Debug)]
32#[cfg(feature = "tensor")]
33pub struct NdArrayStorage {
34 data: Vec<f64>,
36 dtype: DType,
38}
39
40#[cfg(feature = "tensor")]
41impl NdArrayStorage {
42 pub fn new(data: Vec<f64>, dtype: DType) -> Self {
44 Self { data, dtype }
45 }
46
47 pub fn data(&self) -> &[f64] {
49 &self.data
50 }
51
52 pub fn data_mut(&mut self) -> &mut [f64] {
54 &mut self.data
55 }
56}
57
58#[cfg(feature = "tensor")]
59impl TensorStorage for NdArrayStorage {
60 fn dtype(&self) -> DType {
61 self.dtype
62 }
63
64 fn device(&self) -> Device {
65 Device::Cpu
66 }
67
68 fn nbytes(&self) -> usize {
69 self.data.len() * self.dtype.size_bytes()
70 }
71
72 fn is_contiguous(&self) -> bool {
73 true
74 }
75
76 fn alignment(&self) -> usize {
77 64 }
79}
80
81#[cfg(feature = "tensor-gpu")]
83#[derive(Clone, Debug)]
84pub struct DfdxStorage {
85 inner: dfdx::tensor::Tensor1D<f64>,
87}
88
89#[cfg(feature = "tensor-gpu")]
90impl DfdxStorage {
91 pub fn from_dfdx(tensor: dfdx::tensor::Tensor1D<f64>) -> Self {
93 Self { inner: tensor }
94 }
95
96 pub fn inner(&self) -> &dfdx::tensor::Tensor1D<f64> {
98 &self.inner
99 }
100}
101
102#[cfg(feature = "tensor-gpu")]
103impl TensorStorage for DfdxStorage {
104 fn dtype(&self) -> DType {
105 DType::F64
106 }
107
108 fn device(&self) -> Device {
109 Device::Cuda(0) }
111
112 fn nbytes(&self) -> usize {
113 self.inner.shape().0 * 8
114 }
115
116 fn is_contiguous(&self) -> bool {
117 true }
119
120 fn alignment(&self) -> usize {
121 128 }
123}
124
125#[cfg(feature = "tensor-candle")]
127#[derive(Clone, Debug)]
128pub struct CandleStorage {
129 inner: candle_core::Tensor,
131}
132
133#[cfg(feature = "tensor-candle")]
134impl CandleStorage {
135 pub fn from_candle(tensor: candle_core::Tensor) -> Self {
137 Self { inner: tensor }
138 }
139
140 pub fn inner(&self) -> &candle_core::Tensor {
142 &self.inner
143 }
144}
145
146#[cfg(feature = "tensor-candle")]
147impl TensorStorage for CandleStorage {
148 fn dtype(&self) -> DType {
149 match self.inner.dtype() {
150 candle_core::DType::F32 => DType::F32,
151 candle_core::DType::F64 => DType::F64,
152 candle_core::DType::I32 => DType::I32,
153 candle_core::DType::I64 => DType::I64,
154 _ => DType::F64,
155 }
156 }
157
158 fn device(&self) -> Device {
159 match self.inner.device() {
160 candle_core::Device::Cpu => Device::Cpu,
161 candle_core::Device::Cuda(_) => Device::Cuda(0),
162 candle_core::Device::Metal(_) => Device::Cpu, }
164 }
165
166 fn nbytes(&self) -> usize {
167 self.inner.elem_count() * self.dtype().size_bytes()
168 }
169
170 fn is_contiguous(&self) -> bool {
171 self.inner.is_contiguous()
172 }
173
174 fn alignment(&self) -> usize {
175 64
176 }
177}
178
179#[derive(Clone)]
181pub enum UnifiedStorage {
182 NdArray(NdArrayStorage),
184 #[cfg(feature = "tensor-gpu")]
186 Dfdx(DfdxStorage),
187 #[cfg(feature = "tensor-candle")]
189 Candle(CandleStorage),
190}
191
192#[cfg(feature = "tensor")]
193impl UnifiedStorage {
194 pub fn ndarray(data: Vec<f64>, dtype: DType) -> Self {
196 UnifiedStorage::NdArray(NdArrayStorage::new(data, dtype))
197 }
198}
199
200impl TensorStorage for UnifiedStorage {
201 fn dtype(&self) -> DType {
202 match self {
203 UnifiedStorage::NdArray(s) => s.dtype(),
204 #[cfg(feature = "tensor-gpu")]
205 UnifiedStorage::Dfdx(s) => s.dtype(),
206 #[cfg(feature = "tensor-candle")]
207 UnifiedStorage::Candle(s) => s.dtype(),
208 }
209 }
210
211 fn device(&self) -> Device {
212 match self {
213 UnifiedStorage::NdArray(s) => s.device(),
214 #[cfg(feature = "tensor-gpu")]
215 UnifiedStorage::Dfdx(s) => s.device(),
216 #[cfg(feature = "tensor-candle")]
217 UnifiedStorage::Candle(s) => s.device(),
218 }
219 }
220
221 fn nbytes(&self) -> usize {
222 match self {
223 UnifiedStorage::NdArray(s) => s.nbytes(),
224 #[cfg(feature = "tensor-gpu")]
225 UnifiedStorage::Dfdx(s) => s.nbytes(),
226 #[cfg(feature = "tensor-candle")]
227 UnifiedStorage::Candle(s) => s.nbytes(),
228 }
229 }
230
231 fn is_contiguous(&self) -> bool {
232 match self {
233 UnifiedStorage::NdArray(s) => s.is_contiguous(),
234 #[cfg(feature = "tensor-gpu")]
235 UnifiedStorage::Dfdx(s) => s.is_contiguous(),
236 #[cfg(feature = "tensor-candle")]
237 UnifiedStorage::Candle(s) => s.is_contiguous(),
238 }
239 }
240
241 fn alignment(&self) -> usize {
242 match self {
243 UnifiedStorage::NdArray(s) => s.alignment(),
244 #[cfg(feature = "tensor-gpu")]
245 UnifiedStorage::Dfdx(s) => s.alignment(),
246 #[cfg(feature = "tensor-candle")]
247 UnifiedStorage::Candle(s) => s.alignment(),
248 }
249 }
250}
251
252impl Debug for UnifiedStorage {
253 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
254 match self {
255 UnifiedStorage::NdArray(_) => write!(f, "UnifiedStorage::NdArray"),
256 #[cfg(feature = "tensor-gpu")]
257 UnifiedStorage::Dfdx(_) => write!(f, "UnifiedStorage::Dfdx"),
258 #[cfg(feature = "tensor-candle")]
259 UnifiedStorage::Candle(_) => write!(f, "UnifiedStorage::Candle"),
260 }
261 }
262}