1use alloc::boxed::Box;
2
3use burn_backend::{
4 Backend, DType, QTensorPrimitive, Shape, TensorMetadata, quantization::QuantScheme,
5};
6
7#[cfg(feature = "autodiff")]
8use crate::CheckpointingStrategy;
9use crate::backends::*;
10
11#[cfg(feature = "autodiff")]
12use burn_backend::tensor::FloatTensor;
13
14#[derive(Clone, Debug)]
19pub enum BackendTensor<B: Backend> {
20 Float(B::FloatTensorPrimitive),
22 Int(B::IntTensorPrimitive),
24 Bool(B::BoolTensorPrimitive),
26 Quantized(B::QuantizedTensorPrimitive),
28 #[cfg(feature = "autodiff")]
29 Autodiff(FloatTensor<Autodiff<B>>),
31}
32
33impl<B: Backend> BackendTensor<B> {
34 pub(crate) fn float(self) -> B::FloatTensorPrimitive {
36 match self {
37 BackendTensor::Float(tensor) => tensor,
38 BackendTensor::Int(_) => panic!("Should be float, got int"),
39 BackendTensor::Bool(_) => panic!("Should be float, got bool"),
40 BackendTensor::Quantized(_) => panic!("Should be float, got quantized"),
41 #[cfg(feature = "autodiff")]
42 BackendTensor::Autodiff(_) => panic!("Should be float, got autodiff"),
43 }
44 }
45 pub(crate) fn as_float(&self) -> &B::FloatTensorPrimitive {
47 match self {
48 BackendTensor::Float(tensor) => tensor,
49 BackendTensor::Int(_) => panic!("Should be float, got int"),
50 BackendTensor::Bool(_) => panic!("Should be float, got bool"),
51 BackendTensor::Quantized(_) => panic!("Should be float, got quantized"),
52 #[cfg(feature = "autodiff")]
53 BackendTensor::Autodiff(_) => panic!("Should be float, got autodiff"),
54 }
55 }
56
57 pub(crate) fn int(self) -> B::IntTensorPrimitive {
59 match self {
60 BackendTensor::Int(tensor) => tensor,
61 BackendTensor::Float(_) => panic!("Should be int, got float"),
62 BackendTensor::Bool(_) => panic!("Should be int, got bool"),
63 BackendTensor::Quantized(_) => panic!("Should be int, got quantized"),
64 #[cfg(feature = "autodiff")]
65 BackendTensor::Autodiff(_) => panic!("Should be int, got autodiff"),
66 }
67 }
68
69 pub(crate) fn bool(self) -> B::BoolTensorPrimitive {
71 match self {
72 BackendTensor::Bool(tensor) => tensor,
73 BackendTensor::Float(_) => panic!("Should be bool, got float"),
74 BackendTensor::Int(_) => panic!("Should be bool, got int"),
75 BackendTensor::Quantized(_) => panic!("Should be bool, got quantized"),
76 #[cfg(feature = "autodiff")]
77 BackendTensor::Autodiff(_) => panic!("Should be bool, got autodiff"),
78 }
79 }
80
81 pub(crate) fn quantized(self) -> B::QuantizedTensorPrimitive {
83 match self {
84 BackendTensor::Quantized(tensor) => tensor,
85 _ => unreachable!(),
86 }
87 }
88
89 #[cfg(feature = "autodiff")]
90 pub(crate) fn autodiff(self) -> FloatTensor<Autodiff<B>> {
92 match self {
93 BackendTensor::Autodiff(tensor) => tensor,
94 _ => unreachable!(),
96 }
97 }
98
99 #[cfg(feature = "autodiff")]
100 pub(crate) fn as_autodiff(&self) -> &FloatTensor<Autodiff<B>> {
102 match self {
103 BackendTensor::Autodiff(tensor) => tensor,
104 _ => unreachable!(),
105 }
106 }
107
108 #[cfg(feature = "autodiff")]
109 pub(crate) fn autodiff_inner(self) -> B::FloatTensorPrimitive {
111 match self {
112 BackendTensor::Autodiff(tensor) => tensor.primitive,
113 _ => unreachable!(),
114 }
115 }
116
117 pub(crate) fn device(&self) -> B::Device {
119 match self {
120 BackendTensor::Float(tensor) => B::float_device(tensor),
121 BackendTensor::Int(tensor) => B::int_device(tensor),
122 BackendTensor::Bool(tensor) => B::bool_device(tensor),
123 BackendTensor::Quantized(tensor) => B::q_device(tensor),
124 #[cfg(feature = "autodiff")]
125 BackendTensor::Autodiff(tensor) => B::float_device(&tensor.primitive),
126 }
127 }
128}
129
130impl<B: Backend> TensorMetadata for BackendTensor<B> {
131 fn dtype(&self) -> DType {
132 match self {
133 BackendTensor::Float(tensor) => tensor.dtype(),
134 BackendTensor::Int(tensor) => tensor.dtype(),
135 BackendTensor::Bool(tensor) => tensor.dtype(),
136 BackendTensor::Quantized(tensor) => tensor.dtype(),
137 #[cfg(feature = "autodiff")]
138 BackendTensor::Autodiff(tensor) => tensor.dtype(),
139 }
140 }
141
142 fn shape(&self) -> Shape {
143 match self {
144 BackendTensor::Float(tensor) => tensor.shape(),
145 BackendTensor::Int(tensor) => tensor.shape(),
146 BackendTensor::Bool(tensor) => tensor.shape(),
147 BackendTensor::Quantized(tensor) => tensor.shape(),
148 #[cfg(feature = "autodiff")]
149 BackendTensor::Autodiff(tensor) => tensor.shape(),
150 }
151 }
152}
153
154impl<B: Backend> QTensorPrimitive for BackendTensor<B> {
155 fn scheme(&self) -> &QuantScheme {
156 match self {
157 BackendTensor::Quantized(tensor) => tensor.scheme(),
158 _ => panic!(
159 "Quantization scheme is not valid for dtype {:?}",
160 self.dtype(),
161 ),
162 }
163 }
164}
165
166#[derive(Clone, Debug)]
172pub struct DispatchTensor {
173 pub(crate) kind: DispatchTensorKind,
175 #[cfg(feature = "autodiff")]
177 pub(crate) checkpointing: CheckpointingStrategy,
178}
179
180#[derive(Clone, Debug)]
188pub enum DispatchTensorKind {
189 #[cfg(feature = "cpu")]
191 Cpu(BackendTensor<Cpu>),
192
193 #[cfg(feature = "cuda")]
195 Cuda(BackendTensor<Cuda>),
196
197 #[cfg(wgpu_metal)]
199 Metal(BackendTensor<Metal>),
200
201 #[cfg(feature = "rocm")]
203 Rocm(BackendTensor<Rocm>),
204
205 #[cfg(wgpu_vulkan)]
207 Vulkan(BackendTensor<Vulkan>),
208
209 #[cfg(wgpu_webgpu)]
211 Wgpu(BackendTensor<Wgpu>),
212
213 #[cfg(feature = "ndarray")]
215 NdArray(BackendTensor<NdArray>),
216
217 #[cfg(feature = "tch")]
219 LibTorch(BackendTensor<LibTorch>),
220
221 #[cfg(feature = "autodiff")]
223 Autodiff(Box<DispatchTensorKind>),
224}
225
226impl TensorMetadata for DispatchTensorKind {
227 fn dtype(&self) -> DType {
228 match self {
229 #[cfg(feature = "cpu")]
230 Self::Cpu(tensor) => tensor.dtype(),
231 #[cfg(feature = "cuda")]
232 Self::Cuda(tensor) => tensor.dtype(),
233 #[cfg(wgpu_metal)]
234 Self::Metal(tensor) => tensor.dtype(),
235 #[cfg(feature = "rocm")]
236 Self::Rocm(tensor) => tensor.dtype(),
237 #[cfg(wgpu_vulkan)]
238 Self::Vulkan(tensor) => tensor.dtype(),
239 #[cfg(wgpu_webgpu)]
240 Self::Wgpu(tensor) => tensor.dtype(),
241 #[cfg(feature = "ndarray")]
242 Self::NdArray(tensor) => tensor.dtype(),
243 #[cfg(feature = "tch")]
244 Self::LibTorch(tensor) => tensor.dtype(),
245 #[cfg(feature = "autodiff")]
246 Self::Autodiff(tensor) => tensor.dtype(),
247 }
248 }
249
250 fn shape(&self) -> Shape {
251 match self {
252 #[cfg(feature = "cpu")]
253 Self::Cpu(tensor) => tensor.shape(),
254 #[cfg(feature = "cuda")]
255 Self::Cuda(tensor) => tensor.shape(),
256 #[cfg(wgpu_metal)]
257 Self::Metal(tensor) => tensor.shape(),
258 #[cfg(feature = "rocm")]
259 Self::Rocm(tensor) => tensor.shape(),
260 #[cfg(wgpu_vulkan)]
261 Self::Vulkan(tensor) => tensor.shape(),
262 #[cfg(wgpu_webgpu)]
263 Self::Wgpu(tensor) => tensor.shape(),
264 #[cfg(feature = "ndarray")]
265 Self::NdArray(tensor) => tensor.shape(),
266 #[cfg(feature = "tch")]
267 Self::LibTorch(tensor) => tensor.shape(),
268 #[cfg(feature = "autodiff")]
269 Self::Autodiff(tensor) => tensor.shape(),
270 }
271 }
272}
273
274impl QTensorPrimitive for DispatchTensorKind {
275 fn scheme(&self) -> &QuantScheme {
276 match self {
277 #[cfg(feature = "cpu")]
278 Self::Cpu(tensor) => tensor.scheme(),
279 #[cfg(feature = "cuda")]
280 Self::Cuda(tensor) => tensor.scheme(),
281 #[cfg(wgpu_metal)]
282 Self::Metal(tensor) => tensor.scheme(),
283 #[cfg(feature = "rocm")]
284 Self::Rocm(tensor) => tensor.scheme(),
285 #[cfg(wgpu_vulkan)]
286 Self::Vulkan(tensor) => tensor.scheme(),
287 #[cfg(wgpu_webgpu)]
288 Self::Wgpu(tensor) => tensor.scheme(),
289 #[cfg(feature = "ndarray")]
290 Self::NdArray(tensor) => tensor.scheme(),
291 #[cfg(feature = "tch")]
292 Self::LibTorch(tensor) => tensor.scheme(),
293 #[cfg(feature = "autodiff")]
294 Self::Autodiff(tensor) => tensor.scheme(),
295 }
296 }
297}
298
299impl TensorMetadata for DispatchTensor {
300 fn dtype(&self) -> DType {
301 self.kind.dtype()
302 }
303
304 fn shape(&self) -> Shape {
305 self.kind.shape()
306 }
307}
308
309impl QTensorPrimitive for DispatchTensor {
310 fn scheme(&self) -> &QuantScheme {
311 self.kind.scheme()
312 }
313}