Skip to main content

burn_backend/backend/
base.rs

1use burn_std::DType;
2pub use burn_std::backtrace::BackTrace;
3
4use alloc::string::String;
5use enumset::{EnumSet, EnumSetType};
6use serde::{Deserialize, Serialize};
7use thiserror::Error;
8
9use crate::element::Element;
10use crate::ops::*;
11use crate::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
12use crate::{QTensorPrimitive, TensorData, TensorMetadata};
13
14#[cfg(feature = "distributed")]
15use crate::distributed::{DistributedParamId, DistributedParams};
16
17use super::DeviceOps;
18
19/// The mapping of types used by Backend and traits like Numeric, BasicOps
20pub trait BackendTypes {
21    /// Device type.
22    type Device: DeviceOps;
23
24    /// Tensor primitive to be used for all float operations.
25    type FloatTensorPrimitive: TensorMetadata + 'static;
26    /// Default float element type.
27    type FloatElem: Element;
28
29    /// Tensor primitive to be used for all int operations.
30    type IntTensorPrimitive: TensorMetadata + 'static;
31    /// Int element type.
32    type IntElem: Element;
33
34    /// Tensor primitive to be used for all bool operations.
35    type BoolTensorPrimitive: TensorMetadata + 'static;
36    /// Tensor primitive to be used for all bool operations.
37    type BoolElem: Element;
38
39    /// Tensor primitive to be used for all quantized operations.
40    type QuantizedTensorPrimitive: TensorMetadata + QTensorPrimitive + 'static;
41}
42
43/// This trait defines all types and functions needed for a backend to be used with burn.
44///
45/// ## Design
46///
47/// This trait aims to be as unopinionated as possible and allows implementations to define
48/// their own types and patterns. Therefore, there are few pre-defined abstractions baked
49/// into this trait.
50///
51/// Backends must define their own tensor types for each data type: `float`, `int`, and `bool`.
52/// Since we minimize assumptions, we chose to separate these types, as they are used in
53/// different contexts. However, some backends may have a generic tensor type that is used
54/// for all data types.
55///
56/// ### Eager Mode
57///
58/// Because burn supports dynamic graphs, the backend trait is designed around kernel
59/// implementations that can be called without any mutable context or graph. This may not be
60/// ideal for backends that want to configure their computational graphs and execute them
61/// multiple times.
62///
63/// To implement this kind of backend, channels could be used to communicate with a backend
64/// server thread to build the computation graphs and re-execute the ones that are repeated,
65/// with some form of cache. Once that pattern has matured, a graph mode backend trait could
66/// be extracted from it, allowing other backends of the same kind to be quickly integrated
67/// with burn. This pattern could also be used to create an operation fusion trait, which
68/// allows backends to define what kind of graph structures can be fused into one operation.
69///
70/// ### Multi-Threaded
71///
72/// Backend tensor types are all `Clone` + `Send`, which allows them to be safely
73/// sent between threads. It is recommended to wrap tensors with [Arc](alloc::sync::Arc),
74/// which avoids copying the tensor's buffer. Note that it is still possible to mutate and
75/// reuse tensors' buffer without locking; see the next section on the Mutable API.
76///
77/// ### Mutable API
78///
79/// There is no mutable or inplace operation API to implement, but that does not mean that
80/// backends cannot support them. Using [try_unwrap](alloc::sync::Arc::try_unwrap) and
81/// [get_mut](alloc::sync::Arc::get_mut) allows backends to have access to an owned or mutable
82/// reference to their tensor buffer data structure if the tensor is not shared. In that case,
83/// backends can dispatch to their owned inplace operations for better performance.
84///
85/// ## Documentation
86///
87/// Most of the documentation for each function can be found on the user API
88#[cfg_attr(doc, doc = crate::doc_tensor!())]
89#[cfg_attr(not(doc), doc = "`Tensor`")]
90/// struct in the `burn-tensor` crate.
91/// For modules, public functions are often created, which can be used by `burn-core` modules.
92pub trait Backend:
93    BackendTypes
94    + FloatTensorOps<Self>
95    + BoolTensorOps<Self>
96    + IntTensorOps<Self>
97    + ModuleOps<Self>
98    + ActivationOps<Self>
99    + QTensorOps<Self>
100    + TransactionOps<Self>
101    + Clone
102    + Default
103    + Sized
104    + Send
105    + Sync
106    + core::fmt::Debug
107    + 'static
108{
109    /// If autodiff is enabled.
110    fn ad_enabled(_device: &Self::Device) -> bool {
111        false
112    }
113
114    /// Sets the current allocation mode to persistent.
115    #[allow(unused_variables)]
116    fn memory_persistent_allocations<
117        Output: Send,
118        Input: Send,
119        Func: Fn(Input) -> Output + Send,
120    >(
121        device: &Self::Device,
122        input: Input,
123        func: Func,
124    ) -> Output {
125        func(input)
126    }
127
128    /// Manually triggers a memory cleanup on the given device.
129    #[allow(unused_variables)]
130    fn memory_cleanup(device: &Self::Device) {}
131
132    /// Name of the backend.
133    fn name(device: &Self::Device) -> String;
134
135    /// Seeds the backend on the specified device.
136    ///
137    /// There is no guarantee that only the specified device will be seeded, but it is guaranteed
138    /// that at least the specified device will be seeded.
139    ///
140    /// In all cases, this should ensure deterministic execution for a single-threaded program.
141    fn seed(device: &Self::Device, seed: u64);
142
143    /// Sync the backend, ensure that all computation are finished.
144    fn sync(_device: &Self::Device) -> Result<(), ExecutionError> {
145        Ok(())
146    }
147
148    /// Marks the given data as being used as a staging buffer for transfer between CPU and
149    /// accelerators like GPUs.
150    ///
151    /// The given data might be transferred to pinned memory or another format to improve data transfer
152    /// speed.
153    fn staging<'a, Iter>(_data: Iter, _device: &Self::Device)
154    where
155        Iter: Iterator<Item = &'a mut TensorData>,
156    {
157    }
158
159    /// Whether the type is fully supported by the specified device for general operations.
160    ///
161    /// A type is considered supported if it can be used for the full suite of tensor
162    /// operations, including storage, conversion, and basic arithmetic.
163    ///
164    /// Returning `false` does not necessarily mean the device cannot handle the type at all.
165    /// For instance, a device might support a type only for specialized hardware
166    /// acceleration (e.g., matrix multiplication) but lack general arithmetic support. Such
167    /// types should return `false` here as they are not globally supported.
168    fn supports_dtype(device: &Self::Device, dtype: DType) -> bool {
169        Self::dtype_usage(device, dtype).is_superset(DTypeUsage::general())
170    }
171
172    /// Returns the [DTypeUsageSet] for the given [DType] on the specified device.
173    fn dtype_usage(device: &Self::Device, dtype: DType) -> DTypeUsageSet;
174
175    /// Returns the number of devices available on this backend.
176    /// `device` is a reference device used to determine the underlying backend that should be queried.
177    /// A CUDA device will return all devices available to CUDA, a Vulkan device will return all
178    /// devices available to Vulkan, etc.
179    fn device_count(type_id: u16) -> usize;
180}
181
182/// An error that can happen when syncing a device.
183#[derive(Error, Serialize, Deserialize)]
184pub enum ExecutionError {
185    /// A generic error happened during execution.
186    ///
187    /// The backtrace and context information should be included in the reason string.
188    #[error("An error happened during execution\nCaused by:\n  {reason}")]
189    WithContext {
190        /// The reason of the error.
191        reason: String,
192    },
193    /// A generic error happened during execution thrown in the Burn project.
194    ///
195    /// The full context isn't captured by the string alone.
196    #[error("An error happened during execution\nCaused by:\n  {reason}")]
197    Generic {
198        /// The reason of the error.
199        reason: String,
200        /// The backtrace.
201        #[serde(skip)]
202        backtrace: BackTrace,
203    },
204}
205
206impl core::fmt::Debug for ExecutionError {
207    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
208        f.write_fmt(format_args!("{self}"))
209    }
210}
211
212/// Trait that allows a backend to support autodiff.
213pub trait AutodiffBackend: Backend {
214    /// The inner backend type.
215    type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>;
216
217    /// Gradients type.
218    type Gradients: Send;
219
220    /// Backward pass.
221    ///
222    /// # Arguments
223    ///
224    /// * `tensor` - The tensor is the last node of computational graph where the gradients are computed.
225    ///
226    /// # Returns
227    ///
228    /// The gradients.
229    fn backward(tensor: FloatTensor<Self>) -> Self::Gradients;
230
231    /// Returns the gradients of a tensor.
232    ///
233    /// # Arguments
234    ///
235    /// * `tensor` - The tensor to extract the gradients from.
236    ///
237    /// # Returns
238    ///
239    /// An optional tensor containing the gradient.
240    fn grad(
241        tensor: &FloatTensor<Self>,
242        grads: &Self::Gradients,
243    ) -> Option<FloatTensor<Self::InnerBackend>>;
244
245    /// Pops the gradients of a tensor and returns them.
246    ///
247    /// # Arguments
248    ///
249    /// * `tensor` - The tensor to pop the gradients from.
250    /// * `grads` - The gradients.
251    ///
252    /// # Returns
253    ///
254    /// An optional tensor containing the given gradients.
255    fn grad_remove(
256        tensor: &FloatTensor<Self>,
257        grads: &mut Self::Gradients,
258    ) -> Option<FloatTensor<Self::InnerBackend>>;
259
260    /// Replace the gradients of a tensor with the one provided.
261    ///
262    /// If no gradient existed for the provided tensor, register it.
263    ///
264    /// # Arguments
265    ///
266    /// * `tensor` - The tensor to pop the gradients from.
267    /// * `grads` - The gradients.
268    /// * `grad` - The updated grad tensor.
269    fn grad_replace(
270        tensor: &FloatTensor<Self>,
271        grads: &mut Self::Gradients,
272        grad: FloatTensor<Self::InnerBackend>,
273    );
274
275    /// Returns the tensor with inner backend type.
276    ///
277    /// # Arguments
278    ///
279    /// * `tensor` - The tensor to get the inner backend tensor for.
280    ///
281    /// # Returns
282    ///
283    /// The inner backend tensor.
284    fn inner(tensor: FloatTensor<Self>) -> FloatTensor<Self::InnerBackend>;
285
286    /// Returns the tensor with inner backend type.
287    ///
288    /// # Arguments
289    ///
290    /// * `tensor` - The tensor to get the inner backend tensor for.
291    ///
292    /// # Returns
293    ///
294    /// The inner backend tensor.
295    fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend>;
296
297    /// Returns the tensor with inner backend type.
298    ///
299    /// # Arguments
300    ///
301    /// * `tensor` - The tensor to get the inner backend tensor for.
302    ///
303    /// # Returns
304    ///
305    /// The inner backend tensor.
306    fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend>;
307
308    /// Returns the tensor with inner backend type.
309    ///
310    /// # Arguments
311    ///
312    /// * `tensor` - The tensor to get the inner backend tensor for.
313    ///
314    /// # Returns
315    ///
316    /// The inner backend tensor.
317    fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend>;
318
319    /// Converts the inner backend tensor to the autodiff backend tensor.
320    ///
321    /// # Arguments
322    ///
323    /// * `tensor` - The inner backend tensor to convert.
324    ///
325    ///
326    /// # Returns
327    ///
328    /// The autodiff backend tensor.
329    fn from_inner(tensor: FloatTensor<Self::InnerBackend>) -> FloatTensor<Self>;
330
331    /// Converts the inner backend tensor to the autodiff backend tensor.
332    ///
333    /// # Arguments
334    ///
335    /// * `tensor` - The inner backend tensor to convert.
336    ///
337    ///
338    /// # Returns
339    ///
340    /// The autodiff backend tensor.
341    fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self>;
342
343    /// Converts the inner backend tensor to the autodiff backend tensor.
344    ///
345    /// # Arguments
346    ///
347    /// * `tensor` - The inner backend tensor to convert.
348    ///
349    ///
350    /// # Returns
351    ///
352    /// The autodiff backend tensor.
353    fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self>;
354
355    /// Converts the inner backend tensor to the autodiff backend tensor.
356    ///
357    /// # Arguments
358    ///
359    /// * `tensor` - The inner backend tensor to convert.
360    ///
361    ///
362    /// # Returns
363    ///
364    /// The autodiff backend tensor.
365    fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self>;
366
367    #[cfg(feature = "distributed")]
368    /// Mark the tensor as distributed across multiple devices.
369    /// The gradients will be aggregated during the backward pass.
370    ///
371    /// This function does nothing when distributed training is not available.
372    fn set_distributed_params(
373        tensor: FloatTensor<Self>,
374        _param_id: DistributedParamId,
375    ) -> FloatTensor<Self> {
376        tensor
377    }
378
379    #[cfg(feature = "distributed")]
380    /// Returns the distributed parameters if the tensor was marked as distributed.
381    fn distributed_params(_tensor: &FloatTensor<Self>) -> Option<DistributedParams> {
382        None
383    }
384
385    #[cfg(feature = "distributed")]
386    /// Returns true if the tensor was marked as distributed.
387    fn is_distributed(_tensor: &FloatTensor<Self>) -> bool {
388        false
389    }
390}
391
392/// Describes how a data type can be used on a given device.
393///
394/// A data type may be supported for different classes of operations. Not all
395/// data types that appear in hardware or kernel implementations are suitable
396/// for general-purpose tensor operations.
397#[derive(Debug, EnumSetType)]
398pub enum DTypeUsage {
399    /// The type can be stored in device memory and converted to and from
400    /// other supported data types.
401    Storage,
402    /// The type supports general-purpose arithmetic and common tensor
403    /// operations (e.g. elementwise ops, reductions, etc.).
404    Arithmetic,
405    /// The type is supported by hardware-accelerated execution paths.
406    ///
407    /// This typically indicates support for accelerator-backed compute units (e.g., tensor
408    /// cores executing MMA instructions) for high-performance operations such as matrix
409    /// multiplication and operations that lower to it.
410    ///
411    /// # Notes
412    /// - A type can be both [`Arithmetic`](DTypeUsage::Arithmetic) and
413    ///   [`Accelerated`](DTypeUsage::Accelerated) if it supports general-purpose operations
414    ///   *and* accelerated paths.
415    /// - If a type is marked as `Accelerated` but not `Arithmetic`, it is not
416    ///   suitable for general-purpose tensor operations and may only be used
417    ///   in specific accelerated operations.
418    ///
419    /// `Accelerated` is a **flag**, not a detailed descriptor. It does not enumerate which
420    /// operations are accelerated or which accelerator features are available.
421    Accelerated,
422}
423
424/// A set of [DTypeUsage] representing the total capabilities of a data type on a device.
425pub type DTypeUsageSet = EnumSet<DTypeUsage>;
426
427impl DTypeUsage {
428    /// Returns the usage set required for general-purpose tensor support.
429    pub fn general() -> DTypeUsageSet {
430        DTypeUsage::Storage | DTypeUsage::Arithmetic
431    }
432}