Skip to main content

burn_backend/backend/
base.rs

1use burn_std::DType;
2pub use burn_std::backtrace::BackTrace;
3
4use alloc::string::String;
5use enumset::{EnumSet, EnumSetType};
6use serde::{Deserialize, Serialize};
7use thiserror::Error;
8
9use crate::element::Element;
10use crate::ops::*;
11use crate::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
12use crate::{QTensorPrimitive, TensorData, TensorMetadata};
13
14use super::DeviceOps;
15
16/// This trait defines all types and functions needed for a backend to be used with burn.
17///
18/// ## Design
19///
20/// This trait aims to be as unopinionated as possible and allows implementations to define
21/// their own types and patterns. Therefore, there are few pre-defined abstractions baked
22/// into this trait.
23///
24/// Backends must define their own tensor types for each data type: `float`, `int`, and `bool`.
25/// Since we minimize assumptions, we chose to separate these types, as they are used in
26/// different contexts. However, some backends may have a generic tensor type that is used
27/// for all data types.
28///
29/// ### Eager Mode
30///
31/// Because burn supports dynamic graphs, the backend trait is designed around kernel
32/// implementations that can be called without any mutable context or graph. This may not be
33/// ideal for backends that want to configure their computational graphs and execute them
34/// multiple times.
35///
36/// To implement this kind of backend, channels could be used to communicate with a backend
37/// server thread to build the computation graphs and re-execute the ones that are repeated,
38/// with some form of cache. Once that pattern has matured, a graph mode backend trait could
39/// be extracted from it, allowing other backends of the same kind to be quickly integrated
40/// with burn. This pattern could also be used to create an operation fusion trait, which
41/// allows backends to define what kind of graph structures can be fused into one operation.
42///
43/// ### Multi-Threaded
44///
45/// Backend tensor types are all `Clone` + `Send`, which allows them to be safely
46/// sent between threads. It is recommended to wrap tensors with [Arc](alloc::sync::Arc),
47/// which avoids copying the tensor's buffer. Note that it is still possible to mutate and
48/// reuse tensors' buffer without locking; see the next section on the Mutable API.
49///
50/// ### Mutable API
51///
52/// There is no mutable or inplace operation API to implement, but that does not mean that
53/// backends cannot support them. Using [try_unwrap](alloc::sync::Arc::try_unwrap) and
54/// [get_mut](alloc::sync::Arc::get_mut) allows backends to have access to an owned or mutable
55/// reference to their tensor buffer data structure if the tensor is not shared. In that case,
56/// backends can dispatch to their owned inplace operations for better performance.
57///
58/// ## Documentation
59///
60/// Most of the documentation for each function can be found on the user API
61#[cfg_attr(doc, doc = crate::doc_tensor!())]
62#[cfg_attr(not(doc), doc = "`Tensor`")]
63/// struct in the `burn-tensor` crate.
64/// For modules, public functions are often created, which can be used by `burn-core` modules.
65pub trait Backend:
66    FloatTensorOps<Self>
67    + BoolTensorOps<Self>
68    + IntTensorOps<Self>
69    + ModuleOps<Self>
70    + ActivationOps<Self>
71    + QTensorOps<Self>
72    + TransactionOps<Self>
73    + Clone
74    + Default
75    + Sized
76    + Send
77    + Sync
78    + core::fmt::Debug
79    + 'static
80{
81    /// Device type.
82    type Device: DeviceOps;
83
84    /// Tensor primitive to be used for all float operations.
85    type FloatTensorPrimitive: TensorMetadata + 'static;
86    /// Default float element type.
87    type FloatElem: Element;
88
89    /// Tensor primitive to be used for all int operations.
90    type IntTensorPrimitive: TensorMetadata + 'static;
91    /// Int element type.
92    type IntElem: Element;
93
94    /// Tensor primitive to be used for all bool operations.
95    type BoolTensorPrimitive: TensorMetadata + 'static;
96    /// Tensor primitive to be used for all bool operations.
97    type BoolElem: Element;
98
99    /// Tensor primitive to be used for all quantized operations.
100    type QuantizedTensorPrimitive: TensorMetadata + QTensorPrimitive + 'static;
101
102    /// If autodiff is enabled.
103    fn ad_enabled() -> bool {
104        false
105    }
106
107    /// Sets the current allocation mode to persistent.
108    #[allow(unused_variables)]
109    fn memory_persistent_allocations<Output, Input, Func: Fn(Input) -> Output>(
110        device: &Self::Device,
111        input: Input,
112        func: Func,
113    ) -> Output {
114        func(input)
115    }
116
117    /// Manually triggers a memory cleanup on the given device.
118    #[allow(unused_variables)]
119    fn memory_cleanup(device: &Self::Device) {}
120
121    /// Name of the backend.
122    fn name(device: &Self::Device) -> String;
123
124    /// Seeds the backend on the specified device.
125    ///
126    /// There is no guarantee that only the specified device will be seeded, but it is guaranteed
127    /// that at least the specified device will be seeded.
128    ///
129    /// In all cases, this should ensure deterministic execution for a single-threaded program.
130    fn seed(device: &Self::Device, seed: u64);
131
132    /// Sync the backend, ensure that all computation are finished.
133    fn sync(_device: &Self::Device) -> Result<(), ExecutionError> {
134        Ok(())
135    }
136
137    /// Marks the given data as being used as a staging buffer for transfer between CPU and
138    /// accelerators like GPUs.
139    ///
140    /// The given data might be transferred to pinned memory or another format to improve data transfer
141    /// speed.
142    fn staging<'a, Iter>(_data: Iter, _device: &Self::Device)
143    where
144        Iter: Iterator<Item = &'a mut TensorData>,
145    {
146    }
147
148    /// Whether the type is fully supported by the specified device for general operations.
149    ///
150    /// A type is considered supported if it can be used for the full suite of tensor
151    /// operations, including storage, conversion, and basic arithmetic.
152    ///
153    /// Returning `false` does not necessarily mean the device cannot handle the type at all.
154    /// For instance, a device might support a type only for specialized hardware
155    /// acceleration (e.g., matrix multiplication) but lack general arithmetic support. Such
156    /// types should return `false` here as they are not globally supported.
157    fn supports_dtype(device: &Self::Device, dtype: DType) -> bool {
158        Self::dtype_usage(device, dtype).is_superset(DTypeUsage::general())
159    }
160
161    /// Returns the [DTypeUsageSet] for the given [DType] on the specified device.
162    fn dtype_usage(device: &Self::Device, dtype: DType) -> DTypeUsageSet;
163}
164
165/// An error that can happen when syncing a device.
166#[derive(Error, Serialize, Deserialize)]
167pub enum ExecutionError {
168    /// A generic error happened during execution.
169    ///
170    /// The backtrace and context information should be included in the reason string.
171    #[error("An error happened during execution\nCaused by:\n  {reason}")]
172    WithContext {
173        /// The reason of the error.
174        reason: String,
175    },
176    /// A generic error happened during execution thrown in the Burn project.
177    ///
178    /// The full context isn't captured by the string alone.
179    #[error("An error happened during execution\nCaused by:\n  {reason}")]
180    Generic {
181        /// The reason of the error.
182        reason: String,
183        /// The backtrace.
184        #[serde(skip)]
185        backtrace: BackTrace,
186    },
187}
188
189impl core::fmt::Debug for ExecutionError {
190    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
191        f.write_fmt(format_args!("{self}"))
192    }
193}
194
195/// Trait that allows a backend to support autodiff.
196pub trait AutodiffBackend: Backend {
197    /// The inner backend type.
198    type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>;
199
200    /// Gradients type.
201    type Gradients: Send;
202
203    /// Backward pass.
204    ///
205    /// # Arguments
206    ///
207    /// * `tensor` - The tensor is the last node of computational graph where the gradients are computed.
208    ///
209    /// # Returns
210    ///
211    /// The gradients.
212    fn backward(tensor: FloatTensor<Self>) -> Self::Gradients;
213
214    /// Returns the gradients of a tensor.
215    ///
216    /// # Arguments
217    ///
218    /// * `tensor` - The tensor to extract the gradients from.
219    ///
220    /// # Returns
221    ///
222    /// An optional tensor containing the gradient.
223    fn grad(
224        tensor: &FloatTensor<Self>,
225        grads: &Self::Gradients,
226    ) -> Option<FloatTensor<Self::InnerBackend>>;
227
228    /// Pops the gradients of a tensor and returns them.
229    ///
230    /// # Arguments
231    ///
232    /// * `tensor` - The tensor to pop the gradients from.
233    /// * `grads` - The gradients.
234    ///
235    /// # Returns
236    ///
237    /// An optional tensor containing the given gradients.
238    fn grad_remove(
239        tensor: &FloatTensor<Self>,
240        grads: &mut Self::Gradients,
241    ) -> Option<FloatTensor<Self::InnerBackend>>;
242
243    /// Replace the gradients of a tensor with the one provided.
244    ///
245    /// If no gradient existed for the provided tensor, register it.
246    ///
247    /// # Arguments
248    ///
249    /// * `tensor` - The tensor to pop the gradients from.
250    /// * `grads` - The gradients.
251    /// * `grad` - The updated grad tensor.
252    fn grad_replace(
253        tensor: &FloatTensor<Self>,
254        grads: &mut Self::Gradients,
255        grad: FloatTensor<Self::InnerBackend>,
256    );
257
258    /// Returns the tensor with inner backend type.
259    ///
260    /// # Arguments
261    ///
262    /// * `tensor` - The tensor to get the inner backend tensor for.
263    ///
264    /// # Returns
265    ///
266    /// The inner backend tensor.
267    fn inner(tensor: FloatTensor<Self>) -> FloatTensor<Self::InnerBackend>;
268
269    /// Returns the tensor with inner backend type.
270    ///
271    /// # Arguments
272    ///
273    /// * `tensor` - The tensor to get the inner backend tensor for.
274    ///
275    /// # Returns
276    ///
277    /// The inner backend tensor.
278    fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend>;
279
280    /// Returns the tensor with inner backend type.
281    ///
282    /// # Arguments
283    ///
284    /// * `tensor` - The tensor to get the inner backend tensor for.
285    ///
286    /// # Returns
287    ///
288    /// The inner backend tensor.
289    fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend>;
290
291    /// Returns the tensor with inner backend type.
292    ///
293    /// # Arguments
294    ///
295    /// * `tensor` - The tensor to get the inner backend tensor for.
296    ///
297    /// # Returns
298    ///
299    /// The inner backend tensor.
300    fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend>;
301
302    /// Converts the inner backend tensor to the autodiff backend tensor.
303    ///
304    /// # Arguments
305    ///
306    /// * `tensor` - The inner backend tensor to convert.
307    ///
308    ///
309    /// # Returns
310    ///
311    /// The autodiff backend tensor.
312    fn from_inner(tensor: FloatTensor<Self::InnerBackend>) -> FloatTensor<Self>;
313
314    /// Converts the inner backend tensor to the autodiff backend tensor.
315    ///
316    /// # Arguments
317    ///
318    /// * `tensor` - The inner backend tensor to convert.
319    ///
320    ///
321    /// # Returns
322    ///
323    /// The autodiff backend tensor.
324    fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self>;
325
326    /// Converts the inner backend tensor to the autodiff backend tensor.
327    ///
328    /// # Arguments
329    ///
330    /// * `tensor` - The inner backend tensor to convert.
331    ///
332    ///
333    /// # Returns
334    ///
335    /// The autodiff backend tensor.
336    fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self>;
337
338    /// Converts the inner backend tensor to the autodiff backend tensor.
339    ///
340    /// # Arguments
341    ///
342    /// * `tensor` - The inner backend tensor to convert.
343    ///
344    ///
345    /// # Returns
346    ///
347    /// The autodiff backend tensor.
348    fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self>;
349}
350
351/// Describes how a data type can be used on a given device.
352///
353/// A data type may be supported for different classes of operations. Not all
354/// data types that appear in hardware or kernel implementations are suitable
355/// for general-purpose tensor operations.
356#[derive(Debug, EnumSetType)]
357pub enum DTypeUsage {
358    /// The type can be stored in device memory and converted to and from
359    /// other supported data types.
360    Storage,
361    /// The type supports general-purpose arithmetic and common tensor
362    /// operations (e.g. elementwise ops, reductions, etc.).
363    Arithmetic,
364    /// The type is supported by hardware-accelerated execution paths.
365    ///
366    /// This typically indicates support for accelerator-backed compute units (e.g., tensor
367    /// cores executing MMA instructions) for high-performance operations such as matrix
368    /// multiplication and operations that lower to it.
369    ///
370    /// # Notes
371    /// - A type can be both [`Arithmetic`](DTypeUsage::Arithmetic) and
372    ///   [`Accelerated`](DTypeUsage::Accelerated) if it supports general-purpose operations
373    ///   *and* accelerated paths.
374    /// - If a type is marked as `Accelerated` but not `Arithmetic`, it is not
375    ///   suitable for general-purpose tensor operations and may only be used
376    ///   in specific accelerated operations.
377    ///
378    /// `Accelerated` is a **flag**, not a detailed descriptor. It does not enumerate which
379    /// operations are accelerated or which accelerator features are available.
380    Accelerated,
381}
382
383/// A set of [DTypeUsage] representing the total capabilities of a data type on a device.
384pub type DTypeUsageSet = EnumSet<DTypeUsage>;
385
386impl DTypeUsage {
387    /// Returns the usage set required for general-purpose tensor support.
388    pub fn general() -> DTypeUsageSet {
389        DTypeUsage::Storage | DTypeUsage::Arithmetic
390    }
391}