Skip to main content

burn_dispatch/
tensor.rs

1use alloc::boxed::Box;
2
3use burn_backend::{
4    Backend, DType, QTensorPrimitive, Shape, TensorMetadata, quantization::QuantScheme,
5};
6
7#[cfg(feature = "autodiff")]
8use crate::CheckpointingStrategy;
9use crate::backends::*;
10
11#[cfg(feature = "autodiff")]
12use burn_backend::tensor::FloatTensor;
13
14// TODO: if we reduce the different associated types for float/int/bool/quantized tensor primitives down to a single
15// `B::TensorPrimitive` we can simplify this.
16
17/// Tensor which points to a backend tensor primitive kind.
18#[derive(Clone, Debug)]
19pub enum BackendTensor<B: Backend> {
20    /// Float tensor handle.
21    Float(B::FloatTensorPrimitive),
22    /// Int tensor handle.
23    Int(B::IntTensorPrimitive),
24    /// Bool tensor handle.
25    Bool(B::BoolTensorPrimitive),
26    /// Quantized tensor handle.
27    Quantized(B::QuantizedTensorPrimitive),
28    #[cfg(feature = "autodiff")]
29    /// Autodiff float tensor handle.
30    Autodiff(FloatTensor<Autodiff<B>>),
31}
32
33impl<B: Backend> BackendTensor<B> {
34    /// Returns the inner float tensor primitive.
35    pub(crate) fn float(self) -> B::FloatTensorPrimitive {
36        match self {
37            BackendTensor::Float(tensor) => tensor,
38            BackendTensor::Int(_) => panic!("Should be float, got int"),
39            BackendTensor::Bool(_) => panic!("Should be float, got bool"),
40            BackendTensor::Quantized(_) => panic!("Should be float, got quantized"),
41            #[cfg(feature = "autodiff")]
42            BackendTensor::Autodiff(_) => panic!("Should be float, got autodiff"),
43        }
44    }
45    /// Returns the inner float tensor primitive.
46    pub(crate) fn as_float(&self) -> &B::FloatTensorPrimitive {
47        match self {
48            BackendTensor::Float(tensor) => tensor,
49            BackendTensor::Int(_) => panic!("Should be float, got int"),
50            BackendTensor::Bool(_) => panic!("Should be float, got bool"),
51            BackendTensor::Quantized(_) => panic!("Should be float, got quantized"),
52            #[cfg(feature = "autodiff")]
53            BackendTensor::Autodiff(_) => panic!("Should be float, got autodiff"),
54        }
55    }
56
57    /// Returns the inner int tensor primitive.
58    pub(crate) fn int(self) -> B::IntTensorPrimitive {
59        match self {
60            BackendTensor::Int(tensor) => tensor,
61            BackendTensor::Float(_) => panic!("Should be int, got float"),
62            BackendTensor::Bool(_) => panic!("Should be int, got bool"),
63            BackendTensor::Quantized(_) => panic!("Should be int, got quantized"),
64            #[cfg(feature = "autodiff")]
65            BackendTensor::Autodiff(_) => panic!("Should be int, got autodiff"),
66        }
67    }
68
69    /// Returns the inner bool tensor primitive.
70    pub(crate) fn bool(self) -> B::BoolTensorPrimitive {
71        match self {
72            BackendTensor::Bool(tensor) => tensor,
73            BackendTensor::Float(_) => panic!("Should be bool, got float"),
74            BackendTensor::Int(_) => panic!("Should be bool, got int"),
75            BackendTensor::Quantized(_) => panic!("Should be bool, got quantized"),
76            #[cfg(feature = "autodiff")]
77            BackendTensor::Autodiff(_) => panic!("Should be bool, got autodiff"),
78        }
79    }
80
81    /// Returns the inner quantized tensor primitive.
82    pub(crate) fn quantized(self) -> B::QuantizedTensorPrimitive {
83        match self {
84            BackendTensor::Quantized(tensor) => tensor,
85            _ => unreachable!(),
86        }
87    }
88
89    #[cfg(feature = "autodiff")]
90    /// Returns the inner autodiff tensor primitive.
91    pub(crate) fn autodiff(self) -> FloatTensor<Autodiff<B>> {
92        match self {
93            BackendTensor::Autodiff(tensor) => tensor,
94            // NOTE: this is the panicking code reached in tensor.rs:74:18:
95            _ => unreachable!(),
96        }
97    }
98
99    #[cfg(feature = "autodiff")]
100    /// Returns the inner autodiff tensor primitive.
101    pub(crate) fn as_autodiff(&self) -> &FloatTensor<Autodiff<B>> {
102        match self {
103            BackendTensor::Autodiff(tensor) => tensor,
104            _ => unreachable!(),
105        }
106    }
107
108    #[cfg(feature = "autodiff")]
109    /// Returns the inner autodiff tensor primitive.
110    pub(crate) fn autodiff_inner(self) -> B::FloatTensorPrimitive {
111        match self {
112            BackendTensor::Autodiff(tensor) => tensor.primitive,
113            _ => unreachable!(),
114        }
115    }
116
117    /// Returns the backend device.
118    pub(crate) fn device(&self) -> B::Device {
119        match self {
120            BackendTensor::Float(tensor) => B::float_device(tensor),
121            BackendTensor::Int(tensor) => B::int_device(tensor),
122            BackendTensor::Bool(tensor) => B::bool_device(tensor),
123            BackendTensor::Quantized(tensor) => B::q_device(tensor),
124            #[cfg(feature = "autodiff")]
125            BackendTensor::Autodiff(tensor) => B::float_device(&tensor.primitive),
126        }
127    }
128}
129
130impl<B: Backend> TensorMetadata for BackendTensor<B> {
131    fn dtype(&self) -> DType {
132        match self {
133            BackendTensor::Float(tensor) => tensor.dtype(),
134            BackendTensor::Int(tensor) => tensor.dtype(),
135            BackendTensor::Bool(tensor) => tensor.dtype(),
136            BackendTensor::Quantized(tensor) => tensor.dtype(),
137            #[cfg(feature = "autodiff")]
138            BackendTensor::Autodiff(tensor) => tensor.dtype(),
139        }
140    }
141
142    fn shape(&self) -> Shape {
143        match self {
144            BackendTensor::Float(tensor) => tensor.shape(),
145            BackendTensor::Int(tensor) => tensor.shape(),
146            BackendTensor::Bool(tensor) => tensor.shape(),
147            BackendTensor::Quantized(tensor) => tensor.shape(),
148            #[cfg(feature = "autodiff")]
149            BackendTensor::Autodiff(tensor) => tensor.shape(),
150        }
151    }
152}
153
154impl<B: Backend> QTensorPrimitive for BackendTensor<B> {
155    fn scheme(&self) -> &QuantScheme {
156        match self {
157            BackendTensor::Quantized(tensor) => tensor.scheme(),
158            _ => panic!(
159                "Quantization scheme is not valid for dtype {:?}",
160                self.dtype(),
161            ),
162        }
163    }
164}
165
166/// A tensor that can dispatch operations to any enabled backend at runtime.
167///
168/// When the `autodiff` feature is enabled, tensors may carry a checkpointing
169/// strategy used to control gradient computation. This is derived from the
170/// device used to create the tensor.
171#[derive(Clone, Debug)]
172pub struct DispatchTensor {
173    /// Tensor kind primitive.
174    pub(crate) kind: DispatchTensorKind,
175    // Technically more of a device property, but device is not a dispatch tensor field.
176    #[cfg(feature = "autodiff")]
177    pub(crate) checkpointing: CheckpointingStrategy,
178}
179
180/// Internal representation of a [`DispatchTensor`].
181///
182/// This enum contains the concrete backend tensor for each enabled backend.
183/// It is not intended to be used directly; instead, it is manipulated by
184/// the dispatch system to route operations to the correct backend.
185///
186/// Each variant corresponds to a specific backend implementation.
187#[derive(Clone, Debug)]
188pub enum DispatchTensorKind {
189    /// The [CPU backend](Cpu) tensor.
190    #[cfg(feature = "cpu")]
191    Cpu(BackendTensor<Cpu>),
192
193    /// The [CUDA backend](Cuda) tensor.
194    #[cfg(feature = "cuda")]
195    Cuda(BackendTensor<Cuda>),
196
197    /// The [Metal backend](Metal) tensor.
198    #[cfg(wgpu_metal)]
199    Metal(BackendTensor<Metal>),
200
201    /// The [ROCm backend](Rocm) tensor.
202    #[cfg(feature = "rocm")]
203    Rocm(BackendTensor<Rocm>),
204
205    /// The [Vulkan backend](Vulkan) tensor.
206    #[cfg(wgpu_vulkan)]
207    Vulkan(BackendTensor<Vulkan>),
208
209    /// The [WebGPU backend](Wgpu) tensor.
210    #[cfg(wgpu_webgpu)]
211    Wgpu(BackendTensor<Wgpu>),
212
213    /// The [NdArray backend](NdArray) tensor.
214    #[cfg(feature = "ndarray")]
215    NdArray(BackendTensor<NdArray>),
216
217    /// The [LibTorch backend](LibTorch) tensor.
218    #[cfg(feature = "tch")]
219    LibTorch(BackendTensor<LibTorch>),
220
221    /// The [autodiff enabled backend](Autodiff) tensor.
222    #[cfg(feature = "autodiff")]
223    Autodiff(Box<DispatchTensorKind>),
224}
225
226impl TensorMetadata for DispatchTensorKind {
227    fn dtype(&self) -> DType {
228        match self {
229            #[cfg(feature = "cpu")]
230            Self::Cpu(tensor) => tensor.dtype(),
231            #[cfg(feature = "cuda")]
232            Self::Cuda(tensor) => tensor.dtype(),
233            #[cfg(wgpu_metal)]
234            Self::Metal(tensor) => tensor.dtype(),
235            #[cfg(feature = "rocm")]
236            Self::Rocm(tensor) => tensor.dtype(),
237            #[cfg(wgpu_vulkan)]
238            Self::Vulkan(tensor) => tensor.dtype(),
239            #[cfg(wgpu_webgpu)]
240            Self::Wgpu(tensor) => tensor.dtype(),
241            #[cfg(feature = "ndarray")]
242            Self::NdArray(tensor) => tensor.dtype(),
243            #[cfg(feature = "tch")]
244            Self::LibTorch(tensor) => tensor.dtype(),
245            #[cfg(feature = "autodiff")]
246            Self::Autodiff(tensor) => tensor.dtype(),
247        }
248    }
249
250    fn shape(&self) -> Shape {
251        match self {
252            #[cfg(feature = "cpu")]
253            Self::Cpu(tensor) => tensor.shape(),
254            #[cfg(feature = "cuda")]
255            Self::Cuda(tensor) => tensor.shape(),
256            #[cfg(wgpu_metal)]
257            Self::Metal(tensor) => tensor.shape(),
258            #[cfg(feature = "rocm")]
259            Self::Rocm(tensor) => tensor.shape(),
260            #[cfg(wgpu_vulkan)]
261            Self::Vulkan(tensor) => tensor.shape(),
262            #[cfg(wgpu_webgpu)]
263            Self::Wgpu(tensor) => tensor.shape(),
264            #[cfg(feature = "ndarray")]
265            Self::NdArray(tensor) => tensor.shape(),
266            #[cfg(feature = "tch")]
267            Self::LibTorch(tensor) => tensor.shape(),
268            #[cfg(feature = "autodiff")]
269            Self::Autodiff(tensor) => tensor.shape(),
270        }
271    }
272}
273
274impl QTensorPrimitive for DispatchTensorKind {
275    fn scheme(&self) -> &QuantScheme {
276        match self {
277            #[cfg(feature = "cpu")]
278            Self::Cpu(tensor) => tensor.scheme(),
279            #[cfg(feature = "cuda")]
280            Self::Cuda(tensor) => tensor.scheme(),
281            #[cfg(wgpu_metal)]
282            Self::Metal(tensor) => tensor.scheme(),
283            #[cfg(feature = "rocm")]
284            Self::Rocm(tensor) => tensor.scheme(),
285            #[cfg(wgpu_vulkan)]
286            Self::Vulkan(tensor) => tensor.scheme(),
287            #[cfg(wgpu_webgpu)]
288            Self::Wgpu(tensor) => tensor.scheme(),
289            #[cfg(feature = "ndarray")]
290            Self::NdArray(tensor) => tensor.scheme(),
291            #[cfg(feature = "tch")]
292            Self::LibTorch(tensor) => tensor.scheme(),
293            #[cfg(feature = "autodiff")]
294            Self::Autodiff(tensor) => tensor.scheme(),
295        }
296    }
297}
298
299impl TensorMetadata for DispatchTensor {
300    fn dtype(&self) -> DType {
301        self.kind.dtype()
302    }
303
304    fn shape(&self) -> Shape {
305        self.kind.shape()
306    }
307}
308
309impl QTensorPrimitive for DispatchTensor {
310    fn scheme(&self) -> &QuantScheme {
311        self.kind.scheme()
312    }
313}