Skip to main content

burn_dispatch/
tensor.rs

1use alloc::boxed::Box;
2
3use burn_backend::{
4    Backend, DType, QTensorPrimitive, Shape, TensorMetadata, quantization::QuantScheme,
5};
6
7#[cfg(feature = "autodiff")]
8use crate::CheckpointingStrategy;
9use crate::backends::*;
10
11#[cfg(feature = "autodiff")]
12use burn_backend::tensor::FloatTensor;
13
14// TODO: if we reduce the different associated types for float/int/bool/quantized tensor primitives down to a single
15// `B::TensorPrimitive` we can simplify this.
16
17/// Tensor which points to a backend tensor primitive kind.
18#[derive(Clone, Debug)]
19pub enum BackendTensor<B: Backend> {
20    /// Float tensor handle.
21    Float(B::FloatTensorPrimitive),
22    /// Int tensor handle.
23    Int(B::IntTensorPrimitive),
24    /// Bool tensor handle.
25    Bool(B::BoolTensorPrimitive),
26    /// Quantized tensor handle.
27    Quantized(B::QuantizedTensorPrimitive),
28    #[cfg(feature = "autodiff")]
29    /// Autodiff float tensor handle.
30    Autodiff(FloatTensor<Autodiff<B>>),
31}
32
33impl<B: Backend> BackendTensor<B> {
34    /// Returns the inner float tensor primitive.
35    pub(crate) fn float(self) -> B::FloatTensorPrimitive {
36        match self {
37            BackendTensor::Float(tensor) => tensor,
38            BackendTensor::Int(_) => panic!("Should be float, got int"),
39            BackendTensor::Bool(_) => panic!("Should be float, got bool"),
40            BackendTensor::Quantized(_) => panic!("Should be float, got quantized"),
41            #[cfg(feature = "autodiff")]
42            BackendTensor::Autodiff(_) => panic!("Should be float, got autodiff"),
43        }
44    }
45    /// Returns the inner float tensor primitive.
46    pub(crate) fn as_float(&self) -> &B::FloatTensorPrimitive {
47        match self {
48            BackendTensor::Float(tensor) => tensor,
49            BackendTensor::Int(_) => panic!("Should be float, got int"),
50            BackendTensor::Bool(_) => panic!("Should be float, got bool"),
51            BackendTensor::Quantized(_) => panic!("Should be float, got quantized"),
52            #[cfg(feature = "autodiff")]
53            BackendTensor::Autodiff(_) => panic!("Should be float, got autodiff"),
54        }
55    }
56
57    /// Returns the inner int tensor primitive.
58    pub(crate) fn int(self) -> B::IntTensorPrimitive {
59        match self {
60            BackendTensor::Int(tensor) => tensor,
61            BackendTensor::Float(_) => panic!("Should be int, got float"),
62            BackendTensor::Bool(_) => panic!("Should be int, got bool"),
63            BackendTensor::Quantized(_) => panic!("Should be int, got quantized"),
64            #[cfg(feature = "autodiff")]
65            BackendTensor::Autodiff(_) => panic!("Should be int, got autodiff"),
66        }
67    }
68
69    /// Returns the inner bool tensor primitive.
70    pub(crate) fn bool(self) -> B::BoolTensorPrimitive {
71        match self {
72            BackendTensor::Bool(tensor) => tensor,
73            BackendTensor::Float(_) => panic!("Should be bool, got float"),
74            BackendTensor::Int(_) => panic!("Should be bool, got int"),
75            BackendTensor::Quantized(_) => panic!("Should be bool, got quantized"),
76            #[cfg(feature = "autodiff")]
77            BackendTensor::Autodiff(_) => panic!("Should be bool, got autodiff"),
78        }
79    }
80
81    /// Returns the inner quantized tensor primitive.
82    pub(crate) fn quantized(self) -> B::QuantizedTensorPrimitive {
83        match self {
84            BackendTensor::Quantized(tensor) => tensor,
85            _ => unreachable!(),
86        }
87    }
88
89    #[cfg(feature = "autodiff")]
90    /// Returns the inner autodiff tensor primitive.
91    pub(crate) fn autodiff(self) -> FloatTensor<Autodiff<B>> {
92        match self {
93            BackendTensor::Autodiff(tensor) => tensor,
94            // NOTE: this is the panicking code reached in tensor.rs:74:18:
95            _ => unreachable!(),
96        }
97    }
98
99    #[cfg(feature = "autodiff")]
100    /// Returns the inner autodiff tensor primitive.
101    pub(crate) fn as_autodiff(&self) -> &FloatTensor<Autodiff<B>> {
102        match self {
103            BackendTensor::Autodiff(tensor) => tensor,
104            _ => unreachable!(),
105        }
106    }
107
108    #[cfg(feature = "autodiff")]
109    /// Returns the inner autodiff tensor primitive.
110    pub(crate) fn autodiff_inner(self) -> B::FloatTensorPrimitive {
111        match self {
112            BackendTensor::Autodiff(tensor) => tensor.primitive,
113            _ => unreachable!(),
114        }
115    }
116
117    /// Returns the backend device.
118    pub(crate) fn device(&self) -> B::Device {
119        match self {
120            BackendTensor::Float(tensor) => B::float_device(tensor),
121            BackendTensor::Int(tensor) => B::int_device(tensor),
122            BackendTensor::Bool(tensor) => B::bool_device(tensor),
123            BackendTensor::Quantized(tensor) => B::q_device(tensor),
124            #[cfg(feature = "autodiff")]
125            BackendTensor::Autodiff(tensor) => B::float_device(&tensor.primitive),
126        }
127    }
128}
129
130impl<B: Backend> TensorMetadata for BackendTensor<B> {
131    fn dtype(&self) -> DType {
132        match self {
133            BackendTensor::Float(tensor) => tensor.dtype(),
134            BackendTensor::Int(tensor) => tensor.dtype(),
135            BackendTensor::Bool(tensor) => tensor.dtype(),
136            BackendTensor::Quantized(tensor) => tensor.dtype(),
137            #[cfg(feature = "autodiff")]
138            BackendTensor::Autodiff(tensor) => tensor.dtype(),
139        }
140    }
141
142    fn shape(&self) -> Shape {
143        match self {
144            BackendTensor::Float(tensor) => tensor.shape(),
145            BackendTensor::Int(tensor) => tensor.shape(),
146            BackendTensor::Bool(tensor) => tensor.shape(),
147            BackendTensor::Quantized(tensor) => tensor.shape(),
148            #[cfg(feature = "autodiff")]
149            BackendTensor::Autodiff(tensor) => tensor.shape(),
150        }
151    }
152}
153
154impl<B: Backend> QTensorPrimitive for BackendTensor<B> {
155    fn scheme(&self) -> &QuantScheme {
156        match self {
157            BackendTensor::Quantized(tensor) => tensor.scheme(),
158            _ => panic!(
159                "Quantization scheme is not valid for dtype {:?}",
160                self.dtype(),
161            ),
162        }
163    }
164}
165
166/// A tensor that can dispatch operations to any enabled backend at runtime.
167///
168/// When the `autodiff` feature is enabled, tensors may carry a checkpointing
169/// strategy used to control gradient computation. This is derived from the
170/// device used to create the tensor.
171#[derive(Clone, Debug)]
172pub struct DispatchTensor {
173    /// Tensor kind primitive.
174    pub(crate) kind: DispatchTensorKind,
175    // Technically more of a device property, but device is not a dispatch tensor field.
176    #[cfg(feature = "autodiff")]
177    pub(crate) checkpointing: CheckpointingStrategy,
178}
179
180/// Internal representation of a [`DispatchTensor`].
181///
182/// This enum contains the concrete backend tensor for each enabled backend.
183/// It is not intended to be used directly; instead, it is manipulated by
184/// the dispatch system to route operations to the correct backend.
185///
186/// Each variant corresponds to a specific backend implementation.
187#[derive(Clone, Debug)]
188pub enum DispatchTensorKind {
189    /// The [CPU backend](Cpu) tensor.
190    #[cfg(feature = "cpu")]
191    Cpu(BackendTensor<Cpu>),
192
193    /// The [CUDA backend](Cuda) tensor.
194    #[cfg(feature = "cuda")]
195    Cuda(BackendTensor<Cuda>),
196
197    /// The [Metal backend](Metal) tensor.
198    #[cfg(wgpu_metal)]
199    Metal(BackendTensor<Metal>),
200
201    /// The [ROCm backend](Rocm) tensor.
202    #[cfg(feature = "rocm")]
203    Rocm(BackendTensor<Rocm>),
204
205    /// The [Vulkan backend](Vulkan) tensor.
206    #[cfg(wgpu_vulkan)]
207    Vulkan(BackendTensor<Vulkan>),
208
209    /// The [WebGPU backend](Wgpu) tensor.
210    #[cfg(wgpu_webgpu)]
211    Wgpu(BackendTensor<Wgpu>),
212
213    /// The [Flex backend](Flex) tensor.
214    #[cfg(feature = "flex")]
215    Flex(BackendTensor<Flex>),
216
217    /// The [NdArray backend](NdArray) tensor.
218    #[cfg(feature = "ndarray")]
219    NdArray(BackendTensor<NdArray>),
220
221    /// The [LibTorch backend](LibTorch) tensor.
222    #[cfg(feature = "tch")]
223    LibTorch(BackendTensor<LibTorch>),
224
225    /// The [autodiff enabled backend](Autodiff) tensor.
226    #[cfg(feature = "autodiff")]
227    Autodiff(Box<DispatchTensorKind>),
228}
229
230impl TensorMetadata for DispatchTensorKind {
231    fn dtype(&self) -> DType {
232        match self {
233            #[cfg(feature = "cpu")]
234            Self::Cpu(tensor) => tensor.dtype(),
235            #[cfg(feature = "cuda")]
236            Self::Cuda(tensor) => tensor.dtype(),
237            #[cfg(wgpu_metal)]
238            Self::Metal(tensor) => tensor.dtype(),
239            #[cfg(feature = "rocm")]
240            Self::Rocm(tensor) => tensor.dtype(),
241            #[cfg(wgpu_vulkan)]
242            Self::Vulkan(tensor) => tensor.dtype(),
243            #[cfg(wgpu_webgpu)]
244            Self::Wgpu(tensor) => tensor.dtype(),
245            #[cfg(feature = "flex")]
246            Self::Flex(tensor) => tensor.dtype(),
247            #[cfg(feature = "ndarray")]
248            Self::NdArray(tensor) => tensor.dtype(),
249            #[cfg(feature = "tch")]
250            Self::LibTorch(tensor) => tensor.dtype(),
251            #[cfg(feature = "autodiff")]
252            Self::Autodiff(tensor) => tensor.dtype(),
253        }
254    }
255
256    fn shape(&self) -> Shape {
257        match self {
258            #[cfg(feature = "cpu")]
259            Self::Cpu(tensor) => tensor.shape(),
260            #[cfg(feature = "cuda")]
261            Self::Cuda(tensor) => tensor.shape(),
262            #[cfg(wgpu_metal)]
263            Self::Metal(tensor) => tensor.shape(),
264            #[cfg(feature = "rocm")]
265            Self::Rocm(tensor) => tensor.shape(),
266            #[cfg(wgpu_vulkan)]
267            Self::Vulkan(tensor) => tensor.shape(),
268            #[cfg(wgpu_webgpu)]
269            Self::Wgpu(tensor) => tensor.shape(),
270            #[cfg(feature = "flex")]
271            Self::Flex(tensor) => tensor.shape(),
272            #[cfg(feature = "ndarray")]
273            Self::NdArray(tensor) => tensor.shape(),
274            #[cfg(feature = "tch")]
275            Self::LibTorch(tensor) => tensor.shape(),
276            #[cfg(feature = "autodiff")]
277            Self::Autodiff(tensor) => tensor.shape(),
278        }
279    }
280}
281
282impl QTensorPrimitive for DispatchTensorKind {
283    fn scheme(&self) -> &QuantScheme {
284        match self {
285            #[cfg(feature = "cpu")]
286            Self::Cpu(tensor) => tensor.scheme(),
287            #[cfg(feature = "cuda")]
288            Self::Cuda(tensor) => tensor.scheme(),
289            #[cfg(wgpu_metal)]
290            Self::Metal(tensor) => tensor.scheme(),
291            #[cfg(feature = "rocm")]
292            Self::Rocm(tensor) => tensor.scheme(),
293            #[cfg(wgpu_vulkan)]
294            Self::Vulkan(tensor) => tensor.scheme(),
295            #[cfg(wgpu_webgpu)]
296            Self::Wgpu(tensor) => tensor.scheme(),
297            #[cfg(feature = "flex")]
298            Self::Flex(tensor) => tensor.scheme(),
299            #[cfg(feature = "ndarray")]
300            Self::NdArray(tensor) => tensor.scheme(),
301            #[cfg(feature = "tch")]
302            Self::LibTorch(tensor) => tensor.scheme(),
303            #[cfg(feature = "autodiff")]
304            Self::Autodiff(tensor) => tensor.scheme(),
305        }
306    }
307}
308
309impl TensorMetadata for DispatchTensor {
310    fn dtype(&self) -> DType {
311        self.kind.dtype()
312    }
313
314    fn shape(&self) -> Shape {
315        self.kind.shape()
316    }
317}
318
319impl QTensorPrimitive for DispatchTensor {
320    fn scheme(&self) -> &QuantScheme {
321        self.kind.scheme()
322    }
323}