burn_backend/backend/base.rs
1use burn_std::DType;
2pub use burn_std::backtrace::BackTrace;
3
4use alloc::string::String;
5use serde::{Deserialize, Serialize};
6use thiserror::Error;
7
8use crate::element::Element;
9use crate::ops::*;
10use crate::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
11use crate::{QTensorPrimitive, TensorData, TensorMetadata};
12
13use super::DeviceOps;
14
15/// This trait defines all types and functions needed for a backend to be used with burn.
16///
17/// ## Design
18///
19/// This trait aims to be as unopinionated as possible and allows implementations to define
20/// their own types and patterns. Therefore, there are few pre-defined abstractions baked
21/// into this trait.
22///
23/// Backends must define their own tensor types for each data type: `float`, `int`, and `bool`.
24/// Since we minimize assumptions, we chose to separate these types, as they are used in
25/// different contexts. However, some backends may have a generic tensor type that is used
26/// for all data types.
27///
28/// ### Eager Mode
29///
30/// Because burn supports dynamic graphs, the backend trait is designed around kernel
31/// implementations that can be called without any mutable context or graph. This may not be
32/// ideal for backends that want to configure their computational graphs and execute them
33/// multiple times.
34///
35/// To implement this kind of backend, channels could be used to communicate with a backend
36/// server thread to build the computation graphs and re-execute the ones that are repeated,
37/// with some form of cache. Once that pattern has matured, a graph mode backend trait could
38/// be extracted from it, allowing other backends of the same kind to be quickly integrated
39/// with burn. This pattern could also be used to create an operation fusion trait, which
40/// allows backends to define what kind of graph structures can be fused into one operation.
41///
42/// ### Multi-Threaded
43///
44/// Backend tensor types are all `Clone` + `Send`, which allows them to be safely
45/// sent between threads. It is recommended to wrap tensors with [Arc](alloc::sync::Arc),
46/// which avoids copying the tensor's buffer. Note that it is still possible to mutate and
47/// reuse tensors' buffer without locking; see the next section on the Mutable API.
48///
49/// ### Mutable API
50///
51/// There is no mutable or inplace operation API to implement, but that does not mean that
52/// backends cannot support them. Using [try_unwrap](alloc::sync::Arc::try_unwrap) and
53/// [get_mut](alloc::sync::Arc::get_mut) allows backends to have access to an owned or mutable
54/// reference to their tensor buffer data structure if the tensor is not shared. In that case,
55/// backends can dispatch to their owned inplace operations for better performance.
56///
57/// ## Documentation
58///
59/// Most of the documentation for each function can be found on the user API
60#[cfg_attr(doc, doc = crate::doc_tensor!())]
61#[cfg_attr(not(doc), doc = "`Tensor`")]
62/// struct in the `burn-tensor` crate.
63/// For modules, public functions are often created, which can be used by `burn-core` modules.
64pub trait Backend:
65 FloatTensorOps<Self>
66 + BoolTensorOps<Self>
67 + IntTensorOps<Self>
68 + ModuleOps<Self>
69 + ActivationOps<Self>
70 + QTensorOps<Self>
71 + TransactionOps<Self>
72 + Clone
73 + Default
74 + Sized
75 + Send
76 + Sync
77 + core::fmt::Debug
78 + 'static
79{
80 /// Device type.
81 type Device: DeviceOps;
82
83 /// Tensor primitive to be used for all float operations.
84 type FloatTensorPrimitive: TensorMetadata + 'static;
85 /// Default float element type.
86 type FloatElem: Element;
87
88 /// Tensor primitive to be used for all int operations.
89 type IntTensorPrimitive: TensorMetadata + 'static;
90 /// Int element type.
91 type IntElem: Element;
92
93 /// Tensor primitive to be used for all bool operations.
94 type BoolTensorPrimitive: TensorMetadata + 'static;
95 /// Tensor primitive to be used for all bool operations.
96 type BoolElem: Element;
97
98 /// Tensor primitive to be used for all quantized operations.
99 type QuantizedTensorPrimitive: TensorMetadata + QTensorPrimitive + 'static;
100
101 /// If autodiff is enabled.
102 fn ad_enabled() -> bool {
103 false
104 }
105
106 /// Sets the current allocation mode to persistent.
107 #[allow(unused_variables)]
108 fn memory_persistent_allocations<Output, Input, Func: Fn(Input) -> Output>(
109 device: &Self::Device,
110 input: Input,
111 func: Func,
112 ) -> Output {
113 func(input)
114 }
115
116 /// Manually triggers a memory cleanup on the given device.
117 #[allow(unused_variables)]
118 fn memory_cleanup(device: &Self::Device) {}
119
120 /// Name of the backend.
121 fn name(device: &Self::Device) -> String;
122
123 /// Seeds the backend on the specified device.
124 ///
125 /// There is no guarantee that only the specified device will be seeded, but it is guaranteed
126 /// that at least the specified device will be seeded.
127 ///
128 /// In all cases, this should ensure deterministic execution for a single-threaded program.
129 fn seed(device: &Self::Device, seed: u64);
130
131 /// Sync the backend, ensure that all computation are finished.
132 fn sync(_device: &Self::Device) -> Result<(), ExecutionError> {
133 Ok(())
134 }
135
136 /// Marks the given data as being used as a staging buffer for transfer between CPU and
137 /// accelerators like GPUs.
138 ///
139 /// The given data might be transferred to pinned memory or another format to improve data transfer
140 /// speed.
141 fn staging<'a, Iter>(_data: Iter, _device: &Self::Device)
142 where
143 Iter: Iterator<Item = &'a mut TensorData>,
144 {
145 }
146
147 /// Whether the type is supported by the specified device.
148 fn supports_dtype(device: &Self::Device, dtype: DType) -> bool;
149}
150
151/// An error that can happen when syncing a device.
152#[derive(Error, Serialize, Deserialize)]
153pub enum ExecutionError {
154 /// A generic error happened during execution.
155 ///
156 /// The backtrace and context information should be included in the reason string.
157 #[error("An error happened during execution\nCaused by:\n {reason}")]
158 WithContext {
159 /// The reason of the error.
160 reason: String,
161 },
162 /// A generic error happened during execution thrown in the Burn project.
163 ///
164 /// The full context isn't captured by the string alone.
165 #[error("An error happened during execution\nCaused by:\n {reason}")]
166 Generic {
167 /// The reason of the error.
168 reason: String,
169 /// The backtrace.
170 #[serde(skip)]
171 backtrace: BackTrace,
172 },
173}
174
175impl core::fmt::Debug for ExecutionError {
176 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
177 f.write_fmt(format_args!("{self}"))
178 }
179}
180
181/// Trait that allows a backend to support autodiff.
182pub trait AutodiffBackend: Backend {
183 /// The inner backend type.
184 type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>;
185
186 /// Gradients type.
187 type Gradients: Send;
188
189 /// Backward pass.
190 ///
191 /// # Arguments
192 ///
193 /// * `tensor` - The tensor is the last node of computational graph where the gradients are computed.
194 ///
195 /// # Returns
196 ///
197 /// The gradients.
198 fn backward(tensor: FloatTensor<Self>) -> Self::Gradients;
199
200 /// Returns the gradients of a tensor.
201 ///
202 /// # Arguments
203 ///
204 /// * `tensor` - The tensor to extract the gradients from.
205 ///
206 /// # Returns
207 ///
208 /// An optional tensor containing the gradient.
209 fn grad(
210 tensor: &FloatTensor<Self>,
211 grads: &Self::Gradients,
212 ) -> Option<FloatTensor<Self::InnerBackend>>;
213
214 /// Pops the gradients of a tensor and returns them.
215 ///
216 /// # Arguments
217 ///
218 /// * `tensor` - The tensor to pop the gradients from.
219 /// * `grads` - The gradients.
220 ///
221 /// # Returns
222 ///
223 /// An optional tensor containing the given gradients.
224 fn grad_remove(
225 tensor: &FloatTensor<Self>,
226 grads: &mut Self::Gradients,
227 ) -> Option<FloatTensor<Self::InnerBackend>>;
228
229 /// Replace the gradients of a tensor with the one provided.
230 ///
231 /// If no gradient existed for the provided tensor, register it.
232 ///
233 /// # Arguments
234 ///
235 /// * `tensor` - The tensor to pop the gradients from.
236 /// * `grads` - The gradients.
237 /// * `grad` - The updated grad tensor.
238 fn grad_replace(
239 tensor: &FloatTensor<Self>,
240 grads: &mut Self::Gradients,
241 grad: FloatTensor<Self::InnerBackend>,
242 );
243
244 /// Returns the tensor with inner backend type.
245 ///
246 /// # Arguments
247 ///
248 /// * `tensor` - The tensor to get the inner backend tensor for.
249 ///
250 /// # Returns
251 ///
252 /// The inner backend tensor.
253 fn inner(tensor: FloatTensor<Self>) -> FloatTensor<Self::InnerBackend>;
254
255 /// Returns the tensor with inner backend type.
256 ///
257 /// # Arguments
258 ///
259 /// * `tensor` - The tensor to get the inner backend tensor for.
260 ///
261 /// # Returns
262 ///
263 /// The inner backend tensor.
264 fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend>;
265
266 /// Returns the tensor with inner backend type.
267 ///
268 /// # Arguments
269 ///
270 /// * `tensor` - The tensor to get the inner backend tensor for.
271 ///
272 /// # Returns
273 ///
274 /// The inner backend tensor.
275 fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend>;
276
277 /// Returns the tensor with inner backend type.
278 ///
279 /// # Arguments
280 ///
281 /// * `tensor` - The tensor to get the inner backend tensor for.
282 ///
283 /// # Returns
284 ///
285 /// The inner backend tensor.
286 fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend>;
287
288 /// Converts the inner backend tensor to the autodiff backend tensor.
289 ///
290 /// # Arguments
291 ///
292 /// * `tensor` - The inner backend tensor to convert.
293 ///
294 ///
295 /// # Returns
296 ///
297 /// The autodiff backend tensor.
298 fn from_inner(tensor: FloatTensor<Self::InnerBackend>) -> FloatTensor<Self>;
299
300 /// Converts the inner backend tensor to the autodiff backend tensor.
301 ///
302 /// # Arguments
303 ///
304 /// * `tensor` - The inner backend tensor to convert.
305 ///
306 ///
307 /// # Returns
308 ///
309 /// The autodiff backend tensor.
310 fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self>;
311
312 /// Converts the inner backend tensor to the autodiff backend tensor.
313 ///
314 /// # Arguments
315 ///
316 /// * `tensor` - The inner backend tensor to convert.
317 ///
318 ///
319 /// # Returns
320 ///
321 /// The autodiff backend tensor.
322 fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self>;
323
324 /// Converts the inner backend tensor to the autodiff backend tensor.
325 ///
326 /// # Arguments
327 ///
328 /// * `tensor` - The inner backend tensor to convert.
329 ///
330 ///
331 /// # Returns
332 ///
333 /// The autodiff backend tensor.
334 fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self>;
335}