1use alloc::format;
2use alloc::string::String;
3
4use burn_backend::Backend;
5use burn_backend::ExecutionError;
6use burn_std::DType;
7
8#[cfg(feature = "autodiff")]
9use burn_autodiff::grads::Gradients;
10#[cfg(feature = "autodiff")]
11use burn_backend::AutodiffBackend;
12
13use crate::backends::*;
14use crate::{DispatchDevice, DispatchTensor};
15
16#[derive(Debug, Default, Clone)]
41pub struct Dispatch;
42
43impl Backend for Dispatch {
44 type Device = DispatchDevice;
45
46 type FloatTensorPrimitive = DispatchTensor;
47
48 type FloatElem = f32;
50
51 type IntTensorPrimitive = DispatchTensor;
52
53 type IntElem = i32;
54
55 type BoolTensorPrimitive = DispatchTensor;
56
57 type BoolElem = u8;
58
59 type QuantizedTensorPrimitive = DispatchTensor;
60
61 fn name(device: &Self::Device) -> String {
62 let inner = dispatch_device!(device, |device| B::name(device));
63 format!("dispatch<{inner}>")
64 }
65
66 fn seed(device: &Self::Device, seed: u64) {
67 dispatch_device!(device, |device| B::seed(device, seed))
68 }
69
70 fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
71 dispatch_device!(device, |device| B::sync(device))
72 }
73
74 fn dtype_usage(device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {
75 dispatch_device!(device, |device| B::dtype_usage(device, dtype))
76 }
77
78 fn ad_enabled(device: &Self::Device) -> bool {
79 match device {
80 #[cfg(feature = "autodiff")]
81 DispatchDevice::Autodiff(_) => true,
82 _ => false,
83 }
84 }
85}
86
87#[cfg(feature = "autodiff")]
88impl AutodiffBackend for Dispatch {
89 type InnerBackend = Dispatch;
90
91 type Gradients = Gradients;
92
93 fn backward(tensor: DispatchTensor) -> Self::Gradients {
94 match tensor {
95 #[cfg(feature = "autodiff")]
96 DispatchTensor::Autodiff(tensor) => match *tensor {
97 #[cfg(feature = "cpu")]
98 DispatchTensor::Cpu(tensor) => tensor.autodiff().backward(),
99 #[cfg(feature = "cuda")]
100 DispatchTensor::Cuda(tensor) => tensor.autodiff().backward(),
101 #[cfg(wgpu_metal)]
102 DispatchTensor::Metal(tensor) => tensor.autodiff().backward(),
103 #[cfg(feature = "rocm")]
104 DispatchTensor::Rocm(tensor) => tensor.autodiff().backward(),
105 #[cfg(wgpu_vulkan)]
106 DispatchTensor::Vulkan(tensor) => tensor.autodiff().backward(),
107 #[cfg(wgpu_webgpu)]
108 DispatchTensor::WebGpu(tensor) => tensor.autodiff().backward(),
109 #[cfg(feature = "ndarray")]
110 DispatchTensor::NdArray(tensor) => tensor.autodiff().backward(),
111 DispatchTensor::Autodiff(_) => {
112 panic!("Autodiff should not wrap an autodiff tensor.")
113 }
114 },
115 _ => panic!("Requires autodiff tensor."),
116 }
117 }
118
119 fn grad(tensor: &DispatchTensor, grads: &Self::Gradients) -> Option<DispatchTensor> {
120 match &tensor {
121 #[cfg(feature = "autodiff")]
122 DispatchTensor::Autodiff(tensor) => match &**tensor {
123 #[cfg(feature = "cpu")]
124 DispatchTensor::Cpu(tensor) => tensor
125 .as_autodiff()
126 .grad(grads)
127 .map(|t| DispatchTensor::Cpu(crate::BackendTensor::Float(t))),
128 #[cfg(feature = "cuda")]
129 DispatchTensor::Cuda(tensor) => tensor
130 .as_autodiff()
131 .grad(grads)
132 .map(|t| DispatchTensor::Cuda(crate::BackendTensor::Float(t))),
133 #[cfg(wgpu_metal)]
134 DispatchTensor::Metal(tensor) => tensor
135 .as_autodiff()
136 .grad(grads)
137 .map(|t| DispatchTensor::Metal(crate::BackendTensor::Float(t))),
138 #[cfg(feature = "rocm")]
139 DispatchTensor::Rocm(tensor) => tensor
140 .as_autodiff()
141 .grad(grads)
142 .map(|t| DispatchTensor::Rocm(crate::BackendTensor::Float(t))),
143 #[cfg(wgpu_vulkan)]
144 DispatchTensor::Vulkan(tensor) => tensor
145 .as_autodiff()
146 .grad(grads)
147 .map(|t| DispatchTensor::Vulkan(crate::BackendTensor::Float(t))),
148 #[cfg(wgpu_webgpu)]
149 DispatchTensor::WebGpu(tensor) => tensor
150 .as_autodiff()
151 .grad(grads)
152 .map(|t| DispatchTensor::WebGpu(crate::BackendTensor::Float(t))),
153 #[cfg(feature = "ndarray")]
154 DispatchTensor::NdArray(tensor) => tensor
155 .as_autodiff()
156 .grad(grads)
157 .map(|t| DispatchTensor::NdArray(crate::BackendTensor::Float(t))),
158 DispatchTensor::Autodiff(_) => {
159 panic!("Autodiff should not wrap an autodiff tensor.")
160 }
161 },
162 _ => panic!("Requires autodiff tensor."),
163 }
164 }
165
166 fn grad_remove(tensor: &DispatchTensor, grads: &mut Self::Gradients) -> Option<DispatchTensor> {
167 match &tensor {
168 #[cfg(feature = "autodiff")]
169 DispatchTensor::Autodiff(tensor) => match &**tensor {
170 #[cfg(feature = "cpu")]
171 DispatchTensor::Cpu(tensor) => tensor
172 .as_autodiff()
173 .grad_remove(grads)
174 .map(|t| DispatchTensor::Cpu(crate::BackendTensor::Float(t))),
175 #[cfg(feature = "cuda")]
176 DispatchTensor::Cuda(tensor) => tensor
177 .as_autodiff()
178 .grad_remove(grads)
179 .map(|t| DispatchTensor::Cuda(crate::BackendTensor::Float(t))),
180 #[cfg(wgpu_metal)]
181 DispatchTensor::Metal(tensor) => tensor
182 .as_autodiff()
183 .grad_remove(grads)
184 .map(|t| DispatchTensor::Metal(crate::BackendTensor::Float(t))),
185 #[cfg(feature = "rocm")]
186 DispatchTensor::Rocm(tensor) => tensor
187 .as_autodiff()
188 .grad_remove(grads)
189 .map(|t| DispatchTensor::Rocm(crate::BackendTensor::Float(t))),
190 #[cfg(wgpu_vulkan)]
191 DispatchTensor::Vulkan(tensor) => tensor
192 .as_autodiff()
193 .grad_remove(grads)
194 .map(|t| DispatchTensor::Vulkan(crate::BackendTensor::Float(t))),
195 #[cfg(wgpu_webgpu)]
196 DispatchTensor::WebGpu(tensor) => tensor
197 .as_autodiff()
198 .grad_remove(grads)
199 .map(|t| DispatchTensor::WebGpu(crate::BackendTensor::Float(t))),
200 #[cfg(feature = "ndarray")]
201 DispatchTensor::NdArray(tensor) => tensor
202 .as_autodiff()
203 .grad_remove(grads)
204 .map(|t| DispatchTensor::NdArray(crate::BackendTensor::Float(t))),
205 DispatchTensor::Autodiff(_) => {
206 panic!("Autodiff should not wrap an autodiff tensor.")
207 }
208 },
209 _ => panic!("Requires autodiff tensor."),
210 }
211 }
212
213 fn grad_replace(tensor: &DispatchTensor, grads: &mut Self::Gradients, grad: DispatchTensor) {
214 match &tensor {
215 #[cfg(feature = "autodiff")]
216 DispatchTensor::Autodiff(tensor) => match (&**tensor, grad) {
217 #[cfg(feature = "cpu")]
218 (DispatchTensor::Cpu(tensor), DispatchTensor::Cpu(grad)) => {
219 tensor.as_autodiff().grad_replace(grads, grad.float())
220 }
221 #[cfg(feature = "cuda")]
222 (DispatchTensor::Cuda(tensor), DispatchTensor::Cuda(grad)) => {
223 tensor.as_autodiff().grad_replace(grads, grad.float())
224 }
225 #[cfg(wgpu_metal)]
226 (DispatchTensor::Metal(tensor), DispatchTensor::Metal(grad)) => {
227 tensor.as_autodiff().grad_replace(grads, grad.float())
228 }
229 #[cfg(feature = "rocm")]
230 (DispatchTensor::Rocm(tensor), DispatchTensor::Rocm(grad)) => {
231 tensor.as_autodiff().grad_replace(grads, grad.float())
232 }
233 #[cfg(wgpu_vulkan)]
234 (DispatchTensor::Vulkan(tensor), DispatchTensor::Vulkan(grad)) => {
235 tensor.as_autodiff().grad_replace(grads, grad.float())
236 }
237 #[cfg(wgpu_webgpu)]
238 (DispatchTensor::WebGpu(tensor), DispatchTensor::WebGpu(grad)) => {
239 tensor.as_autodiff().grad_replace(grads, grad.float())
240 }
241 #[cfg(feature = "ndarray")]
242 (DispatchTensor::NdArray(tensor), DispatchTensor::NdArray(grad)) => {
243 tensor.as_autodiff().grad_replace(grads, grad.float())
244 }
245 (DispatchTensor::Autodiff(_), _) => {
246 panic!("Autodiff should not wrap an autodiff tensor.")
247 }
248 (t, g) => panic!(
249 "The provided tensors are not on the same backend. Got backends {t:?} and {g:?}."
250 ),
251 },
252 _ => panic!("Requires autodiff tensor."),
253 }
254 }
255
256 fn inner(tensor: DispatchTensor) -> DispatchTensor {
257 match tensor {
258 #[cfg(feature = "autodiff")]
259 DispatchTensor::Autodiff(tensor) => match *tensor {
260 #[cfg(feature = "cpu")]
261 DispatchTensor::Cpu(tensor) => {
262 DispatchTensor::Cpu(crate::BackendTensor::Float(tensor.autodiff().primitive))
263 }
264 #[cfg(feature = "cuda")]
265 DispatchTensor::Cuda(tensor) => {
266 DispatchTensor::Cuda(crate::BackendTensor::Float(tensor.autodiff().primitive))
267 }
268 #[cfg(wgpu_metal)]
269 DispatchTensor::Metal(tensor) => {
270 DispatchTensor::Metal(crate::BackendTensor::Float(tensor.autodiff().primitive))
271 }
272 #[cfg(feature = "rocm")]
273 DispatchTensor::Rocm(tensor) => {
274 DispatchTensor::Rocm(crate::BackendTensor::Float(tensor.autodiff().primitive))
275 }
276 #[cfg(wgpu_vulkan)]
277 DispatchTensor::Vulkan(tensor) => {
278 DispatchTensor::Vulkan(crate::BackendTensor::Float(tensor.autodiff().primitive))
279 }
280 #[cfg(wgpu_webgpu)]
281 DispatchTensor::WebGpu(tensor) => {
282 DispatchTensor::WebGpu(crate::BackendTensor::Float(tensor.autodiff().primitive))
283 }
284 #[cfg(feature = "ndarray")]
285 DispatchTensor::NdArray(tensor) => DispatchTensor::NdArray(
286 crate::BackendTensor::Float(tensor.autodiff().primitive),
287 ),
288 DispatchTensor::Autodiff(_) => {
289 panic!("Autodiff should not wrap an autodiff tensor.")
290 }
291 },
292 _ => panic!("Requires autodiff tensor."),
293 }
294 }
295
296 fn int_inner(tensor: DispatchTensor) -> DispatchTensor {
297 tensor
298 }
299
300 fn bool_inner(tensor: DispatchTensor) -> DispatchTensor {
301 tensor
302 }
303
304 fn q_inner(tensor: DispatchTensor) -> DispatchTensor {
305 tensor
306 }
307
308 fn from_inner(tensor: DispatchTensor) -> DispatchTensor {
309 match tensor {
310 #[cfg(feature = "cpu")]
311 DispatchTensor::Cpu(tensor) => DispatchTensor::Autodiff(Box::new(DispatchTensor::Cpu(
312 crate::BackendTensor::Autodiff(Autodiff::<Cpu<f32>>::from_inner(tensor.float())),
313 ))),
314 #[cfg(feature = "cuda")]
315 DispatchTensor::Cuda(tensor) => DispatchTensor::Autodiff(Box::new(
316 DispatchTensor::Cuda(crate::BackendTensor::Autodiff(
317 Autodiff::<Cuda<f32>>::from_inner(tensor.float()),
318 )),
319 )),
320 #[cfg(wgpu_metal)]
321 DispatchTensor::Metal(tensor) => DispatchTensor::Autodiff(Box::new(
322 DispatchTensor::Metal(crate::BackendTensor::Autodiff(
323 Autodiff::<Metal<f32>>::from_inner(tensor.float()),
324 )),
325 )),
326 #[cfg(feature = "rocm")]
327 DispatchTensor::Rocm(tensor) => DispatchTensor::Autodiff(Box::new(
328 DispatchTensor::Rocm(crate::BackendTensor::Autodiff(
329 Autodiff::<Rocm<f32>>::from_inner(tensor.float()),
330 )),
331 )),
332 #[cfg(wgpu_vulkan)]
333 DispatchTensor::Vulkan(tensor) => DispatchTensor::Autodiff(Box::new(
334 DispatchTensor::Vulkan(crate::BackendTensor::Autodiff(
335 Autodiff::<Vulkan<f32>>::from_inner(tensor.float()),
336 )),
337 )),
338 #[cfg(wgpu_webgpu)]
339 DispatchTensor::WebGpu(tensor) => DispatchTensor::Autodiff(Box::new(
340 DispatchTensor::WebGpu(crate::BackendTensor::Autodiff(
341 Autodiff::<WebGpu<f32>>::from_inner(tensor.float()),
342 )),
343 )),
344 #[cfg(feature = "ndarray")]
345 DispatchTensor::NdArray(tensor) => DispatchTensor::Autodiff(Box::new(
346 DispatchTensor::NdArray(crate::BackendTensor::Autodiff(
347 Autodiff::<NdArray<f32>>::from_inner(tensor.float()),
348 )),
349 )),
350 DispatchTensor::Autodiff(_) => {
351 panic!("Autodiff should not wrap an autodiff tensor.")
352 }
353 }
354 }
355
356 fn int_from_inner(tensor: DispatchTensor) -> DispatchTensor {
357 tensor
358 }
359
360 fn bool_from_inner(tensor: DispatchTensor) -> DispatchTensor {
361 tensor
362 }
363
364 fn q_from_inner(tensor: DispatchTensor) -> DispatchTensor {
365 tensor
366 }
367}
368
369impl DispatchTensor {
370 pub(crate) fn device(&self) -> DispatchDevice {
371 match self {
372 #[cfg(feature = "cpu")]
373 DispatchTensor::Cpu(tensor) => DispatchDevice::Cpu(tensor.device()),
374 #[cfg(feature = "cuda")]
375 DispatchTensor::Cuda(tensor) => DispatchDevice::Cuda(tensor.device()),
376 #[cfg(wgpu_metal)]
377 DispatchTensor::Metal(tensor) => DispatchDevice::Metal(tensor.device()),
378 #[cfg(feature = "rocm")]
379 DispatchTensor::Rocm(tensor) => DispatchDevice::Rocm(tensor.device()),
380 #[cfg(wgpu_vulkan)]
381 DispatchTensor::Vulkan(tensor) => DispatchDevice::Vulkan(tensor.device()),
382 #[cfg(wgpu_webgpu)]
383 DispatchTensor::WebGpu(tensor) => DispatchDevice::WebGpu(tensor.device()),
384 #[cfg(feature = "ndarray")]
385 DispatchTensor::NdArray(tensor) => DispatchDevice::NdArray(tensor.device()),
386 #[cfg(feature = "tch")]
387 DispatchTensor::LibTorch(tensor) => DispatchDevice::LibTorch(tensor.device()),
388 #[cfg(feature = "autodiff")]
389 DispatchTensor::Autodiff(tensor) => DispatchDevice::autodiff(tensor.device()),
390 }
391 }
392}