burn_backend/backend/
base.rs

1use alloc::string::String;
2use burn_std::backtrace::BackTrace;
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5
6use crate::element::Element;
7use crate::ops::*;
8use crate::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
9use crate::{QTensorPrimitive, TensorData, TensorMetadata};
10
11use super::DeviceOps;
12
13/// This trait defines all types and functions needed for a backend to be used with burn.
14///
15/// ## Design
16///
17/// This trait aims to be as unopinionated as possible and allows implementations to define
18/// their own types and patterns. Therefore, there are few pre-defined abstractions baked
19/// into this trait.
20///
21/// Backends must define their own tensor types for each data type: `float`, `int`, and `bool`.
22/// Since we minimize assumptions, we chose to separate these types, as they are used in
23/// different contexts. However, some backends may have a generic tensor type that is used
24/// for all data types.
25///
26/// ### Eager Mode
27///
28/// Because burn supports dynamic graphs, the backend trait is designed around kernel
29/// implementations that can be called without any mutable context or graph. This may not be
30/// ideal for backends that want to configure their computational graphs and execute them
31/// multiple times.
32///
33/// To implement this kind of backend, channels could be used to communicate with a backend
34/// server thread to build the computation graphs and re-execute the ones that are repeated,
35/// with some form of cache. Once that pattern has matured, a graph mode backend trait could
36/// be extracted from it, allowing other backends of the same kind to be quickly integrated
37/// with burn. This pattern could also be used to create an operation fusion trait, which
38/// allows backends to define what kind of graph structures can be fused into one operation.
39///
40/// ### Multi-Threaded
41///
42/// Backend tensor types are all `Clone` + `Send`, which allows them to be safely
43/// sent between threads. It is recommended to wrap tensors with [Arc](alloc::sync::Arc),
44/// which avoids copying the tensor's buffer. Note that it is still possible to mutate and
45/// reuse tensors' buffer without locking; see the next section on the Mutable API.
46///
47/// ### Mutable API
48///
49/// There is no mutable or inplace operation API to implement, but that does not mean that
50/// backends cannot support them. Using [try_unwrap](alloc::sync::Arc::try_unwrap) and
51/// [get_mut](alloc::sync::Arc::get_mut) allows backends to have access to an owned or mutable
52/// reference to their tensor buffer data structure if the tensor is not shared. In that case,
53/// backends can dispatch to their owned inplace operations for better performance.
54///
55/// ## Documentation
56///
57/// Most of the documentation for each function can be found on the user API
58#[cfg_attr(doc, doc = crate::doc_tensor!())]
59#[cfg_attr(not(doc), doc = "`Tensor`")]
60/// struct in the `burn-tensor` crate.
61/// For modules, public functions are often created, which can be used by `burn-core` modules.
62pub trait Backend:
63    FloatTensorOps<Self>
64    + BoolTensorOps<Self>
65    + IntTensorOps<Self>
66    + ModuleOps<Self>
67    + ActivationOps<Self>
68    + QTensorOps<Self>
69    + TransactionOps<Self>
70    + Clone
71    + Default
72    + Sized
73    + Send
74    + Sync
75    + core::fmt::Debug
76    + 'static
77{
78    /// Device type.
79    type Device: DeviceOps;
80
81    /// Tensor primitive to be used for all float operations.
82    type FloatTensorPrimitive: TensorMetadata + 'static;
83    /// Default float element type.
84    type FloatElem: Element;
85
86    /// Tensor primitive to be used for all int operations.
87    type IntTensorPrimitive: TensorMetadata + 'static;
88    /// Int element type.
89    type IntElem: Element;
90
91    /// Tensor primitive to be used for all bool operations.
92    type BoolTensorPrimitive: TensorMetadata + 'static;
93    /// Tensor primitive to be used for all bool operations.
94    type BoolElem: Element;
95
96    /// Tensor primitive to be used for all quantized operations.
97    type QuantizedTensorPrimitive: TensorMetadata + QTensorPrimitive + 'static;
98
99    /// If autodiff is enabled.
100    fn ad_enabled() -> bool {
101        false
102    }
103
104    /// Sets the current allocation mode to persistent.
105    #[allow(unused_variables)]
106    fn memory_persistent_allocations<Output, Input, Func: Fn(Input) -> Output>(
107        device: &Self::Device,
108        input: Input,
109        func: Func,
110    ) -> Output {
111        func(input)
112    }
113
114    /// Manually triggers a memory cleanup on the given device.
115    #[allow(unused_variables)]
116    fn memory_cleanup(device: &Self::Device) {}
117
118    /// Name of the backend.
119    fn name(device: &Self::Device) -> String;
120
121    /// Seeds the backend on the specified device.
122    ///
123    /// There is no guarantee that only the specified device will be seeded, but it is guaranteed
124    /// that at least the specified device will be seeded.
125    ///
126    /// In all cases, this should ensure deterministic execution for a single-threaded program.
127    fn seed(device: &Self::Device, seed: u64);
128
129    /// Sync the backend, ensure that all computation are finished.
130    fn sync(_device: &Self::Device) -> Result<(), ExecutionError> {
131        Ok(())
132    }
133
134    /// Marks the given data as being used as a staging buffer for transfer between CPU and
135    /// accelerators like GPUs.
136    ///
137    /// The given data might be transferred to pinned memory or another format to improve data transfer
138    /// speed.
139    fn staging<'a, Iter>(_data: Iter, _device: &Self::Device)
140    where
141        Iter: Iterator<Item = &'a mut TensorData>,
142    {
143    }
144}
145
146/// An error that can happen when syncing a device.
147#[derive(Error, Serialize, Deserialize)]
148pub enum ExecutionError {
149    /// A generic error happened during execution.
150    ///
151    /// The backtrace and context information should be included in the reason string.
152    #[error("An error happened during execution\nCaused by:\n  {reason}")]
153    WithContext {
154        /// The reason of the error.
155        reason: String,
156    },
157    /// A generic error happened during execution thrown in the Burn project.
158    ///
159    /// The full context isn't captured by the string alone.
160    #[error("An error happened during execution\nCaused by:\n  {reason}")]
161    Generic {
162        /// The reason of the error.
163        reason: String,
164        /// The backtrace.
165        #[serde(skip)]
166        backtrace: BackTrace,
167    },
168}
169
170impl core::fmt::Debug for ExecutionError {
171    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
172        f.write_fmt(format_args!("{self}"))
173    }
174}
175
176/// Trait that allows a backend to support autodiff.
177pub trait AutodiffBackend: Backend {
178    /// The inner backend type.
179    type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>;
180
181    /// Gradients type.
182    type Gradients: Send;
183
184    /// Backward pass.
185    ///
186    /// # Arguments
187    ///
188    /// * `tensor` - The tensor is the last node of computational graph where the gradients are computed.
189    ///
190    /// # Returns
191    ///
192    /// The gradients.
193    fn backward(tensor: FloatTensor<Self>) -> Self::Gradients;
194
195    /// Returns the gradients of a tensor.
196    ///
197    /// # Arguments
198    ///
199    /// * `tensor` - The tensor to extract the gradients from.
200    ///
201    /// # Returns
202    ///
203    /// An optional tensor containing the gradient.
204    fn grad(
205        tensor: &FloatTensor<Self>,
206        grads: &Self::Gradients,
207    ) -> Option<FloatTensor<Self::InnerBackend>>;
208
209    /// Pops the gradients of a tensor and returns them.
210    ///
211    /// # Arguments
212    ///
213    /// * `tensor` - The tensor to pop the gradients from.
214    /// * `grads` - The gradients.
215    ///
216    /// # Returns
217    ///
218    /// An optional tensor containing the given gradients.
219    fn grad_remove(
220        tensor: &FloatTensor<Self>,
221        grads: &mut Self::Gradients,
222    ) -> Option<FloatTensor<Self::InnerBackend>>;
223
224    /// Replace the gradients of a tensor with the one provided.
225    ///
226    /// If no gradient existed for the provided tensor, register it.
227    ///
228    /// # Arguments
229    ///
230    /// * `tensor` - The tensor to pop the gradients from.
231    /// * `grads` - The gradients.
232    /// * `grad` - The updated grad tensor.
233    fn grad_replace(
234        tensor: &FloatTensor<Self>,
235        grads: &mut Self::Gradients,
236        grad: FloatTensor<Self::InnerBackend>,
237    );
238
239    /// Returns the tensor with inner backend type.
240    ///
241    /// # Arguments
242    ///
243    /// * `tensor` - The tensor to get the inner backend tensor for.
244    ///
245    /// # Returns
246    ///
247    /// The inner backend tensor.
248    fn inner(tensor: FloatTensor<Self>) -> FloatTensor<Self::InnerBackend>;
249
250    /// Returns the tensor with inner backend type.
251    ///
252    /// # Arguments
253    ///
254    /// * `tensor` - The tensor to get the inner backend tensor for.
255    ///
256    /// # Returns
257    ///
258    /// The inner backend tensor.
259    fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend>;
260
261    /// Returns the tensor with inner backend type.
262    ///
263    /// # Arguments
264    ///
265    /// * `tensor` - The tensor to get the inner backend tensor for.
266    ///
267    /// # Returns
268    ///
269    /// The inner backend tensor.
270    fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend>;
271
272    /// Returns the tensor with inner backend type.
273    ///
274    /// # Arguments
275    ///
276    /// * `tensor` - The tensor to get the inner backend tensor for.
277    ///
278    /// # Returns
279    ///
280    /// The inner backend tensor.
281    fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend>;
282
283    /// Converts the inner backend tensor to the autodiff backend tensor.
284    ///
285    /// # Arguments
286    ///
287    /// * `tensor` - The inner backend tensor to convert.
288    ///
289    ///
290    /// # Returns
291    ///
292    /// The autodiff backend tensor.
293    fn from_inner(tensor: FloatTensor<Self::InnerBackend>) -> FloatTensor<Self>;
294
295    /// Converts the inner backend tensor to the autodiff backend tensor.
296    ///
297    /// # Arguments
298    ///
299    /// * `tensor` - The inner backend tensor to convert.
300    ///
301    ///
302    /// # Returns
303    ///
304    /// The autodiff backend tensor.
305    fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self>;
306
307    /// Converts the inner backend tensor to the autodiff backend tensor.
308    ///
309    /// # Arguments
310    ///
311    /// * `tensor` - The inner backend tensor to convert.
312    ///
313    ///
314    /// # Returns
315    ///
316    /// The autodiff backend tensor.
317    fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self>;
318
319    /// Converts the inner backend tensor to the autodiff backend tensor.
320    ///
321    /// # Arguments
322    ///
323    /// * `tensor` - The inner backend tensor to convert.
324    ///
325    ///
326    /// # Returns
327    ///
328    /// The autodiff backend tensor.
329    fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self>;
330}