burn_backend/backend/base.rs
1use alloc::string::String;
2use burn_std::backtrace::BackTrace;
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5
6use crate::element::Element;
7use crate::ops::*;
8use crate::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
9use crate::{QTensorPrimitive, TensorData, TensorMetadata};
10
11use super::DeviceOps;
12
13/// This trait defines all types and functions needed for a backend to be used with burn.
14///
15/// ## Design
16///
17/// This trait aims to be as unopinionated as possible and allows implementations to define
18/// their own types and patterns. Therefore, there are few pre-defined abstractions baked
19/// into this trait.
20///
21/// Backends must define their own tensor types for each data type: `float`, `int`, and `bool`.
22/// Since we minimize assumptions, we chose to separate these types, as they are used in
23/// different contexts. However, some backends may have a generic tensor type that is used
24/// for all data types.
25///
26/// ### Eager Mode
27///
28/// Because burn supports dynamic graphs, the backend trait is designed around kernel
29/// implementations that can be called without any mutable context or graph. This may not be
30/// ideal for backends that want to configure their computational graphs and execute them
31/// multiple times.
32///
33/// To implement this kind of backend, channels could be used to communicate with a backend
34/// server thread to build the computation graphs and re-execute the ones that are repeated,
35/// with some form of cache. Once that pattern has matured, a graph mode backend trait could
36/// be extracted from it, allowing other backends of the same kind to be quickly integrated
37/// with burn. This pattern could also be used to create an operation fusion trait, which
38/// allows backends to define what kind of graph structures can be fused into one operation.
39///
40/// ### Multi-Threaded
41///
42/// Backend tensor types are all `Clone` + `Send`, which allows them to be safely
43/// sent between threads. It is recommended to wrap tensors with [Arc](alloc::sync::Arc),
44/// which avoids copying the tensor's buffer. Note that it is still possible to mutate and
45/// reuse tensors' buffer without locking; see the next section on the Mutable API.
46///
47/// ### Mutable API
48///
49/// There is no mutable or inplace operation API to implement, but that does not mean that
50/// backends cannot support them. Using [try_unwrap](alloc::sync::Arc::try_unwrap) and
51/// [get_mut](alloc::sync::Arc::get_mut) allows backends to have access to an owned or mutable
52/// reference to their tensor buffer data structure if the tensor is not shared. In that case,
53/// backends can dispatch to their owned inplace operations for better performance.
54///
55/// ## Documentation
56///
57/// Most of the documentation for each function can be found on the user API
58#[cfg_attr(doc, doc = crate::doc_tensor!())]
59#[cfg_attr(not(doc), doc = "`Tensor`")]
60/// struct in the `burn-tensor` crate.
61/// For modules, public functions are often created, which can be used by `burn-core` modules.
62pub trait Backend:
63 FloatTensorOps<Self>
64 + BoolTensorOps<Self>
65 + IntTensorOps<Self>
66 + ModuleOps<Self>
67 + ActivationOps<Self>
68 + QTensorOps<Self>
69 + TransactionOps<Self>
70 + Clone
71 + Default
72 + Sized
73 + Send
74 + Sync
75 + core::fmt::Debug
76 + 'static
77{
78 /// Device type.
79 type Device: DeviceOps;
80
81 /// Tensor primitive to be used for all float operations.
82 type FloatTensorPrimitive: TensorMetadata + 'static;
83 /// Default float element type.
84 type FloatElem: Element;
85
86 /// Tensor primitive to be used for all int operations.
87 type IntTensorPrimitive: TensorMetadata + 'static;
88 /// Int element type.
89 type IntElem: Element;
90
91 /// Tensor primitive to be used for all bool operations.
92 type BoolTensorPrimitive: TensorMetadata + 'static;
93 /// Tensor primitive to be used for all bool operations.
94 type BoolElem: Element;
95
96 /// Tensor primitive to be used for all quantized operations.
97 type QuantizedTensorPrimitive: TensorMetadata + QTensorPrimitive + 'static;
98
99 /// If autodiff is enabled.
100 fn ad_enabled() -> bool {
101 false
102 }
103
104 /// Sets the current allocation mode to persistent.
105 #[allow(unused_variables)]
106 fn memory_persistent_allocations<Output, Input, Func: Fn(Input) -> Output>(
107 device: &Self::Device,
108 input: Input,
109 func: Func,
110 ) -> Output {
111 func(input)
112 }
113
114 /// Manually triggers a memory cleanup on the given device.
115 #[allow(unused_variables)]
116 fn memory_cleanup(device: &Self::Device) {}
117
118 /// Name of the backend.
119 fn name(device: &Self::Device) -> String;
120
121 /// Seeds the backend on the specified device.
122 ///
123 /// There is no guarantee that only the specified device will be seeded, but it is guaranteed
124 /// that at least the specified device will be seeded.
125 ///
126 /// In all cases, this should ensure deterministic execution for a single-threaded program.
127 fn seed(device: &Self::Device, seed: u64);
128
129 /// Sync the backend, ensure that all computation are finished.
130 fn sync(_device: &Self::Device) -> Result<(), ExecutionError> {
131 Ok(())
132 }
133
134 /// Marks the given data as being used as a staging buffer for transfer between CPU and
135 /// accelerators like GPUs.
136 ///
137 /// The given data might be transferred to pinned memory or another format to improve data transfer
138 /// speed.
139 fn staging<'a, Iter>(_data: Iter, _device: &Self::Device)
140 where
141 Iter: Iterator<Item = &'a mut TensorData>,
142 {
143 }
144}
145
146/// An error that can happen when syncing a device.
147#[derive(Error, Serialize, Deserialize)]
148pub enum ExecutionError {
149 /// A generic error happened during execution.
150 ///
151 /// The backtrace and context information should be included in the reason string.
152 #[error("An error happened during execution\nCaused by:\n {reason}")]
153 WithContext {
154 /// The reason of the error.
155 reason: String,
156 },
157 /// A generic error happened during execution thrown in the Burn project.
158 ///
159 /// The full context isn't captured by the string alone.
160 #[error("An error happened during execution\nCaused by:\n {reason}")]
161 Generic {
162 /// The reason of the error.
163 reason: String,
164 /// The backtrace.
165 #[serde(skip)]
166 backtrace: BackTrace,
167 },
168}
169
170impl core::fmt::Debug for ExecutionError {
171 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
172 f.write_fmt(format_args!("{self}"))
173 }
174}
175
176/// Trait that allows a backend to support autodiff.
177pub trait AutodiffBackend: Backend {
178 /// The inner backend type.
179 type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>;
180
181 /// Gradients type.
182 type Gradients: Send;
183
184 /// Backward pass.
185 ///
186 /// # Arguments
187 ///
188 /// * `tensor` - The tensor is the last node of computational graph where the gradients are computed.
189 ///
190 /// # Returns
191 ///
192 /// The gradients.
193 fn backward(tensor: FloatTensor<Self>) -> Self::Gradients;
194
195 /// Returns the gradients of a tensor.
196 ///
197 /// # Arguments
198 ///
199 /// * `tensor` - The tensor to extract the gradients from.
200 ///
201 /// # Returns
202 ///
203 /// An optional tensor containing the gradient.
204 fn grad(
205 tensor: &FloatTensor<Self>,
206 grads: &Self::Gradients,
207 ) -> Option<FloatTensor<Self::InnerBackend>>;
208
209 /// Pops the gradients of a tensor and returns them.
210 ///
211 /// # Arguments
212 ///
213 /// * `tensor` - The tensor to pop the gradients from.
214 /// * `grads` - The gradients.
215 ///
216 /// # Returns
217 ///
218 /// An optional tensor containing the given gradients.
219 fn grad_remove(
220 tensor: &FloatTensor<Self>,
221 grads: &mut Self::Gradients,
222 ) -> Option<FloatTensor<Self::InnerBackend>>;
223
224 /// Replace the gradients of a tensor with the one provided.
225 ///
226 /// If no gradient existed for the provided tensor, register it.
227 ///
228 /// # Arguments
229 ///
230 /// * `tensor` - The tensor to pop the gradients from.
231 /// * `grads` - The gradients.
232 /// * `grad` - The updated grad tensor.
233 fn grad_replace(
234 tensor: &FloatTensor<Self>,
235 grads: &mut Self::Gradients,
236 grad: FloatTensor<Self::InnerBackend>,
237 );
238
239 /// Returns the tensor with inner backend type.
240 ///
241 /// # Arguments
242 ///
243 /// * `tensor` - The tensor to get the inner backend tensor for.
244 ///
245 /// # Returns
246 ///
247 /// The inner backend tensor.
248 fn inner(tensor: FloatTensor<Self>) -> FloatTensor<Self::InnerBackend>;
249
250 /// Returns the tensor with inner backend type.
251 ///
252 /// # Arguments
253 ///
254 /// * `tensor` - The tensor to get the inner backend tensor for.
255 ///
256 /// # Returns
257 ///
258 /// The inner backend tensor.
259 fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend>;
260
261 /// Returns the tensor with inner backend type.
262 ///
263 /// # Arguments
264 ///
265 /// * `tensor` - The tensor to get the inner backend tensor for.
266 ///
267 /// # Returns
268 ///
269 /// The inner backend tensor.
270 fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend>;
271
272 /// Returns the tensor with inner backend type.
273 ///
274 /// # Arguments
275 ///
276 /// * `tensor` - The tensor to get the inner backend tensor for.
277 ///
278 /// # Returns
279 ///
280 /// The inner backend tensor.
281 fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend>;
282
283 /// Converts the inner backend tensor to the autodiff backend tensor.
284 ///
285 /// # Arguments
286 ///
287 /// * `tensor` - The inner backend tensor to convert.
288 ///
289 ///
290 /// # Returns
291 ///
292 /// The autodiff backend tensor.
293 fn from_inner(tensor: FloatTensor<Self::InnerBackend>) -> FloatTensor<Self>;
294
295 /// Converts the inner backend tensor to the autodiff backend tensor.
296 ///
297 /// # Arguments
298 ///
299 /// * `tensor` - The inner backend tensor to convert.
300 ///
301 ///
302 /// # Returns
303 ///
304 /// The autodiff backend tensor.
305 fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self>;
306
307 /// Converts the inner backend tensor to the autodiff backend tensor.
308 ///
309 /// # Arguments
310 ///
311 /// * `tensor` - The inner backend tensor to convert.
312 ///
313 ///
314 /// # Returns
315 ///
316 /// The autodiff backend tensor.
317 fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self>;
318
319 /// Converts the inner backend tensor to the autodiff backend tensor.
320 ///
321 /// # Arguments
322 ///
323 /// * `tensor` - The inner backend tensor to convert.
324 ///
325 ///
326 /// # Returns
327 ///
328 /// The autodiff backend tensor.
329 fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self>;
330}