burn_backend/backend/
base.rs

1use burn_std::DType;
2pub use burn_std::backtrace::BackTrace;
3
4use alloc::string::String;
5use serde::{Deserialize, Serialize};
6use thiserror::Error;
7
8use crate::element::Element;
9use crate::ops::*;
10use crate::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
11use crate::{QTensorPrimitive, TensorData, TensorMetadata};
12
13use super::DeviceOps;
14
15/// This trait defines all types and functions needed for a backend to be used with burn.
16///
17/// ## Design
18///
19/// This trait aims to be as unopinionated as possible and allows implementations to define
20/// their own types and patterns. Therefore, there are few pre-defined abstractions baked
21/// into this trait.
22///
23/// Backends must define their own tensor types for each data type: `float`, `int`, and `bool`.
24/// Since we minimize assumptions, we chose to separate these types, as they are used in
25/// different contexts. However, some backends may have a generic tensor type that is used
26/// for all data types.
27///
28/// ### Eager Mode
29///
30/// Because burn supports dynamic graphs, the backend trait is designed around kernel
31/// implementations that can be called without any mutable context or graph. This may not be
32/// ideal for backends that want to configure their computational graphs and execute them
33/// multiple times.
34///
35/// To implement this kind of backend, channels could be used to communicate with a backend
36/// server thread to build the computation graphs and re-execute the ones that are repeated,
37/// with some form of cache. Once that pattern has matured, a graph mode backend trait could
38/// be extracted from it, allowing other backends of the same kind to be quickly integrated
39/// with burn. This pattern could also be used to create an operation fusion trait, which
40/// allows backends to define what kind of graph structures can be fused into one operation.
41///
42/// ### Multi-Threaded
43///
44/// Backend tensor types are all `Clone` + `Send`, which allows them to be safely
45/// sent between threads. It is recommended to wrap tensors with [Arc](alloc::sync::Arc),
46/// which avoids copying the tensor's buffer. Note that it is still possible to mutate and
47/// reuse tensors' buffer without locking; see the next section on the Mutable API.
48///
49/// ### Mutable API
50///
51/// There is no mutable or inplace operation API to implement, but that does not mean that
52/// backends cannot support them. Using [try_unwrap](alloc::sync::Arc::try_unwrap) and
53/// [get_mut](alloc::sync::Arc::get_mut) allows backends to have access to an owned or mutable
54/// reference to their tensor buffer data structure if the tensor is not shared. In that case,
55/// backends can dispatch to their owned inplace operations for better performance.
56///
57/// ## Documentation
58///
59/// Most of the documentation for each function can be found on the user API
60#[cfg_attr(doc, doc = crate::doc_tensor!())]
61#[cfg_attr(not(doc), doc = "`Tensor`")]
62/// struct in the `burn-tensor` crate.
63/// For modules, public functions are often created, which can be used by `burn-core` modules.
64pub trait Backend:
65    FloatTensorOps<Self>
66    + BoolTensorOps<Self>
67    + IntTensorOps<Self>
68    + ModuleOps<Self>
69    + ActivationOps<Self>
70    + QTensorOps<Self>
71    + TransactionOps<Self>
72    + Clone
73    + Default
74    + Sized
75    + Send
76    + Sync
77    + core::fmt::Debug
78    + 'static
79{
80    /// Device type.
81    type Device: DeviceOps;
82
83    /// Tensor primitive to be used for all float operations.
84    type FloatTensorPrimitive: TensorMetadata + 'static;
85    /// Default float element type.
86    type FloatElem: Element;
87
88    /// Tensor primitive to be used for all int operations.
89    type IntTensorPrimitive: TensorMetadata + 'static;
90    /// Int element type.
91    type IntElem: Element;
92
93    /// Tensor primitive to be used for all bool operations.
94    type BoolTensorPrimitive: TensorMetadata + 'static;
95    /// Tensor primitive to be used for all bool operations.
96    type BoolElem: Element;
97
98    /// Tensor primitive to be used for all quantized operations.
99    type QuantizedTensorPrimitive: TensorMetadata + QTensorPrimitive + 'static;
100
101    /// If autodiff is enabled.
102    fn ad_enabled() -> bool {
103        false
104    }
105
106    /// Sets the current allocation mode to persistent.
107    #[allow(unused_variables)]
108    fn memory_persistent_allocations<Output, Input, Func: Fn(Input) -> Output>(
109        device: &Self::Device,
110        input: Input,
111        func: Func,
112    ) -> Output {
113        func(input)
114    }
115
116    /// Manually triggers a memory cleanup on the given device.
117    #[allow(unused_variables)]
118    fn memory_cleanup(device: &Self::Device) {}
119
120    /// Name of the backend.
121    fn name(device: &Self::Device) -> String;
122
123    /// Seeds the backend on the specified device.
124    ///
125    /// There is no guarantee that only the specified device will be seeded, but it is guaranteed
126    /// that at least the specified device will be seeded.
127    ///
128    /// In all cases, this should ensure deterministic execution for a single-threaded program.
129    fn seed(device: &Self::Device, seed: u64);
130
131    /// Sync the backend, ensure that all computation are finished.
132    fn sync(_device: &Self::Device) -> Result<(), ExecutionError> {
133        Ok(())
134    }
135
136    /// Marks the given data as being used as a staging buffer for transfer between CPU and
137    /// accelerators like GPUs.
138    ///
139    /// The given data might be transferred to pinned memory or another format to improve data transfer
140    /// speed.
141    fn staging<'a, Iter>(_data: Iter, _device: &Self::Device)
142    where
143        Iter: Iterator<Item = &'a mut TensorData>,
144    {
145    }
146
147    /// Whether the type is supported by the specified device.
148    fn supports_dtype(device: &Self::Device, dtype: DType) -> bool;
149}
150
151/// An error that can happen when syncing a device.
152#[derive(Error, Serialize, Deserialize)]
153pub enum ExecutionError {
154    /// A generic error happened during execution.
155    ///
156    /// The backtrace and context information should be included in the reason string.
157    #[error("An error happened during execution\nCaused by:\n  {reason}")]
158    WithContext {
159        /// The reason of the error.
160        reason: String,
161    },
162    /// A generic error happened during execution thrown in the Burn project.
163    ///
164    /// The full context isn't captured by the string alone.
165    #[error("An error happened during execution\nCaused by:\n  {reason}")]
166    Generic {
167        /// The reason of the error.
168        reason: String,
169        /// The backtrace.
170        #[serde(skip)]
171        backtrace: BackTrace,
172    },
173}
174
175impl core::fmt::Debug for ExecutionError {
176    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
177        f.write_fmt(format_args!("{self}"))
178    }
179}
180
181/// Trait that allows a backend to support autodiff.
182pub trait AutodiffBackend: Backend {
183    /// The inner backend type.
184    type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>;
185
186    /// Gradients type.
187    type Gradients: Send;
188
189    /// Backward pass.
190    ///
191    /// # Arguments
192    ///
193    /// * `tensor` - The tensor is the last node of computational graph where the gradients are computed.
194    ///
195    /// # Returns
196    ///
197    /// The gradients.
198    fn backward(tensor: FloatTensor<Self>) -> Self::Gradients;
199
200    /// Returns the gradients of a tensor.
201    ///
202    /// # Arguments
203    ///
204    /// * `tensor` - The tensor to extract the gradients from.
205    ///
206    /// # Returns
207    ///
208    /// An optional tensor containing the gradient.
209    fn grad(
210        tensor: &FloatTensor<Self>,
211        grads: &Self::Gradients,
212    ) -> Option<FloatTensor<Self::InnerBackend>>;
213
214    /// Pops the gradients of a tensor and returns them.
215    ///
216    /// # Arguments
217    ///
218    /// * `tensor` - The tensor to pop the gradients from.
219    /// * `grads` - The gradients.
220    ///
221    /// # Returns
222    ///
223    /// An optional tensor containing the given gradients.
224    fn grad_remove(
225        tensor: &FloatTensor<Self>,
226        grads: &mut Self::Gradients,
227    ) -> Option<FloatTensor<Self::InnerBackend>>;
228
229    /// Replace the gradients of a tensor with the one provided.
230    ///
231    /// If no gradient existed for the provided tensor, register it.
232    ///
233    /// # Arguments
234    ///
235    /// * `tensor` - The tensor to pop the gradients from.
236    /// * `grads` - The gradients.
237    /// * `grad` - The updated grad tensor.
238    fn grad_replace(
239        tensor: &FloatTensor<Self>,
240        grads: &mut Self::Gradients,
241        grad: FloatTensor<Self::InnerBackend>,
242    );
243
244    /// Returns the tensor with inner backend type.
245    ///
246    /// # Arguments
247    ///
248    /// * `tensor` - The tensor to get the inner backend tensor for.
249    ///
250    /// # Returns
251    ///
252    /// The inner backend tensor.
253    fn inner(tensor: FloatTensor<Self>) -> FloatTensor<Self::InnerBackend>;
254
255    /// Returns the tensor with inner backend type.
256    ///
257    /// # Arguments
258    ///
259    /// * `tensor` - The tensor to get the inner backend tensor for.
260    ///
261    /// # Returns
262    ///
263    /// The inner backend tensor.
264    fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend>;
265
266    /// Returns the tensor with inner backend type.
267    ///
268    /// # Arguments
269    ///
270    /// * `tensor` - The tensor to get the inner backend tensor for.
271    ///
272    /// # Returns
273    ///
274    /// The inner backend tensor.
275    fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend>;
276
277    /// Returns the tensor with inner backend type.
278    ///
279    /// # Arguments
280    ///
281    /// * `tensor` - The tensor to get the inner backend tensor for.
282    ///
283    /// # Returns
284    ///
285    /// The inner backend tensor.
286    fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend>;
287
288    /// Converts the inner backend tensor to the autodiff backend tensor.
289    ///
290    /// # Arguments
291    ///
292    /// * `tensor` - The inner backend tensor to convert.
293    ///
294    ///
295    /// # Returns
296    ///
297    /// The autodiff backend tensor.
298    fn from_inner(tensor: FloatTensor<Self::InnerBackend>) -> FloatTensor<Self>;
299
300    /// Converts the inner backend tensor to the autodiff backend tensor.
301    ///
302    /// # Arguments
303    ///
304    /// * `tensor` - The inner backend tensor to convert.
305    ///
306    ///
307    /// # Returns
308    ///
309    /// The autodiff backend tensor.
310    fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self>;
311
312    /// Converts the inner backend tensor to the autodiff backend tensor.
313    ///
314    /// # Arguments
315    ///
316    /// * `tensor` - The inner backend tensor to convert.
317    ///
318    ///
319    /// # Returns
320    ///
321    /// The autodiff backend tensor.
322    fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self>;
323
324    /// Converts the inner backend tensor to the autodiff backend tensor.
325    ///
326    /// # Arguments
327    ///
328    /// * `tensor` - The inner backend tensor to convert.
329    ///
330    ///
331    /// # Returns
332    ///
333    /// The autodiff backend tensor.
334    fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self>;
335}