burn_backend/backend/base.rs
1use burn_std::DType;
2pub use burn_std::backtrace::BackTrace;
3
4use alloc::string::String;
5use enumset::{EnumSet, EnumSetType};
6use serde::{Deserialize, Serialize};
7use thiserror::Error;
8
9use crate::element::Element;
10use crate::ops::*;
11use crate::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
12use crate::{QTensorPrimitive, TensorData, TensorMetadata};
13
14#[cfg(feature = "distributed")]
15use crate::distributed::{DistributedParamId, DistributedParams};
16
17use super::DeviceOps;
18
19/// This trait defines all types and functions needed for a backend to be used with burn.
20///
21/// ## Design
22///
23/// This trait aims to be as unopinionated as possible and allows implementations to define
24/// their own types and patterns. Therefore, there are few pre-defined abstractions baked
25/// into this trait.
26///
27/// Backends must define their own tensor types for each data type: `float`, `int`, and `bool`.
28/// Since we minimize assumptions, we chose to separate these types, as they are used in
29/// different contexts. However, some backends may have a generic tensor type that is used
30/// for all data types.
31///
32/// ### Eager Mode
33///
34/// Because burn supports dynamic graphs, the backend trait is designed around kernel
35/// implementations that can be called without any mutable context or graph. This may not be
36/// ideal for backends that want to configure their computational graphs and execute them
37/// multiple times.
38///
39/// To implement this kind of backend, channels could be used to communicate with a backend
40/// server thread to build the computation graphs and re-execute the ones that are repeated,
41/// with some form of cache. Once that pattern has matured, a graph mode backend trait could
42/// be extracted from it, allowing other backends of the same kind to be quickly integrated
43/// with burn. This pattern could also be used to create an operation fusion trait, which
44/// allows backends to define what kind of graph structures can be fused into one operation.
45///
46/// ### Multi-Threaded
47///
48/// Backend tensor types are all `Clone` + `Send`, which allows them to be safely
49/// sent between threads. It is recommended to wrap tensors with [Arc](alloc::sync::Arc),
50/// which avoids copying the tensor's buffer. Note that it is still possible to mutate and
51/// reuse tensors' buffer without locking; see the next section on the Mutable API.
52///
53/// ### Mutable API
54///
55/// There is no mutable or inplace operation API to implement, but that does not mean that
56/// backends cannot support them. Using [try_unwrap](alloc::sync::Arc::try_unwrap) and
57/// [get_mut](alloc::sync::Arc::get_mut) allows backends to have access to an owned or mutable
58/// reference to their tensor buffer data structure if the tensor is not shared. In that case,
59/// backends can dispatch to their owned inplace operations for better performance.
60///
61/// ## Documentation
62///
63/// Most of the documentation for each function can be found on the user API
64#[cfg_attr(doc, doc = crate::doc_tensor!())]
65#[cfg_attr(not(doc), doc = "`Tensor`")]
66/// struct in the `burn-tensor` crate.
67/// For modules, public functions are often created, which can be used by `burn-core` modules.
68pub trait Backend:
69 FloatTensorOps<Self>
70 + BoolTensorOps<Self>
71 + IntTensorOps<Self>
72 + ModuleOps<Self>
73 + ActivationOps<Self>
74 + QTensorOps<Self>
75 + TransactionOps<Self>
76 + Clone
77 + Default
78 + Sized
79 + Send
80 + Sync
81 + core::fmt::Debug
82 + 'static
83{
84 /// Device type.
85 type Device: DeviceOps;
86
87 /// Tensor primitive to be used for all float operations.
88 type FloatTensorPrimitive: TensorMetadata + 'static;
89 /// Default float element type.
90 type FloatElem: Element;
91
92 /// Tensor primitive to be used for all int operations.
93 type IntTensorPrimitive: TensorMetadata + 'static;
94 /// Int element type.
95 type IntElem: Element;
96
97 /// Tensor primitive to be used for all bool operations.
98 type BoolTensorPrimitive: TensorMetadata + 'static;
99 /// Tensor primitive to be used for all bool operations.
100 type BoolElem: Element;
101
102 /// Tensor primitive to be used for all quantized operations.
103 type QuantizedTensorPrimitive: TensorMetadata + QTensorPrimitive + 'static;
104
105 /// If autodiff is enabled.
106 fn ad_enabled(_device: &Self::Device) -> bool {
107 false
108 }
109
110 /// Sets the current allocation mode to persistent.
111 #[allow(unused_variables)]
112 fn memory_persistent_allocations<
113 Output: Send,
114 Input: Send,
115 Func: Fn(Input) -> Output + Send,
116 >(
117 device: &Self::Device,
118 input: Input,
119 func: Func,
120 ) -> Output {
121 func(input)
122 }
123
124 /// Manually triggers a memory cleanup on the given device.
125 #[allow(unused_variables)]
126 fn memory_cleanup(device: &Self::Device) {}
127
128 /// Name of the backend.
129 fn name(device: &Self::Device) -> String;
130
131 /// Seeds the backend on the specified device.
132 ///
133 /// There is no guarantee that only the specified device will be seeded, but it is guaranteed
134 /// that at least the specified device will be seeded.
135 ///
136 /// In all cases, this should ensure deterministic execution for a single-threaded program.
137 fn seed(device: &Self::Device, seed: u64);
138
139 /// Sync the backend, ensure that all computation are finished.
140 fn sync(_device: &Self::Device) -> Result<(), ExecutionError> {
141 Ok(())
142 }
143
144 /// Marks the given data as being used as a staging buffer for transfer between CPU and
145 /// accelerators like GPUs.
146 ///
147 /// The given data might be transferred to pinned memory or another format to improve data transfer
148 /// speed.
149 fn staging<'a, Iter>(_data: Iter, _device: &Self::Device)
150 where
151 Iter: Iterator<Item = &'a mut TensorData>,
152 {
153 }
154
155 /// Whether the type is fully supported by the specified device for general operations.
156 ///
157 /// A type is considered supported if it can be used for the full suite of tensor
158 /// operations, including storage, conversion, and basic arithmetic.
159 ///
160 /// Returning `false` does not necessarily mean the device cannot handle the type at all.
161 /// For instance, a device might support a type only for specialized hardware
162 /// acceleration (e.g., matrix multiplication) but lack general arithmetic support. Such
163 /// types should return `false` here as they are not globally supported.
164 fn supports_dtype(device: &Self::Device, dtype: DType) -> bool {
165 Self::dtype_usage(device, dtype).is_superset(DTypeUsage::general())
166 }
167
168 /// Returns the [DTypeUsageSet] for the given [DType] on the specified device.
169 fn dtype_usage(device: &Self::Device, dtype: DType) -> DTypeUsageSet;
170
171 /// Returns the number of devices available on this backend.
172 /// `device` is a reference device used to determine the underlying backend that should be queried.
173 /// A CUDA device will return all devices available to CUDA, a Vulkan device will return all
174 /// devices available to Vulkan, etc.
175 fn device_count(type_id: u16) -> usize;
176}
177
178/// An error that can happen when syncing a device.
179#[derive(Error, Serialize, Deserialize)]
180pub enum ExecutionError {
181 /// A generic error happened during execution.
182 ///
183 /// The backtrace and context information should be included in the reason string.
184 #[error("An error happened during execution\nCaused by:\n {reason}")]
185 WithContext {
186 /// The reason of the error.
187 reason: String,
188 },
189 /// A generic error happened during execution thrown in the Burn project.
190 ///
191 /// The full context isn't captured by the string alone.
192 #[error("An error happened during execution\nCaused by:\n {reason}")]
193 Generic {
194 /// The reason of the error.
195 reason: String,
196 /// The backtrace.
197 #[serde(skip)]
198 backtrace: BackTrace,
199 },
200}
201
202impl core::fmt::Debug for ExecutionError {
203 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
204 f.write_fmt(format_args!("{self}"))
205 }
206}
207
208/// Trait that allows a backend to support autodiff.
209pub trait AutodiffBackend: Backend {
210 /// The inner backend type.
211 type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>;
212
213 /// Gradients type.
214 type Gradients: Send;
215
216 /// Backward pass.
217 ///
218 /// # Arguments
219 ///
220 /// * `tensor` - The tensor is the last node of computational graph where the gradients are computed.
221 ///
222 /// # Returns
223 ///
224 /// The gradients.
225 fn backward(tensor: FloatTensor<Self>) -> Self::Gradients;
226
227 /// Returns the gradients of a tensor.
228 ///
229 /// # Arguments
230 ///
231 /// * `tensor` - The tensor to extract the gradients from.
232 ///
233 /// # Returns
234 ///
235 /// An optional tensor containing the gradient.
236 fn grad(
237 tensor: &FloatTensor<Self>,
238 grads: &Self::Gradients,
239 ) -> Option<FloatTensor<Self::InnerBackend>>;
240
241 /// Pops the gradients of a tensor and returns them.
242 ///
243 /// # Arguments
244 ///
245 /// * `tensor` - The tensor to pop the gradients from.
246 /// * `grads` - The gradients.
247 ///
248 /// # Returns
249 ///
250 /// An optional tensor containing the given gradients.
251 fn grad_remove(
252 tensor: &FloatTensor<Self>,
253 grads: &mut Self::Gradients,
254 ) -> Option<FloatTensor<Self::InnerBackend>>;
255
256 /// Replace the gradients of a tensor with the one provided.
257 ///
258 /// If no gradient existed for the provided tensor, register it.
259 ///
260 /// # Arguments
261 ///
262 /// * `tensor` - The tensor to pop the gradients from.
263 /// * `grads` - The gradients.
264 /// * `grad` - The updated grad tensor.
265 fn grad_replace(
266 tensor: &FloatTensor<Self>,
267 grads: &mut Self::Gradients,
268 grad: FloatTensor<Self::InnerBackend>,
269 );
270
271 /// Returns the tensor with inner backend type.
272 ///
273 /// # Arguments
274 ///
275 /// * `tensor` - The tensor to get the inner backend tensor for.
276 ///
277 /// # Returns
278 ///
279 /// The inner backend tensor.
280 fn inner(tensor: FloatTensor<Self>) -> FloatTensor<Self::InnerBackend>;
281
282 /// Returns the tensor with inner backend type.
283 ///
284 /// # Arguments
285 ///
286 /// * `tensor` - The tensor to get the inner backend tensor for.
287 ///
288 /// # Returns
289 ///
290 /// The inner backend tensor.
291 fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend>;
292
293 /// Returns the tensor with inner backend type.
294 ///
295 /// # Arguments
296 ///
297 /// * `tensor` - The tensor to get the inner backend tensor for.
298 ///
299 /// # Returns
300 ///
301 /// The inner backend tensor.
302 fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend>;
303
304 /// Returns the tensor with inner backend type.
305 ///
306 /// # Arguments
307 ///
308 /// * `tensor` - The tensor to get the inner backend tensor for.
309 ///
310 /// # Returns
311 ///
312 /// The inner backend tensor.
313 fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend>;
314
315 /// Converts the inner backend tensor to the autodiff backend tensor.
316 ///
317 /// # Arguments
318 ///
319 /// * `tensor` - The inner backend tensor to convert.
320 ///
321 ///
322 /// # Returns
323 ///
324 /// The autodiff backend tensor.
325 fn from_inner(tensor: FloatTensor<Self::InnerBackend>) -> FloatTensor<Self>;
326
327 /// Converts the inner backend tensor to the autodiff backend tensor.
328 ///
329 /// # Arguments
330 ///
331 /// * `tensor` - The inner backend tensor to convert.
332 ///
333 ///
334 /// # Returns
335 ///
336 /// The autodiff backend tensor.
337 fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self>;
338
339 /// Converts the inner backend tensor to the autodiff backend tensor.
340 ///
341 /// # Arguments
342 ///
343 /// * `tensor` - The inner backend tensor to convert.
344 ///
345 ///
346 /// # Returns
347 ///
348 /// The autodiff backend tensor.
349 fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self>;
350
351 /// Converts the inner backend tensor to the autodiff backend tensor.
352 ///
353 /// # Arguments
354 ///
355 /// * `tensor` - The inner backend tensor to convert.
356 ///
357 ///
358 /// # Returns
359 ///
360 /// The autodiff backend tensor.
361 fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self>;
362
363 #[cfg(feature = "distributed")]
364 /// Mark the tensor as distributed across multiple devices.
365 /// The gradients will be aggregated during the backward pass.
366 ///
367 /// This function does nothing when distributed training is not available.
368 fn set_distributed_params(
369 tensor: FloatTensor<Self>,
370 _param_id: DistributedParamId,
371 ) -> FloatTensor<Self> {
372 tensor
373 }
374
375 #[cfg(feature = "distributed")]
376 /// Returns the distributed parameters if the tensor was marked as distributed.
377 fn distributed_params(_tensor: &FloatTensor<Self>) -> Option<DistributedParams> {
378 None
379 }
380
381 #[cfg(feature = "distributed")]
382 /// Returns true if the tensor was marked as distributed.
383 fn is_distributed(_tensor: &FloatTensor<Self>) -> bool {
384 false
385 }
386}
387
388/// Describes how a data type can be used on a given device.
389///
390/// A data type may be supported for different classes of operations. Not all
391/// data types that appear in hardware or kernel implementations are suitable
392/// for general-purpose tensor operations.
393#[derive(Debug, EnumSetType)]
394pub enum DTypeUsage {
395 /// The type can be stored in device memory and converted to and from
396 /// other supported data types.
397 Storage,
398 /// The type supports general-purpose arithmetic and common tensor
399 /// operations (e.g. elementwise ops, reductions, etc.).
400 Arithmetic,
401 /// The type is supported by hardware-accelerated execution paths.
402 ///
403 /// This typically indicates support for accelerator-backed compute units (e.g., tensor
404 /// cores executing MMA instructions) for high-performance operations such as matrix
405 /// multiplication and operations that lower to it.
406 ///
407 /// # Notes
408 /// - A type can be both [`Arithmetic`](DTypeUsage::Arithmetic) and
409 /// [`Accelerated`](DTypeUsage::Accelerated) if it supports general-purpose operations
410 /// *and* accelerated paths.
411 /// - If a type is marked as `Accelerated` but not `Arithmetic`, it is not
412 /// suitable for general-purpose tensor operations and may only be used
413 /// in specific accelerated operations.
414 ///
415 /// `Accelerated` is a **flag**, not a detailed descriptor. It does not enumerate which
416 /// operations are accelerated or which accelerator features are available.
417 Accelerated,
418}
419
420/// A set of [DTypeUsage] representing the total capabilities of a data type on a device.
421pub type DTypeUsageSet = EnumSet<DTypeUsage>;
422
423impl DTypeUsage {
424 /// Returns the usage set required for general-purpose tensor support.
425 pub fn general() -> DTypeUsageSet {
426 DTypeUsage::Storage | DTypeUsage::Arithmetic
427 }
428}