1use burn_backend::{Backend, QTensorPrimitive, TensorMetadata};
2
3use crate::backends::*;
4
5#[cfg(feature = "autodiff")]
6use burn_backend::tensor::FloatTensor;
7
8#[derive(Clone, Debug)]
13pub enum BackendTensor<B: Backend> {
14 Float(B::FloatTensorPrimitive),
16 Int(B::IntTensorPrimitive),
18 Bool(B::BoolTensorPrimitive),
20 Quantized(B::QuantizedTensorPrimitive),
22 #[cfg(feature = "autodiff")]
23 Autodiff(FloatTensor<Autodiff<B>>),
25}
26
27impl<B: Backend> BackendTensor<B> {
28 pub(crate) fn float(self) -> B::FloatTensorPrimitive {
30 match self {
31 BackendTensor::Float(tensor) => tensor,
32 BackendTensor::Int(_) => panic!("Should be float, got int"),
33 BackendTensor::Bool(_) => panic!("Should be float, got bool"),
34 BackendTensor::Quantized(_) => panic!("Should be float, got quantized"),
35 #[cfg(feature = "autodiff")]
36 BackendTensor::Autodiff(_) => panic!("Should be float, got autodiff"),
37 }
38 }
39 pub(crate) fn as_float(&self) -> &B::FloatTensorPrimitive {
41 match self {
42 BackendTensor::Float(tensor) => tensor,
43 BackendTensor::Int(_) => panic!("Should be float, got int"),
44 BackendTensor::Bool(_) => panic!("Should be float, got bool"),
45 BackendTensor::Quantized(_) => panic!("Should be float, got quantized"),
46 #[cfg(feature = "autodiff")]
47 BackendTensor::Autodiff(_) => panic!("Should be float, got autodiff"),
48 }
49 }
50
51 pub(crate) fn int(self) -> B::IntTensorPrimitive {
53 match self {
54 BackendTensor::Int(tensor) => tensor,
55 BackendTensor::Float(_) => panic!("Should be int, got float"),
56 BackendTensor::Bool(_) => panic!("Should be int, got bool"),
57 BackendTensor::Quantized(_) => panic!("Should be int, got quantized"),
58 #[cfg(feature = "autodiff")]
59 BackendTensor::Autodiff(_) => panic!("Should be int, got autodiff"),
60 }
61 }
62
63 pub(crate) fn bool(self) -> B::BoolTensorPrimitive {
65 match self {
66 BackendTensor::Bool(tensor) => tensor,
67 BackendTensor::Float(_) => panic!("Should be bool, got float"),
68 BackendTensor::Int(_) => panic!("Should be bool, got int"),
69 BackendTensor::Quantized(_) => panic!("Should be bool, got quantized"),
70 #[cfg(feature = "autodiff")]
71 BackendTensor::Autodiff(_) => panic!("Should be bool, got autodiff"),
72 }
73 }
74
75 pub(crate) fn quantized(self) -> B::QuantizedTensorPrimitive {
77 match self {
78 BackendTensor::Quantized(tensor) => tensor,
79 _ => unreachable!(),
80 }
81 }
82
83 #[cfg(feature = "autodiff")]
84 pub(crate) fn autodiff(self) -> FloatTensor<Autodiff<B>> {
86 match self {
87 BackendTensor::Autodiff(tensor) => tensor,
88 _ => unreachable!(),
90 }
91 }
92
93 #[cfg(feature = "autodiff")]
94 pub(crate) fn as_autodiff(&self) -> &FloatTensor<Autodiff<B>> {
96 match self {
97 BackendTensor::Autodiff(tensor) => tensor,
98 _ => unreachable!(),
99 }
100 }
101
102 #[cfg(feature = "autodiff")]
103 pub(crate) fn autodiff_inner(self) -> B::FloatTensorPrimitive {
105 match self {
106 BackendTensor::Autodiff(tensor) => tensor.primitive,
107 _ => unreachable!(),
108 }
109 }
110
111 pub(crate) fn device(&self) -> B::Device {
113 match self {
114 BackendTensor::Float(tensor) => B::float_device(tensor),
115 BackendTensor::Int(tensor) => B::int_device(tensor),
116 BackendTensor::Bool(tensor) => B::bool_device(tensor),
117 BackendTensor::Quantized(tensor) => B::q_device(tensor),
118 #[cfg(feature = "autodiff")]
119 BackendTensor::Autodiff(tensor) => B::float_device(&tensor.primitive),
120 }
121 }
122}
123
124impl<B: Backend> TensorMetadata for BackendTensor<B> {
125 fn dtype(&self) -> burn_std::DType {
126 match self {
127 BackendTensor::Float(tensor) => tensor.dtype(),
128 BackendTensor::Int(tensor) => tensor.dtype(),
129 BackendTensor::Bool(tensor) => tensor.dtype(),
130 BackendTensor::Quantized(tensor) => tensor.dtype(),
131 #[cfg(feature = "autodiff")]
132 BackendTensor::Autodiff(tensor) => tensor.dtype(),
133 }
134 }
135
136 fn shape(&self) -> burn_std::Shape {
137 match self {
138 BackendTensor::Float(tensor) => tensor.shape(),
139 BackendTensor::Int(tensor) => tensor.shape(),
140 BackendTensor::Bool(tensor) => tensor.shape(),
141 BackendTensor::Quantized(tensor) => tensor.shape(),
142 #[cfg(feature = "autodiff")]
143 BackendTensor::Autodiff(tensor) => tensor.shape(),
144 }
145 }
146}
147
148impl<B: Backend> QTensorPrimitive for BackendTensor<B> {
149 fn scheme(&self) -> &burn_std::QuantScheme {
150 match self {
151 BackendTensor::Quantized(tensor) => tensor.scheme(),
152 _ => panic!(
153 "Quantization scheme is not valid for dtype {:?}",
154 self.dtype(),
155 ),
156 }
157 }
158}
159
160#[derive(Clone, Debug)]
165pub enum DispatchTensor {
166 #[cfg(feature = "cpu")]
168 Cpu(BackendTensor<Cpu>),
169
170 #[cfg(feature = "cuda")]
172 Cuda(BackendTensor<Cuda>),
173
174 #[cfg(wgpu_metal)]
176 Metal(BackendTensor<Metal>),
177
178 #[cfg(feature = "rocm")]
180 Rocm(BackendTensor<Rocm>),
181
182 #[cfg(wgpu_vulkan)]
184 Vulkan(BackendTensor<Vulkan>),
185
186 #[cfg(wgpu_webgpu)]
188 WebGpu(BackendTensor<WebGpu>),
189
190 #[cfg(feature = "ndarray")]
192 NdArray(BackendTensor<NdArray>),
193
194 #[cfg(feature = "tch")]
196 LibTorch(BackendTensor<LibTorch>),
197
198 #[cfg(feature = "autodiff")]
200 Autodiff(Box<DispatchTensor>),
201}
202
203impl TensorMetadata for DispatchTensor {
204 fn dtype(&self) -> burn_std::DType {
205 match self {
206 #[cfg(feature = "cpu")]
207 DispatchTensor::Cpu(tensor) => tensor.dtype(),
208 #[cfg(feature = "cuda")]
209 DispatchTensor::Cuda(tensor) => tensor.dtype(),
210 #[cfg(wgpu_metal)]
211 DispatchTensor::Metal(tensor) => tensor.dtype(),
212 #[cfg(feature = "rocm")]
213 DispatchTensor::Rocm(tensor) => tensor.dtype(),
214 #[cfg(wgpu_vulkan)]
215 DispatchTensor::Vulkan(tensor) => tensor.dtype(),
216 #[cfg(wgpu_webgpu)]
217 DispatchTensor::WebGpu(tensor) => tensor.dtype(),
218 #[cfg(feature = "ndarray")]
219 DispatchTensor::NdArray(tensor) => tensor.dtype(),
220 #[cfg(feature = "tch")]
221 DispatchTensor::LibTorch(tensor) => tensor.dtype(),
222 #[cfg(feature = "autodiff")]
223 DispatchTensor::Autodiff(tensor) => tensor.dtype(),
224 }
225 }
226
227 fn shape(&self) -> burn_std::Shape {
228 match self {
229 #[cfg(feature = "cpu")]
230 DispatchTensor::Cpu(tensor) => tensor.shape(),
231 #[cfg(feature = "cuda")]
232 DispatchTensor::Cuda(tensor) => tensor.shape(),
233 #[cfg(wgpu_metal)]
234 DispatchTensor::Metal(tensor) => tensor.shape(),
235 #[cfg(feature = "rocm")]
236 DispatchTensor::Rocm(tensor) => tensor.shape(),
237 #[cfg(wgpu_vulkan)]
238 DispatchTensor::Vulkan(tensor) => tensor.shape(),
239 #[cfg(wgpu_webgpu)]
240 DispatchTensor::WebGpu(tensor) => tensor.shape(),
241 #[cfg(feature = "ndarray")]
242 DispatchTensor::NdArray(tensor) => tensor.shape(),
243 #[cfg(feature = "tch")]
244 DispatchTensor::LibTorch(tensor) => tensor.shape(),
245 #[cfg(feature = "autodiff")]
246 DispatchTensor::Autodiff(tensor) => tensor.shape(),
247 }
248 }
249}
250
251impl QTensorPrimitive for DispatchTensor {
252 fn scheme(&self) -> &burn_std::QuantScheme {
253 match self {
254 #[cfg(feature = "cpu")]
255 DispatchTensor::Cpu(tensor) => tensor.scheme(),
256 #[cfg(feature = "cuda")]
257 DispatchTensor::Cuda(tensor) => tensor.scheme(),
258 #[cfg(wgpu_metal)]
259 DispatchTensor::Metal(tensor) => tensor.scheme(),
260 #[cfg(feature = "rocm")]
261 DispatchTensor::Rocm(tensor) => tensor.scheme(),
262 #[cfg(wgpu_vulkan)]
263 DispatchTensor::Vulkan(tensor) => tensor.scheme(),
264 #[cfg(wgpu_webgpu)]
265 DispatchTensor::WebGpu(tensor) => tensor.scheme(),
266 #[cfg(feature = "ndarray")]
267 DispatchTensor::NdArray(tensor) => tensor.scheme(),
268 #[cfg(feature = "tch")]
269 DispatchTensor::LibTorch(tensor) => tensor.scheme(),
270 #[cfg(feature = "autodiff")]
271 DispatchTensor::Autodiff(tensor) => tensor.scheme(),
272 }
273 }
274}