1use alloc::boxed::Box;
2
3use burn_backend::{
4 Backend, DType, QTensorPrimitive, Shape, TensorMetadata, quantization::QuantScheme,
5};
6
7#[cfg(feature = "autodiff")]
8use crate::CheckpointingStrategy;
9use crate::backends::*;
10
11#[cfg(feature = "autodiff")]
12use burn_backend::tensor::FloatTensor;
13
14#[derive(Clone, Debug)]
19pub enum BackendTensor<B: Backend> {
20 Float(B::FloatTensorPrimitive),
22 Int(B::IntTensorPrimitive),
24 Bool(B::BoolTensorPrimitive),
26 Quantized(B::QuantizedTensorPrimitive),
28 #[cfg(feature = "autodiff")]
29 Autodiff(FloatTensor<Autodiff<B>>),
31}
32
33impl<B: Backend> BackendTensor<B> {
34 pub(crate) fn float(self) -> B::FloatTensorPrimitive {
36 match self {
37 BackendTensor::Float(tensor) => tensor,
38 BackendTensor::Int(_) => panic!("Should be float, got int"),
39 BackendTensor::Bool(_) => panic!("Should be float, got bool"),
40 BackendTensor::Quantized(_) => panic!("Should be float, got quantized"),
41 #[cfg(feature = "autodiff")]
42 BackendTensor::Autodiff(_) => panic!("Should be float, got autodiff"),
43 }
44 }
45 pub(crate) fn as_float(&self) -> &B::FloatTensorPrimitive {
47 match self {
48 BackendTensor::Float(tensor) => tensor,
49 BackendTensor::Int(_) => panic!("Should be float, got int"),
50 BackendTensor::Bool(_) => panic!("Should be float, got bool"),
51 BackendTensor::Quantized(_) => panic!("Should be float, got quantized"),
52 #[cfg(feature = "autodiff")]
53 BackendTensor::Autodiff(_) => panic!("Should be float, got autodiff"),
54 }
55 }
56
57 pub(crate) fn int(self) -> B::IntTensorPrimitive {
59 match self {
60 BackendTensor::Int(tensor) => tensor,
61 BackendTensor::Float(_) => panic!("Should be int, got float"),
62 BackendTensor::Bool(_) => panic!("Should be int, got bool"),
63 BackendTensor::Quantized(_) => panic!("Should be int, got quantized"),
64 #[cfg(feature = "autodiff")]
65 BackendTensor::Autodiff(_) => panic!("Should be int, got autodiff"),
66 }
67 }
68
69 pub(crate) fn bool(self) -> B::BoolTensorPrimitive {
71 match self {
72 BackendTensor::Bool(tensor) => tensor,
73 BackendTensor::Float(_) => panic!("Should be bool, got float"),
74 BackendTensor::Int(_) => panic!("Should be bool, got int"),
75 BackendTensor::Quantized(_) => panic!("Should be bool, got quantized"),
76 #[cfg(feature = "autodiff")]
77 BackendTensor::Autodiff(_) => panic!("Should be bool, got autodiff"),
78 }
79 }
80
81 pub(crate) fn quantized(self) -> B::QuantizedTensorPrimitive {
83 match self {
84 BackendTensor::Quantized(tensor) => tensor,
85 _ => unreachable!(),
86 }
87 }
88
89 #[cfg(feature = "autodiff")]
90 pub(crate) fn autodiff(self) -> FloatTensor<Autodiff<B>> {
92 match self {
93 BackendTensor::Autodiff(tensor) => tensor,
94 _ => unreachable!(),
96 }
97 }
98
99 #[cfg(feature = "autodiff")]
100 pub(crate) fn as_autodiff(&self) -> &FloatTensor<Autodiff<B>> {
102 match self {
103 BackendTensor::Autodiff(tensor) => tensor,
104 _ => unreachable!(),
105 }
106 }
107
108 #[cfg(feature = "autodiff")]
109 pub(crate) fn autodiff_inner(self) -> B::FloatTensorPrimitive {
111 match self {
112 BackendTensor::Autodiff(tensor) => tensor.primitive,
113 _ => unreachable!(),
114 }
115 }
116
117 pub(crate) fn device(&self) -> B::Device {
119 match self {
120 BackendTensor::Float(tensor) => B::float_device(tensor),
121 BackendTensor::Int(tensor) => B::int_device(tensor),
122 BackendTensor::Bool(tensor) => B::bool_device(tensor),
123 BackendTensor::Quantized(tensor) => B::q_device(tensor),
124 #[cfg(feature = "autodiff")]
125 BackendTensor::Autodiff(tensor) => B::float_device(&tensor.primitive),
126 }
127 }
128}
129
130impl<B: Backend> TensorMetadata for BackendTensor<B> {
131 fn dtype(&self) -> DType {
132 match self {
133 BackendTensor::Float(tensor) => tensor.dtype(),
134 BackendTensor::Int(tensor) => tensor.dtype(),
135 BackendTensor::Bool(tensor) => tensor.dtype(),
136 BackendTensor::Quantized(tensor) => tensor.dtype(),
137 #[cfg(feature = "autodiff")]
138 BackendTensor::Autodiff(tensor) => tensor.dtype(),
139 }
140 }
141
142 fn shape(&self) -> Shape {
143 match self {
144 BackendTensor::Float(tensor) => tensor.shape(),
145 BackendTensor::Int(tensor) => tensor.shape(),
146 BackendTensor::Bool(tensor) => tensor.shape(),
147 BackendTensor::Quantized(tensor) => tensor.shape(),
148 #[cfg(feature = "autodiff")]
149 BackendTensor::Autodiff(tensor) => tensor.shape(),
150 }
151 }
152}
153
154impl<B: Backend> QTensorPrimitive for BackendTensor<B> {
155 fn scheme(&self) -> &QuantScheme {
156 match self {
157 BackendTensor::Quantized(tensor) => tensor.scheme(),
158 _ => panic!(
159 "Quantization scheme is not valid for dtype {:?}",
160 self.dtype(),
161 ),
162 }
163 }
164}
165
166#[derive(Clone, Debug)]
172pub struct DispatchTensor {
173 pub(crate) kind: DispatchTensorKind,
175 #[cfg(feature = "autodiff")]
177 pub(crate) checkpointing: CheckpointingStrategy,
178}
179
180#[derive(Clone, Debug)]
188pub enum DispatchTensorKind {
189 #[cfg(feature = "cpu")]
191 Cpu(BackendTensor<Cpu>),
192
193 #[cfg(feature = "cuda")]
195 Cuda(BackendTensor<Cuda>),
196
197 #[cfg(wgpu_metal)]
199 Metal(BackendTensor<Metal>),
200
201 #[cfg(feature = "rocm")]
203 Rocm(BackendTensor<Rocm>),
204
205 #[cfg(wgpu_vulkan)]
207 Vulkan(BackendTensor<Vulkan>),
208
209 #[cfg(wgpu_webgpu)]
211 Wgpu(BackendTensor<Wgpu>),
212
213 #[cfg(feature = "flex")]
215 Flex(BackendTensor<Flex>),
216
217 #[cfg(feature = "ndarray")]
219 NdArray(BackendTensor<NdArray>),
220
221 #[cfg(feature = "tch")]
223 LibTorch(BackendTensor<LibTorch>),
224
225 #[cfg(feature = "autodiff")]
227 Autodiff(Box<DispatchTensorKind>),
228}
229
230impl TensorMetadata for DispatchTensorKind {
231 fn dtype(&self) -> DType {
232 match self {
233 #[cfg(feature = "cpu")]
234 Self::Cpu(tensor) => tensor.dtype(),
235 #[cfg(feature = "cuda")]
236 Self::Cuda(tensor) => tensor.dtype(),
237 #[cfg(wgpu_metal)]
238 Self::Metal(tensor) => tensor.dtype(),
239 #[cfg(feature = "rocm")]
240 Self::Rocm(tensor) => tensor.dtype(),
241 #[cfg(wgpu_vulkan)]
242 Self::Vulkan(tensor) => tensor.dtype(),
243 #[cfg(wgpu_webgpu)]
244 Self::Wgpu(tensor) => tensor.dtype(),
245 #[cfg(feature = "flex")]
246 Self::Flex(tensor) => tensor.dtype(),
247 #[cfg(feature = "ndarray")]
248 Self::NdArray(tensor) => tensor.dtype(),
249 #[cfg(feature = "tch")]
250 Self::LibTorch(tensor) => tensor.dtype(),
251 #[cfg(feature = "autodiff")]
252 Self::Autodiff(tensor) => tensor.dtype(),
253 }
254 }
255
256 fn shape(&self) -> Shape {
257 match self {
258 #[cfg(feature = "cpu")]
259 Self::Cpu(tensor) => tensor.shape(),
260 #[cfg(feature = "cuda")]
261 Self::Cuda(tensor) => tensor.shape(),
262 #[cfg(wgpu_metal)]
263 Self::Metal(tensor) => tensor.shape(),
264 #[cfg(feature = "rocm")]
265 Self::Rocm(tensor) => tensor.shape(),
266 #[cfg(wgpu_vulkan)]
267 Self::Vulkan(tensor) => tensor.shape(),
268 #[cfg(wgpu_webgpu)]
269 Self::Wgpu(tensor) => tensor.shape(),
270 #[cfg(feature = "flex")]
271 Self::Flex(tensor) => tensor.shape(),
272 #[cfg(feature = "ndarray")]
273 Self::NdArray(tensor) => tensor.shape(),
274 #[cfg(feature = "tch")]
275 Self::LibTorch(tensor) => tensor.shape(),
276 #[cfg(feature = "autodiff")]
277 Self::Autodiff(tensor) => tensor.shape(),
278 }
279 }
280}
281
282impl QTensorPrimitive for DispatchTensorKind {
283 fn scheme(&self) -> &QuantScheme {
284 match self {
285 #[cfg(feature = "cpu")]
286 Self::Cpu(tensor) => tensor.scheme(),
287 #[cfg(feature = "cuda")]
288 Self::Cuda(tensor) => tensor.scheme(),
289 #[cfg(wgpu_metal)]
290 Self::Metal(tensor) => tensor.scheme(),
291 #[cfg(feature = "rocm")]
292 Self::Rocm(tensor) => tensor.scheme(),
293 #[cfg(wgpu_vulkan)]
294 Self::Vulkan(tensor) => tensor.scheme(),
295 #[cfg(wgpu_webgpu)]
296 Self::Wgpu(tensor) => tensor.scheme(),
297 #[cfg(feature = "flex")]
298 Self::Flex(tensor) => tensor.scheme(),
299 #[cfg(feature = "ndarray")]
300 Self::NdArray(tensor) => tensor.scheme(),
301 #[cfg(feature = "tch")]
302 Self::LibTorch(tensor) => tensor.scheme(),
303 #[cfg(feature = "autodiff")]
304 Self::Autodiff(tensor) => tensor.scheme(),
305 }
306 }
307}
308
309impl TensorMetadata for DispatchTensor {
310 fn dtype(&self) -> DType {
311 self.kind.dtype()
312 }
313
314 fn shape(&self) -> Shape {
315 self.kind.shape()
316 }
317}
318
319impl QTensorPrimitive for DispatchTensor {
320 fn scheme(&self) -> &QuantScheme {
321 self.kind.scheme()
322 }
323}