burn_backend/backend/base.rs
1use burn_std::DType;
2pub use burn_std::backtrace::BackTrace;
3
4use alloc::string::String;
5use enumset::{EnumSet, EnumSetType};
6use serde::{Deserialize, Serialize};
7use thiserror::Error;
8
9use crate::element::Element;
10use crate::ops::*;
11use crate::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
12use crate::{QTensorPrimitive, TensorData, TensorMetadata};
13
14use super::DeviceOps;
15
16/// This trait defines all types and functions needed for a backend to be used with burn.
17///
18/// ## Design
19///
20/// This trait aims to be as unopinionated as possible and allows implementations to define
21/// their own types and patterns. Therefore, there are few pre-defined abstractions baked
22/// into this trait.
23///
24/// Backends must define their own tensor types for each data type: `float`, `int`, and `bool`.
25/// Since we minimize assumptions, we chose to separate these types, as they are used in
26/// different contexts. However, some backends may have a generic tensor type that is used
27/// for all data types.
28///
29/// ### Eager Mode
30///
31/// Because burn supports dynamic graphs, the backend trait is designed around kernel
32/// implementations that can be called without any mutable context or graph. This may not be
33/// ideal for backends that want to configure their computational graphs and execute them
34/// multiple times.
35///
36/// To implement this kind of backend, channels could be used to communicate with a backend
37/// server thread to build the computation graphs and re-execute the ones that are repeated,
38/// with some form of cache. Once that pattern has matured, a graph mode backend trait could
39/// be extracted from it, allowing other backends of the same kind to be quickly integrated
40/// with burn. This pattern could also be used to create an operation fusion trait, which
41/// allows backends to define what kind of graph structures can be fused into one operation.
42///
43/// ### Multi-Threaded
44///
45/// Backend tensor types are all `Clone` + `Send`, which allows them to be safely
46/// sent between threads. It is recommended to wrap tensors with [Arc](alloc::sync::Arc),
47/// which avoids copying the tensor's buffer. Note that it is still possible to mutate and
48/// reuse tensors' buffer without locking; see the next section on the Mutable API.
49///
50/// ### Mutable API
51///
52/// There is no mutable or inplace operation API to implement, but that does not mean that
53/// backends cannot support them. Using [try_unwrap](alloc::sync::Arc::try_unwrap) and
54/// [get_mut](alloc::sync::Arc::get_mut) allows backends to have access to an owned or mutable
55/// reference to their tensor buffer data structure if the tensor is not shared. In that case,
56/// backends can dispatch to their owned inplace operations for better performance.
57///
58/// ## Documentation
59///
60/// Most of the documentation for each function can be found on the user API
61#[cfg_attr(doc, doc = crate::doc_tensor!())]
62#[cfg_attr(not(doc), doc = "`Tensor`")]
63/// struct in the `burn-tensor` crate.
64/// For modules, public functions are often created, which can be used by `burn-core` modules.
65pub trait Backend:
66 FloatTensorOps<Self>
67 + BoolTensorOps<Self>
68 + IntTensorOps<Self>
69 + ModuleOps<Self>
70 + ActivationOps<Self>
71 + QTensorOps<Self>
72 + TransactionOps<Self>
73 + Clone
74 + Default
75 + Sized
76 + Send
77 + Sync
78 + core::fmt::Debug
79 + 'static
80{
81 /// Device type.
82 type Device: DeviceOps;
83
84 /// Tensor primitive to be used for all float operations.
85 type FloatTensorPrimitive: TensorMetadata + 'static;
86 /// Default float element type.
87 type FloatElem: Element;
88
89 /// Tensor primitive to be used for all int operations.
90 type IntTensorPrimitive: TensorMetadata + 'static;
91 /// Int element type.
92 type IntElem: Element;
93
94 /// Tensor primitive to be used for all bool operations.
95 type BoolTensorPrimitive: TensorMetadata + 'static;
96 /// Tensor primitive to be used for all bool operations.
97 type BoolElem: Element;
98
99 /// Tensor primitive to be used for all quantized operations.
100 type QuantizedTensorPrimitive: TensorMetadata + QTensorPrimitive + 'static;
101
102 /// If autodiff is enabled.
103 fn ad_enabled() -> bool {
104 false
105 }
106
107 /// Sets the current allocation mode to persistent.
108 #[allow(unused_variables)]
109 fn memory_persistent_allocations<Output, Input, Func: Fn(Input) -> Output>(
110 device: &Self::Device,
111 input: Input,
112 func: Func,
113 ) -> Output {
114 func(input)
115 }
116
117 /// Manually triggers a memory cleanup on the given device.
118 #[allow(unused_variables)]
119 fn memory_cleanup(device: &Self::Device) {}
120
121 /// Name of the backend.
122 fn name(device: &Self::Device) -> String;
123
124 /// Seeds the backend on the specified device.
125 ///
126 /// There is no guarantee that only the specified device will be seeded, but it is guaranteed
127 /// that at least the specified device will be seeded.
128 ///
129 /// In all cases, this should ensure deterministic execution for a single-threaded program.
130 fn seed(device: &Self::Device, seed: u64);
131
132 /// Sync the backend, ensure that all computation are finished.
133 fn sync(_device: &Self::Device) -> Result<(), ExecutionError> {
134 Ok(())
135 }
136
137 /// Marks the given data as being used as a staging buffer for transfer between CPU and
138 /// accelerators like GPUs.
139 ///
140 /// The given data might be transferred to pinned memory or another format to improve data transfer
141 /// speed.
142 fn staging<'a, Iter>(_data: Iter, _device: &Self::Device)
143 where
144 Iter: Iterator<Item = &'a mut TensorData>,
145 {
146 }
147
148 /// Whether the type is fully supported by the specified device for general operations.
149 ///
150 /// A type is considered supported if it can be used for the full suite of tensor
151 /// operations, including storage, conversion, and basic arithmetic.
152 ///
153 /// Returning `false` does not necessarily mean the device cannot handle the type at all.
154 /// For instance, a device might support a type only for specialized hardware
155 /// acceleration (e.g., matrix multiplication) but lack general arithmetic support. Such
156 /// types should return `false` here as they are not globally supported.
157 fn supports_dtype(device: &Self::Device, dtype: DType) -> bool {
158 Self::dtype_usage(device, dtype).is_superset(DTypeUsage::general())
159 }
160
161 /// Returns the [DTypeUsageSet] for the given [DType] on the specified device.
162 fn dtype_usage(device: &Self::Device, dtype: DType) -> DTypeUsageSet;
163}
164
165/// An error that can happen when syncing a device.
166#[derive(Error, Serialize, Deserialize)]
167pub enum ExecutionError {
168 /// A generic error happened during execution.
169 ///
170 /// The backtrace and context information should be included in the reason string.
171 #[error("An error happened during execution\nCaused by:\n {reason}")]
172 WithContext {
173 /// The reason of the error.
174 reason: String,
175 },
176 /// A generic error happened during execution thrown in the Burn project.
177 ///
178 /// The full context isn't captured by the string alone.
179 #[error("An error happened during execution\nCaused by:\n {reason}")]
180 Generic {
181 /// The reason of the error.
182 reason: String,
183 /// The backtrace.
184 #[serde(skip)]
185 backtrace: BackTrace,
186 },
187}
188
189impl core::fmt::Debug for ExecutionError {
190 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
191 f.write_fmt(format_args!("{self}"))
192 }
193}
194
195/// Trait that allows a backend to support autodiff.
196pub trait AutodiffBackend: Backend {
197 /// The inner backend type.
198 type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>;
199
200 /// Gradients type.
201 type Gradients: Send;
202
203 /// Backward pass.
204 ///
205 /// # Arguments
206 ///
207 /// * `tensor` - The tensor is the last node of computational graph where the gradients are computed.
208 ///
209 /// # Returns
210 ///
211 /// The gradients.
212 fn backward(tensor: FloatTensor<Self>) -> Self::Gradients;
213
214 /// Returns the gradients of a tensor.
215 ///
216 /// # Arguments
217 ///
218 /// * `tensor` - The tensor to extract the gradients from.
219 ///
220 /// # Returns
221 ///
222 /// An optional tensor containing the gradient.
223 fn grad(
224 tensor: &FloatTensor<Self>,
225 grads: &Self::Gradients,
226 ) -> Option<FloatTensor<Self::InnerBackend>>;
227
228 /// Pops the gradients of a tensor and returns them.
229 ///
230 /// # Arguments
231 ///
232 /// * `tensor` - The tensor to pop the gradients from.
233 /// * `grads` - The gradients.
234 ///
235 /// # Returns
236 ///
237 /// An optional tensor containing the given gradients.
238 fn grad_remove(
239 tensor: &FloatTensor<Self>,
240 grads: &mut Self::Gradients,
241 ) -> Option<FloatTensor<Self::InnerBackend>>;
242
243 /// Replace the gradients of a tensor with the one provided.
244 ///
245 /// If no gradient existed for the provided tensor, register it.
246 ///
247 /// # Arguments
248 ///
249 /// * `tensor` - The tensor to pop the gradients from.
250 /// * `grads` - The gradients.
251 /// * `grad` - The updated grad tensor.
252 fn grad_replace(
253 tensor: &FloatTensor<Self>,
254 grads: &mut Self::Gradients,
255 grad: FloatTensor<Self::InnerBackend>,
256 );
257
258 /// Returns the tensor with inner backend type.
259 ///
260 /// # Arguments
261 ///
262 /// * `tensor` - The tensor to get the inner backend tensor for.
263 ///
264 /// # Returns
265 ///
266 /// The inner backend tensor.
267 fn inner(tensor: FloatTensor<Self>) -> FloatTensor<Self::InnerBackend>;
268
269 /// Returns the tensor with inner backend type.
270 ///
271 /// # Arguments
272 ///
273 /// * `tensor` - The tensor to get the inner backend tensor for.
274 ///
275 /// # Returns
276 ///
277 /// The inner backend tensor.
278 fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend>;
279
280 /// Returns the tensor with inner backend type.
281 ///
282 /// # Arguments
283 ///
284 /// * `tensor` - The tensor to get the inner backend tensor for.
285 ///
286 /// # Returns
287 ///
288 /// The inner backend tensor.
289 fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend>;
290
291 /// Returns the tensor with inner backend type.
292 ///
293 /// # Arguments
294 ///
295 /// * `tensor` - The tensor to get the inner backend tensor for.
296 ///
297 /// # Returns
298 ///
299 /// The inner backend tensor.
300 fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend>;
301
302 /// Converts the inner backend tensor to the autodiff backend tensor.
303 ///
304 /// # Arguments
305 ///
306 /// * `tensor` - The inner backend tensor to convert.
307 ///
308 ///
309 /// # Returns
310 ///
311 /// The autodiff backend tensor.
312 fn from_inner(tensor: FloatTensor<Self::InnerBackend>) -> FloatTensor<Self>;
313
314 /// Converts the inner backend tensor to the autodiff backend tensor.
315 ///
316 /// # Arguments
317 ///
318 /// * `tensor` - The inner backend tensor to convert.
319 ///
320 ///
321 /// # Returns
322 ///
323 /// The autodiff backend tensor.
324 fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self>;
325
326 /// Converts the inner backend tensor to the autodiff backend tensor.
327 ///
328 /// # Arguments
329 ///
330 /// * `tensor` - The inner backend tensor to convert.
331 ///
332 ///
333 /// # Returns
334 ///
335 /// The autodiff backend tensor.
336 fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self>;
337
338 /// Converts the inner backend tensor to the autodiff backend tensor.
339 ///
340 /// # Arguments
341 ///
342 /// * `tensor` - The inner backend tensor to convert.
343 ///
344 ///
345 /// # Returns
346 ///
347 /// The autodiff backend tensor.
348 fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self>;
349}
350
351/// Describes how a data type can be used on a given device.
352///
353/// A data type may be supported for different classes of operations. Not all
354/// data types that appear in hardware or kernel implementations are suitable
355/// for general-purpose tensor operations.
356#[derive(Debug, EnumSetType)]
357pub enum DTypeUsage {
358 /// The type can be stored in device memory and converted to and from
359 /// other supported data types.
360 Storage,
361 /// The type supports general-purpose arithmetic and common tensor
362 /// operations (e.g. elementwise ops, reductions, etc.).
363 Arithmetic,
364 /// The type is supported by hardware-accelerated execution paths.
365 ///
366 /// This typically indicates support for accelerator-backed compute units (e.g., tensor
367 /// cores executing MMA instructions) for high-performance operations such as matrix
368 /// multiplication and operations that lower to it.
369 ///
370 /// # Notes
371 /// - A type can be both [`Arithmetic`](DTypeUsage::Arithmetic) and
372 /// [`Accelerated`](DTypeUsage::Accelerated) if it supports general-purpose operations
373 /// *and* accelerated paths.
374 /// - If a type is marked as `Accelerated` but not `Arithmetic`, it is not
375 /// suitable for general-purpose tensor operations and may only be used
376 /// in specific accelerated operations.
377 ///
378 /// `Accelerated` is a **flag**, not a detailed descriptor. It does not enumerate which
379 /// operations are accelerated or which accelerator features are available.
380 Accelerated,
381}
382
383/// A set of [DTypeUsage] representing the total capabilities of a data type on a device.
384pub type DTypeUsageSet = EnumSet<DTypeUsage>;
385
386impl DTypeUsage {
387 /// Returns the usage set required for general-purpose tensor support.
388 pub fn general() -> DTypeUsageSet {
389 DTypeUsage::Storage | DTypeUsage::Arithmetic
390 }
391}