Skip to main content

burn_backend/backend/
base.rs

1use burn_std::DType;
2pub use burn_std::backtrace::BackTrace;
3
4use alloc::string::String;
5use enumset::{EnumSet, EnumSetType};
6use serde::{Deserialize, Serialize};
7use thiserror::Error;
8
9use crate::element::Element;
10use crate::ops::*;
11use crate::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
12use crate::{QTensorPrimitive, TensorData, TensorMetadata};
13
14#[cfg(feature = "distributed")]
15use crate::distributed::{DistributedParamId, DistributedParams};
16
17use super::DeviceOps;
18
19/// This trait defines all types and functions needed for a backend to be used with burn.
20///
21/// ## Design
22///
23/// This trait aims to be as unopinionated as possible and allows implementations to define
24/// their own types and patterns. Therefore, there are few pre-defined abstractions baked
25/// into this trait.
26///
27/// Backends must define their own tensor types for each data type: `float`, `int`, and `bool`.
28/// Since we minimize assumptions, we chose to separate these types, as they are used in
29/// different contexts. However, some backends may have a generic tensor type that is used
30/// for all data types.
31///
32/// ### Eager Mode
33///
34/// Because burn supports dynamic graphs, the backend trait is designed around kernel
35/// implementations that can be called without any mutable context or graph. This may not be
36/// ideal for backends that want to configure their computational graphs and execute them
37/// multiple times.
38///
39/// To implement this kind of backend, channels could be used to communicate with a backend
40/// server thread to build the computation graphs and re-execute the ones that are repeated,
41/// with some form of cache. Once that pattern has matured, a graph mode backend trait could
42/// be extracted from it, allowing other backends of the same kind to be quickly integrated
43/// with burn. This pattern could also be used to create an operation fusion trait, which
44/// allows backends to define what kind of graph structures can be fused into one operation.
45///
46/// ### Multi-Threaded
47///
48/// Backend tensor types are all `Clone` + `Send`, which allows them to be safely
49/// sent between threads. It is recommended to wrap tensors with [Arc](alloc::sync::Arc),
50/// which avoids copying the tensor's buffer. Note that it is still possible to mutate and
51/// reuse tensors' buffer without locking; see the next section on the Mutable API.
52///
53/// ### Mutable API
54///
55/// There is no mutable or inplace operation API to implement, but that does not mean that
56/// backends cannot support them. Using [try_unwrap](alloc::sync::Arc::try_unwrap) and
57/// [get_mut](alloc::sync::Arc::get_mut) allows backends to have access to an owned or mutable
58/// reference to their tensor buffer data structure if the tensor is not shared. In that case,
59/// backends can dispatch to their owned inplace operations for better performance.
60///
61/// ## Documentation
62///
63/// Most of the documentation for each function can be found on the user API
64#[cfg_attr(doc, doc = crate::doc_tensor!())]
65#[cfg_attr(not(doc), doc = "`Tensor`")]
66/// struct in the `burn-tensor` crate.
67/// For modules, public functions are often created, which can be used by `burn-core` modules.
68pub trait Backend:
69    FloatTensorOps<Self>
70    + BoolTensorOps<Self>
71    + IntTensorOps<Self>
72    + ModuleOps<Self>
73    + ActivationOps<Self>
74    + QTensorOps<Self>
75    + TransactionOps<Self>
76    + Clone
77    + Default
78    + Sized
79    + Send
80    + Sync
81    + core::fmt::Debug
82    + 'static
83{
84    /// Device type.
85    type Device: DeviceOps;
86
87    /// Tensor primitive to be used for all float operations.
88    type FloatTensorPrimitive: TensorMetadata + 'static;
89    /// Default float element type.
90    type FloatElem: Element;
91
92    /// Tensor primitive to be used for all int operations.
93    type IntTensorPrimitive: TensorMetadata + 'static;
94    /// Int element type.
95    type IntElem: Element;
96
97    /// Tensor primitive to be used for all bool operations.
98    type BoolTensorPrimitive: TensorMetadata + 'static;
99    /// Tensor primitive to be used for all bool operations.
100    type BoolElem: Element;
101
102    /// Tensor primitive to be used for all quantized operations.
103    type QuantizedTensorPrimitive: TensorMetadata + QTensorPrimitive + 'static;
104
105    /// If autodiff is enabled.
106    fn ad_enabled(_device: &Self::Device) -> bool {
107        false
108    }
109
110    /// Sets the current allocation mode to persistent.
111    #[allow(unused_variables)]
112    fn memory_persistent_allocations<
113        Output: Send,
114        Input: Send,
115        Func: Fn(Input) -> Output + Send,
116    >(
117        device: &Self::Device,
118        input: Input,
119        func: Func,
120    ) -> Output {
121        func(input)
122    }
123
124    /// Manually triggers a memory cleanup on the given device.
125    #[allow(unused_variables)]
126    fn memory_cleanup(device: &Self::Device) {}
127
128    /// Name of the backend.
129    fn name(device: &Self::Device) -> String;
130
131    /// Seeds the backend on the specified device.
132    ///
133    /// There is no guarantee that only the specified device will be seeded, but it is guaranteed
134    /// that at least the specified device will be seeded.
135    ///
136    /// In all cases, this should ensure deterministic execution for a single-threaded program.
137    fn seed(device: &Self::Device, seed: u64);
138
139    /// Sync the backend, ensure that all computation are finished.
140    fn sync(_device: &Self::Device) -> Result<(), ExecutionError> {
141        Ok(())
142    }
143
144    /// Marks the given data as being used as a staging buffer for transfer between CPU and
145    /// accelerators like GPUs.
146    ///
147    /// The given data might be transferred to pinned memory or another format to improve data transfer
148    /// speed.
149    fn staging<'a, Iter>(_data: Iter, _device: &Self::Device)
150    where
151        Iter: Iterator<Item = &'a mut TensorData>,
152    {
153    }
154
155    /// Whether the type is fully supported by the specified device for general operations.
156    ///
157    /// A type is considered supported if it can be used for the full suite of tensor
158    /// operations, including storage, conversion, and basic arithmetic.
159    ///
160    /// Returning `false` does not necessarily mean the device cannot handle the type at all.
161    /// For instance, a device might support a type only for specialized hardware
162    /// acceleration (e.g., matrix multiplication) but lack general arithmetic support. Such
163    /// types should return `false` here as they are not globally supported.
164    fn supports_dtype(device: &Self::Device, dtype: DType) -> bool {
165        Self::dtype_usage(device, dtype).is_superset(DTypeUsage::general())
166    }
167
168    /// Returns the [DTypeUsageSet] for the given [DType] on the specified device.
169    fn dtype_usage(device: &Self::Device, dtype: DType) -> DTypeUsageSet;
170
171    /// Returns the number of devices available on this backend.
172    /// `device` is a reference device used to determine the underlying backend that should be queried.
173    /// A CUDA device will return all devices available to CUDA, a Vulkan device will return all
174    /// devices available to Vulkan, etc.
175    fn device_count(type_id: u16) -> usize;
176}
177
178/// An error that can happen when syncing a device.
179#[derive(Error, Serialize, Deserialize)]
180pub enum ExecutionError {
181    /// A generic error happened during execution.
182    ///
183    /// The backtrace and context information should be included in the reason string.
184    #[error("An error happened during execution\nCaused by:\n  {reason}")]
185    WithContext {
186        /// The reason of the error.
187        reason: String,
188    },
189    /// A generic error happened during execution thrown in the Burn project.
190    ///
191    /// The full context isn't captured by the string alone.
192    #[error("An error happened during execution\nCaused by:\n  {reason}")]
193    Generic {
194        /// The reason of the error.
195        reason: String,
196        /// The backtrace.
197        #[serde(skip)]
198        backtrace: BackTrace,
199    },
200}
201
202impl core::fmt::Debug for ExecutionError {
203    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
204        f.write_fmt(format_args!("{self}"))
205    }
206}
207
208/// Trait that allows a backend to support autodiff.
209pub trait AutodiffBackend: Backend {
210    /// The inner backend type.
211    type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>;
212
213    /// Gradients type.
214    type Gradients: Send;
215
216    /// Backward pass.
217    ///
218    /// # Arguments
219    ///
220    /// * `tensor` - The tensor is the last node of computational graph where the gradients are computed.
221    ///
222    /// # Returns
223    ///
224    /// The gradients.
225    fn backward(tensor: FloatTensor<Self>) -> Self::Gradients;
226
227    /// Returns the gradients of a tensor.
228    ///
229    /// # Arguments
230    ///
231    /// * `tensor` - The tensor to extract the gradients from.
232    ///
233    /// # Returns
234    ///
235    /// An optional tensor containing the gradient.
236    fn grad(
237        tensor: &FloatTensor<Self>,
238        grads: &Self::Gradients,
239    ) -> Option<FloatTensor<Self::InnerBackend>>;
240
241    /// Pops the gradients of a tensor and returns them.
242    ///
243    /// # Arguments
244    ///
245    /// * `tensor` - The tensor to pop the gradients from.
246    /// * `grads` - The gradients.
247    ///
248    /// # Returns
249    ///
250    /// An optional tensor containing the given gradients.
251    fn grad_remove(
252        tensor: &FloatTensor<Self>,
253        grads: &mut Self::Gradients,
254    ) -> Option<FloatTensor<Self::InnerBackend>>;
255
256    /// Replace the gradients of a tensor with the one provided.
257    ///
258    /// If no gradient existed for the provided tensor, register it.
259    ///
260    /// # Arguments
261    ///
262    /// * `tensor` - The tensor to pop the gradients from.
263    /// * `grads` - The gradients.
264    /// * `grad` - The updated grad tensor.
265    fn grad_replace(
266        tensor: &FloatTensor<Self>,
267        grads: &mut Self::Gradients,
268        grad: FloatTensor<Self::InnerBackend>,
269    );
270
271    /// Returns the tensor with inner backend type.
272    ///
273    /// # Arguments
274    ///
275    /// * `tensor` - The tensor to get the inner backend tensor for.
276    ///
277    /// # Returns
278    ///
279    /// The inner backend tensor.
280    fn inner(tensor: FloatTensor<Self>) -> FloatTensor<Self::InnerBackend>;
281
282    /// Returns the tensor with inner backend type.
283    ///
284    /// # Arguments
285    ///
286    /// * `tensor` - The tensor to get the inner backend tensor for.
287    ///
288    /// # Returns
289    ///
290    /// The inner backend tensor.
291    fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend>;
292
293    /// Returns the tensor with inner backend type.
294    ///
295    /// # Arguments
296    ///
297    /// * `tensor` - The tensor to get the inner backend tensor for.
298    ///
299    /// # Returns
300    ///
301    /// The inner backend tensor.
302    fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend>;
303
304    /// Returns the tensor with inner backend type.
305    ///
306    /// # Arguments
307    ///
308    /// * `tensor` - The tensor to get the inner backend tensor for.
309    ///
310    /// # Returns
311    ///
312    /// The inner backend tensor.
313    fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend>;
314
315    /// Converts the inner backend tensor to the autodiff backend tensor.
316    ///
317    /// # Arguments
318    ///
319    /// * `tensor` - The inner backend tensor to convert.
320    ///
321    ///
322    /// # Returns
323    ///
324    /// The autodiff backend tensor.
325    fn from_inner(tensor: FloatTensor<Self::InnerBackend>) -> FloatTensor<Self>;
326
327    /// Converts the inner backend tensor to the autodiff backend tensor.
328    ///
329    /// # Arguments
330    ///
331    /// * `tensor` - The inner backend tensor to convert.
332    ///
333    ///
334    /// # Returns
335    ///
336    /// The autodiff backend tensor.
337    fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self>;
338
339    /// Converts the inner backend tensor to the autodiff backend tensor.
340    ///
341    /// # Arguments
342    ///
343    /// * `tensor` - The inner backend tensor to convert.
344    ///
345    ///
346    /// # Returns
347    ///
348    /// The autodiff backend tensor.
349    fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self>;
350
351    /// Converts the inner backend tensor to the autodiff backend tensor.
352    ///
353    /// # Arguments
354    ///
355    /// * `tensor` - The inner backend tensor to convert.
356    ///
357    ///
358    /// # Returns
359    ///
360    /// The autodiff backend tensor.
361    fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self>;
362
363    #[cfg(feature = "distributed")]
364    /// Mark the tensor as distributed across multiple devices.
365    /// The gradients will be aggregated during the backward pass.
366    ///
367    /// This function does nothing when distributed training is not available.
368    fn set_distributed_params(
369        tensor: FloatTensor<Self>,
370        _param_id: DistributedParamId,
371    ) -> FloatTensor<Self> {
372        tensor
373    }
374
375    #[cfg(feature = "distributed")]
376    /// Returns the distributed parameters if the tensor was marked as distributed.
377    fn distributed_params(_tensor: &FloatTensor<Self>) -> Option<DistributedParams> {
378        None
379    }
380
381    #[cfg(feature = "distributed")]
382    /// Returns true if the tensor was marked as distributed.
383    fn is_distributed(_tensor: &FloatTensor<Self>) -> bool {
384        false
385    }
386}
387
388/// Describes how a data type can be used on a given device.
389///
390/// A data type may be supported for different classes of operations. Not all
391/// data types that appear in hardware or kernel implementations are suitable
392/// for general-purpose tensor operations.
393#[derive(Debug, EnumSetType)]
394pub enum DTypeUsage {
395    /// The type can be stored in device memory and converted to and from
396    /// other supported data types.
397    Storage,
398    /// The type supports general-purpose arithmetic and common tensor
399    /// operations (e.g. elementwise ops, reductions, etc.).
400    Arithmetic,
401    /// The type is supported by hardware-accelerated execution paths.
402    ///
403    /// This typically indicates support for accelerator-backed compute units (e.g., tensor
404    /// cores executing MMA instructions) for high-performance operations such as matrix
405    /// multiplication and operations that lower to it.
406    ///
407    /// # Notes
408    /// - A type can be both [`Arithmetic`](DTypeUsage::Arithmetic) and
409    ///   [`Accelerated`](DTypeUsage::Accelerated) if it supports general-purpose operations
410    ///   *and* accelerated paths.
411    /// - If a type is marked as `Accelerated` but not `Arithmetic`, it is not
412    ///   suitable for general-purpose tensor operations and may only be used
413    ///   in specific accelerated operations.
414    ///
415    /// `Accelerated` is a **flag**, not a detailed descriptor. It does not enumerate which
416    /// operations are accelerated or which accelerator features are available.
417    Accelerated,
418}
419
420/// A set of [DTypeUsage] representing the total capabilities of a data type on a device.
421pub type DTypeUsageSet = EnumSet<DTypeUsage>;
422
423impl DTypeUsage {
424    /// Returns the usage set required for general-purpose tensor support.
425    pub fn general() -> DTypeUsageSet {
426        DTypeUsage::Storage | DTypeUsage::Arithmetic
427    }
428}