Skip to main content

burn_dispatch/
tensor.rs

1use burn_backend::{Backend, QTensorPrimitive, TensorMetadata};
2
3use crate::backends::*;
4
5#[cfg(feature = "autodiff")]
6use burn_backend::tensor::FloatTensor;
7
8// TODO: if we reduce the different associated types for float/int/bool/quantized tensor primitives down to a single
9// `B::TensorPrimitive` we can simplify this.
10
11/// Tensor which points to a backend tensor primitive kind.
12#[derive(Clone, Debug)]
13pub enum BackendTensor<B: Backend> {
14    /// Float tensor handle.
15    Float(B::FloatTensorPrimitive),
16    /// Int tensor handle.
17    Int(B::IntTensorPrimitive),
18    /// Bool tensor handle.
19    Bool(B::BoolTensorPrimitive),
20    /// Quantized tensor handle.
21    Quantized(B::QuantizedTensorPrimitive),
22    #[cfg(feature = "autodiff")]
23    /// Autodiff float tensor handle.
24    Autodiff(FloatTensor<Autodiff<B>>),
25}
26
27impl<B: Backend> BackendTensor<B> {
28    /// Returns the inner float tensor primitive.
29    pub(crate) fn float(self) -> B::FloatTensorPrimitive {
30        match self {
31            BackendTensor::Float(tensor) => tensor,
32            BackendTensor::Int(_) => panic!("Should be float, got int"),
33            BackendTensor::Bool(_) => panic!("Should be float, got bool"),
34            BackendTensor::Quantized(_) => panic!("Should be float, got quantized"),
35            #[cfg(feature = "autodiff")]
36            BackendTensor::Autodiff(_) => panic!("Should be float, got autodiff"),
37        }
38    }
39    /// Returns the inner float tensor primitive.
40    pub(crate) fn as_float(&self) -> &B::FloatTensorPrimitive {
41        match self {
42            BackendTensor::Float(tensor) => tensor,
43            BackendTensor::Int(_) => panic!("Should be float, got int"),
44            BackendTensor::Bool(_) => panic!("Should be float, got bool"),
45            BackendTensor::Quantized(_) => panic!("Should be float, got quantized"),
46            #[cfg(feature = "autodiff")]
47            BackendTensor::Autodiff(_) => panic!("Should be float, got autodiff"),
48        }
49    }
50
51    /// Returns the inner int tensor primitive.
52    pub(crate) fn int(self) -> B::IntTensorPrimitive {
53        match self {
54            BackendTensor::Int(tensor) => tensor,
55            BackendTensor::Float(_) => panic!("Should be int, got float"),
56            BackendTensor::Bool(_) => panic!("Should be int, got bool"),
57            BackendTensor::Quantized(_) => panic!("Should be int, got quantized"),
58            #[cfg(feature = "autodiff")]
59            BackendTensor::Autodiff(_) => panic!("Should be int, got autodiff"),
60        }
61    }
62
63    /// Returns the inner bool tensor primitive.
64    pub(crate) fn bool(self) -> B::BoolTensorPrimitive {
65        match self {
66            BackendTensor::Bool(tensor) => tensor,
67            BackendTensor::Float(_) => panic!("Should be bool, got float"),
68            BackendTensor::Int(_) => panic!("Should be bool, got int"),
69            BackendTensor::Quantized(_) => panic!("Should be bool, got quantized"),
70            #[cfg(feature = "autodiff")]
71            BackendTensor::Autodiff(_) => panic!("Should be bool, got autodiff"),
72        }
73    }
74
75    /// Returns the inner quantized tensor primitive.
76    pub(crate) fn quantized(self) -> B::QuantizedTensorPrimitive {
77        match self {
78            BackendTensor::Quantized(tensor) => tensor,
79            _ => unreachable!(),
80        }
81    }
82
83    #[cfg(feature = "autodiff")]
84    /// Returns the inner autodiff tensor primitive.
85    pub(crate) fn autodiff(self) -> FloatTensor<Autodiff<B>> {
86        match self {
87            BackendTensor::Autodiff(tensor) => tensor,
88            // NOTE: this is the panicking code reached in tensor.rs:74:18:
89            _ => unreachable!(),
90        }
91    }
92
93    #[cfg(feature = "autodiff")]
94    /// Returns the inner autodiff tensor primitive.
95    pub(crate) fn as_autodiff(&self) -> &FloatTensor<Autodiff<B>> {
96        match self {
97            BackendTensor::Autodiff(tensor) => tensor,
98            _ => unreachable!(),
99        }
100    }
101
102    #[cfg(feature = "autodiff")]
103    /// Returns the inner autodiff tensor primitive.
104    pub(crate) fn autodiff_inner(self) -> B::FloatTensorPrimitive {
105        match self {
106            BackendTensor::Autodiff(tensor) => tensor.primitive,
107            _ => unreachable!(),
108        }
109    }
110
111    /// Returns the backend device.
112    pub(crate) fn device(&self) -> B::Device {
113        match self {
114            BackendTensor::Float(tensor) => B::float_device(tensor),
115            BackendTensor::Int(tensor) => B::int_device(tensor),
116            BackendTensor::Bool(tensor) => B::bool_device(tensor),
117            BackendTensor::Quantized(tensor) => B::q_device(tensor),
118            #[cfg(feature = "autodiff")]
119            BackendTensor::Autodiff(tensor) => B::float_device(&tensor.primitive),
120        }
121    }
122}
123
124impl<B: Backend> TensorMetadata for BackendTensor<B> {
125    fn dtype(&self) -> burn_std::DType {
126        match self {
127            BackendTensor::Float(tensor) => tensor.dtype(),
128            BackendTensor::Int(tensor) => tensor.dtype(),
129            BackendTensor::Bool(tensor) => tensor.dtype(),
130            BackendTensor::Quantized(tensor) => tensor.dtype(),
131            #[cfg(feature = "autodiff")]
132            BackendTensor::Autodiff(tensor) => tensor.dtype(),
133        }
134    }
135
136    fn shape(&self) -> burn_std::Shape {
137        match self {
138            BackendTensor::Float(tensor) => tensor.shape(),
139            BackendTensor::Int(tensor) => tensor.shape(),
140            BackendTensor::Bool(tensor) => tensor.shape(),
141            BackendTensor::Quantized(tensor) => tensor.shape(),
142            #[cfg(feature = "autodiff")]
143            BackendTensor::Autodiff(tensor) => tensor.shape(),
144        }
145    }
146}
147
148impl<B: Backend> QTensorPrimitive for BackendTensor<B> {
149    fn scheme(&self) -> &burn_std::QuantScheme {
150        match self {
151            BackendTensor::Quantized(tensor) => tensor.scheme(),
152            _ => panic!(
153                "Quantization scheme is not valid for dtype {:?}",
154                self.dtype(),
155            ),
156        }
157    }
158}
159
160/// Dispatch tensor that can hold tensors from any enabled backend.
161///
162/// This enum wraps backend-specific tensor types, allowing runtime selection
163/// of the backend to execute operations on.
164#[derive(Clone, Debug)]
165pub enum DispatchTensor {
166    /// The [CPU backend](Cpu) tensor.
167    #[cfg(feature = "cpu")]
168    Cpu(BackendTensor<Cpu>),
169
170    /// The [CUDA backend](Cuda) tensor.
171    #[cfg(feature = "cuda")]
172    Cuda(BackendTensor<Cuda>),
173
174    /// The [Metal backend](Metal) tensor.
175    #[cfg(wgpu_metal)]
176    Metal(BackendTensor<Metal>),
177
178    /// The [ROCm backend](Rocm) tensor.
179    #[cfg(feature = "rocm")]
180    Rocm(BackendTensor<Rocm>),
181
182    /// The [Vulkan backend](Vulkan) tensor.
183    #[cfg(wgpu_vulkan)]
184    Vulkan(BackendTensor<Vulkan>),
185
186    /// The [WebGPU backend](WebGpu) tensor.
187    #[cfg(wgpu_webgpu)]
188    WebGpu(BackendTensor<WebGpu>),
189
190    /// The [NdArray backend](NdArray) tensor.
191    #[cfg(feature = "ndarray")]
192    NdArray(BackendTensor<NdArray>),
193
194    /// The [LibTorch backend](LibTorch) tensor.
195    #[cfg(feature = "tch")]
196    LibTorch(BackendTensor<LibTorch>),
197
198    /// The [autodiff enabled backend](Autodiff) tensor.
199    #[cfg(feature = "autodiff")]
200    Autodiff(Box<DispatchTensor>),
201}
202
203impl TensorMetadata for DispatchTensor {
204    fn dtype(&self) -> burn_std::DType {
205        match self {
206            #[cfg(feature = "cpu")]
207            DispatchTensor::Cpu(tensor) => tensor.dtype(),
208            #[cfg(feature = "cuda")]
209            DispatchTensor::Cuda(tensor) => tensor.dtype(),
210            #[cfg(wgpu_metal)]
211            DispatchTensor::Metal(tensor) => tensor.dtype(),
212            #[cfg(feature = "rocm")]
213            DispatchTensor::Rocm(tensor) => tensor.dtype(),
214            #[cfg(wgpu_vulkan)]
215            DispatchTensor::Vulkan(tensor) => tensor.dtype(),
216            #[cfg(wgpu_webgpu)]
217            DispatchTensor::WebGpu(tensor) => tensor.dtype(),
218            #[cfg(feature = "ndarray")]
219            DispatchTensor::NdArray(tensor) => tensor.dtype(),
220            #[cfg(feature = "tch")]
221            DispatchTensor::LibTorch(tensor) => tensor.dtype(),
222            #[cfg(feature = "autodiff")]
223            DispatchTensor::Autodiff(tensor) => tensor.dtype(),
224        }
225    }
226
227    fn shape(&self) -> burn_std::Shape {
228        match self {
229            #[cfg(feature = "cpu")]
230            DispatchTensor::Cpu(tensor) => tensor.shape(),
231            #[cfg(feature = "cuda")]
232            DispatchTensor::Cuda(tensor) => tensor.shape(),
233            #[cfg(wgpu_metal)]
234            DispatchTensor::Metal(tensor) => tensor.shape(),
235            #[cfg(feature = "rocm")]
236            DispatchTensor::Rocm(tensor) => tensor.shape(),
237            #[cfg(wgpu_vulkan)]
238            DispatchTensor::Vulkan(tensor) => tensor.shape(),
239            #[cfg(wgpu_webgpu)]
240            DispatchTensor::WebGpu(tensor) => tensor.shape(),
241            #[cfg(feature = "ndarray")]
242            DispatchTensor::NdArray(tensor) => tensor.shape(),
243            #[cfg(feature = "tch")]
244            DispatchTensor::LibTorch(tensor) => tensor.shape(),
245            #[cfg(feature = "autodiff")]
246            DispatchTensor::Autodiff(tensor) => tensor.shape(),
247        }
248    }
249}
250
251impl QTensorPrimitive for DispatchTensor {
252    fn scheme(&self) -> &burn_std::QuantScheme {
253        match self {
254            #[cfg(feature = "cpu")]
255            DispatchTensor::Cpu(tensor) => tensor.scheme(),
256            #[cfg(feature = "cuda")]
257            DispatchTensor::Cuda(tensor) => tensor.scheme(),
258            #[cfg(wgpu_metal)]
259            DispatchTensor::Metal(tensor) => tensor.scheme(),
260            #[cfg(feature = "rocm")]
261            DispatchTensor::Rocm(tensor) => tensor.scheme(),
262            #[cfg(wgpu_vulkan)]
263            DispatchTensor::Vulkan(tensor) => tensor.scheme(),
264            #[cfg(wgpu_webgpu)]
265            DispatchTensor::WebGpu(tensor) => tensor.scheme(),
266            #[cfg(feature = "ndarray")]
267            DispatchTensor::NdArray(tensor) => tensor.scheme(),
268            #[cfg(feature = "tch")]
269            DispatchTensor::LibTorch(tensor) => tensor.scheme(),
270            #[cfg(feature = "autodiff")]
271            DispatchTensor::Autodiff(tensor) => tensor.scheme(),
272        }
273    }
274}