Skip to main content

god_graph/tensor/
backend.rs

1//! Tensor 后端抽象:支持多种 backend 实现
2//!
3//! 本模块提供了 tensor 后端的抽象层次,支持:
4//! - NdArrayBackend: 基于 ndarray 的 CPU backend
5//! - DfdxBackend: 基于 dfdx 的 GPU backend(支持自动微分)
6//! - CandleBackend: 基于 candle 的轻量级 backend
7
8use core::fmt::Debug;
9
10use crate::tensor::traits::{DType, Device};
11
12/// Tensor 存储后端 trait
13pub trait TensorStorage: Clone + Send + Sync + Debug {
14    /// 获取数据类型
15    fn dtype(&self) -> DType;
16
17    /// 获取设备类型
18    fn device(&self) -> Device;
19
20    /// 获取字节大小
21    fn nbytes(&self) -> usize;
22
23    /// 检查是否连续存储
24    fn is_contiguous(&self) -> bool;
25
26    /// 获取对齐字节数
27    fn alignment(&self) -> usize;
28}
29
30/// NdArray 存储后端
31#[derive(Clone, Debug)]
32#[cfg(feature = "tensor")]
33pub struct NdArrayStorage {
34    /// 数据(64 字节对齐)
35    data: Vec<f64>,
36    /// 数据类型
37    dtype: DType,
38}
39
40#[cfg(feature = "tensor")]
41impl NdArrayStorage {
42    /// 创建新的 NdArray 存储
43    pub fn new(data: Vec<f64>, dtype: DType) -> Self {
44        Self { data, dtype }
45    }
46
47    /// 获取数据切片
48    pub fn data(&self) -> &[f64] {
49        &self.data
50    }
51
52    /// 获取可变数据切片
53    pub fn data_mut(&mut self) -> &mut [f64] {
54        &mut self.data
55    }
56}
57
58#[cfg(feature = "tensor")]
59impl TensorStorage for NdArrayStorage {
60    fn dtype(&self) -> DType {
61        self.dtype
62    }
63
64    fn device(&self) -> Device {
65        Device::Cpu
66    }
67
68    fn nbytes(&self) -> usize {
69        self.data.len() * self.dtype.size_bytes()
70    }
71
72    fn is_contiguous(&self) -> bool {
73        true
74    }
75
76    fn alignment(&self) -> usize {
77        64 // Vec<f64> 默认对齐
78    }
79}
80
81/// Dfdx 存储后端(GPU 加速)
82#[cfg(feature = "tensor-gpu")]
83#[derive(Clone, Debug)]
84pub struct DfdxStorage {
85    /// 内部存储(使用 dfdx 的 Tensor 类型)
86    inner: dfdx::tensor::Tensor1D<f64>,
87}
88
89#[cfg(feature = "tensor-gpu")]
90impl DfdxStorage {
91    /// 从 dfdx tensor 创建
92    pub fn from_dfdx(tensor: dfdx::tensor::Tensor1D<f64>) -> Self {
93        Self { inner: tensor }
94    }
95
96    /// 获取内部 dfdx tensor
97    pub fn inner(&self) -> &dfdx::tensor::Tensor1D<f64> {
98        &self.inner
99    }
100}
101
102#[cfg(feature = "tensor-gpu")]
103impl TensorStorage for DfdxStorage {
104    fn dtype(&self) -> DType {
105        DType::F64
106    }
107
108    fn device(&self) -> Device {
109        Device::Cuda(0) // 默认使用 GPU 0
110    }
111
112    fn nbytes(&self) -> usize {
113        self.inner.shape().0 * 8
114    }
115
116    fn is_contiguous(&self) -> bool {
117        true // dfdx tensors are contiguous
118    }
119
120    fn alignment(&self) -> usize {
121        128 // CUDA memory alignment
122    }
123}
124
125/// Candle 存储后端(Hugging Face 轻量级 backend)
126#[cfg(feature = "tensor-candle")]
127#[derive(Clone, Debug)]
128pub struct CandleStorage {
129    /// 内部 candle tensor
130    inner: candle_core::Tensor,
131}
132
133#[cfg(feature = "tensor-candle")]
134impl CandleStorage {
135    /// 从 candle tensor 创建
136    pub fn from_candle(tensor: candle_core::Tensor) -> Self {
137        Self { inner: tensor }
138    }
139
140    /// 获取内部 candle tensor
141    pub fn inner(&self) -> &candle_core::Tensor {
142        &self.inner
143    }
144}
145
146#[cfg(feature = "tensor-candle")]
147impl TensorStorage for CandleStorage {
148    fn dtype(&self) -> DType {
149        match self.inner.dtype() {
150            candle_core::DType::F32 => DType::F32,
151            candle_core::DType::F64 => DType::F64,
152            candle_core::DType::I32 => DType::I32,
153            candle_core::DType::I64 => DType::I64,
154            _ => DType::F64,
155        }
156    }
157
158    fn device(&self) -> Device {
159        match self.inner.device() {
160            candle_core::Device::Cpu => Device::Cpu,
161            candle_core::Device::Cuda(_) => Device::Cuda(0),
162            candle_core::Device::Metal(_) => Device::Cpu, // Treat as CPU for now
163        }
164    }
165
166    fn nbytes(&self) -> usize {
167        self.inner.elem_count() * self.dtype().size_bytes()
168    }
169
170    fn is_contiguous(&self) -> bool {
171        self.inner.is_contiguous()
172    }
173
174    fn alignment(&self) -> usize {
175        64
176    }
177}
178
179/// 统一的存储后端枚举
180#[derive(Clone)]
181pub enum UnifiedStorage {
182    /// NdArray 后端
183    NdArray(NdArrayStorage),
184    /// Dfdx 后端(GPU)
185    #[cfg(feature = "tensor-gpu")]
186    Dfdx(DfdxStorage),
187    /// Candle 后端
188    #[cfg(feature = "tensor-candle")]
189    Candle(CandleStorage),
190}
191
192#[cfg(feature = "tensor")]
193impl UnifiedStorage {
194    /// 从 NdArray 存储创建
195    pub fn ndarray(data: Vec<f64>, dtype: DType) -> Self {
196        UnifiedStorage::NdArray(NdArrayStorage::new(data, dtype))
197    }
198}
199
200impl TensorStorage for UnifiedStorage {
201    fn dtype(&self) -> DType {
202        match self {
203            UnifiedStorage::NdArray(s) => s.dtype(),
204            #[cfg(feature = "tensor-gpu")]
205            UnifiedStorage::Dfdx(s) => s.dtype(),
206            #[cfg(feature = "tensor-candle")]
207            UnifiedStorage::Candle(s) => s.dtype(),
208        }
209    }
210
211    fn device(&self) -> Device {
212        match self {
213            UnifiedStorage::NdArray(s) => s.device(),
214            #[cfg(feature = "tensor-gpu")]
215            UnifiedStorage::Dfdx(s) => s.device(),
216            #[cfg(feature = "tensor-candle")]
217            UnifiedStorage::Candle(s) => s.device(),
218        }
219    }
220
221    fn nbytes(&self) -> usize {
222        match self {
223            UnifiedStorage::NdArray(s) => s.nbytes(),
224            #[cfg(feature = "tensor-gpu")]
225            UnifiedStorage::Dfdx(s) => s.nbytes(),
226            #[cfg(feature = "tensor-candle")]
227            UnifiedStorage::Candle(s) => s.nbytes(),
228        }
229    }
230
231    fn is_contiguous(&self) -> bool {
232        match self {
233            UnifiedStorage::NdArray(s) => s.is_contiguous(),
234            #[cfg(feature = "tensor-gpu")]
235            UnifiedStorage::Dfdx(s) => s.is_contiguous(),
236            #[cfg(feature = "tensor-candle")]
237            UnifiedStorage::Candle(s) => s.is_contiguous(),
238        }
239    }
240
241    fn alignment(&self) -> usize {
242        match self {
243            UnifiedStorage::NdArray(s) => s.alignment(),
244            #[cfg(feature = "tensor-gpu")]
245            UnifiedStorage::Dfdx(s) => s.alignment(),
246            #[cfg(feature = "tensor-candle")]
247            UnifiedStorage::Candle(s) => s.alignment(),
248        }
249    }
250}
251
252impl Debug for UnifiedStorage {
253    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
254        match self {
255            UnifiedStorage::NdArray(_) => write!(f, "UnifiedStorage::NdArray"),
256            #[cfg(feature = "tensor-gpu")]
257            UnifiedStorage::Dfdx(_) => write!(f, "UnifiedStorage::Dfdx"),
258            #[cfg(feature = "tensor-candle")]
259            UnifiedStorage::Candle(_) => write!(f, "UnifiedStorage::Candle"),
260        }
261    }
262}