burn_backend/backend/base.rs
1use burn_std::DType;
2pub use burn_std::backtrace::BackTrace;
3
4use alloc::string::String;
5use enumset::{EnumSet, EnumSetType};
6use serde::{Deserialize, Serialize};
7use thiserror::Error;
8
9use crate::element::Element;
10use crate::ops::*;
11use crate::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
12use crate::{QTensorPrimitive, TensorData, TensorMetadata};
13
14#[cfg(feature = "distributed")]
15use crate::distributed::{DistributedParamId, DistributedParams};
16
17use super::DeviceOps;
18
19/// The mapping of types used by Backend and traits like Numeric, BasicOps
20pub trait BackendTypes {
21 /// Device type.
22 type Device: DeviceOps;
23
24 /// Tensor primitive to be used for all float operations.
25 type FloatTensorPrimitive: TensorMetadata + 'static;
26 /// Default float element type.
27 type FloatElem: Element;
28
29 /// Tensor primitive to be used for all int operations.
30 type IntTensorPrimitive: TensorMetadata + 'static;
31 /// Int element type.
32 type IntElem: Element;
33
34 /// Tensor primitive to be used for all bool operations.
35 type BoolTensorPrimitive: TensorMetadata + 'static;
36 /// Tensor primitive to be used for all bool operations.
37 type BoolElem: Element;
38
39 /// Tensor primitive to be used for all quantized operations.
40 type QuantizedTensorPrimitive: TensorMetadata + QTensorPrimitive + 'static;
41}
42
43/// This trait defines all types and functions needed for a backend to be used with burn.
44///
45/// ## Design
46///
47/// This trait aims to be as unopinionated as possible and allows implementations to define
48/// their own types and patterns. Therefore, there are few pre-defined abstractions baked
49/// into this trait.
50///
51/// Backends must define their own tensor types for each data type: `float`, `int`, and `bool`.
52/// Since we minimize assumptions, we chose to separate these types, as they are used in
53/// different contexts. However, some backends may have a generic tensor type that is used
54/// for all data types.
55///
56/// ### Eager Mode
57///
58/// Because burn supports dynamic graphs, the backend trait is designed around kernel
59/// implementations that can be called without any mutable context or graph. This may not be
60/// ideal for backends that want to configure their computational graphs and execute them
61/// multiple times.
62///
63/// To implement this kind of backend, channels could be used to communicate with a backend
64/// server thread to build the computation graphs and re-execute the ones that are repeated,
65/// with some form of cache. Once that pattern has matured, a graph mode backend trait could
66/// be extracted from it, allowing other backends of the same kind to be quickly integrated
67/// with burn. This pattern could also be used to create an operation fusion trait, which
68/// allows backends to define what kind of graph structures can be fused into one operation.
69///
70/// ### Multi-Threaded
71///
72/// Backend tensor types are all `Clone` + `Send`, which allows them to be safely
73/// sent between threads. It is recommended to wrap tensors with [Arc](alloc::sync::Arc),
74/// which avoids copying the tensor's buffer. Note that it is still possible to mutate and
75/// reuse tensors' buffer without locking; see the next section on the Mutable API.
76///
77/// ### Mutable API
78///
79/// There is no mutable or inplace operation API to implement, but that does not mean that
80/// backends cannot support them. Using [try_unwrap](alloc::sync::Arc::try_unwrap) and
81/// [get_mut](alloc::sync::Arc::get_mut) allows backends to have access to an owned or mutable
82/// reference to their tensor buffer data structure if the tensor is not shared. In that case,
83/// backends can dispatch to their owned inplace operations for better performance.
84///
85/// ## Documentation
86///
87/// Most of the documentation for each function can be found on the user API
88#[cfg_attr(doc, doc = crate::doc_tensor!())]
89#[cfg_attr(not(doc), doc = "`Tensor`")]
90/// struct in the `burn-tensor` crate.
91/// For modules, public functions are often created, which can be used by `burn-core` modules.
92pub trait Backend:
93 BackendTypes
94 + FloatTensorOps<Self>
95 + BoolTensorOps<Self>
96 + IntTensorOps<Self>
97 + ModuleOps<Self>
98 + ActivationOps<Self>
99 + QTensorOps<Self>
100 + TransactionOps<Self>
101 + Clone
102 + Default
103 + Sized
104 + Send
105 + Sync
106 + core::fmt::Debug
107 + 'static
108{
109 /// If autodiff is enabled.
110 fn ad_enabled(_device: &Self::Device) -> bool {
111 false
112 }
113
114 /// Sets the current allocation mode to persistent.
115 #[allow(unused_variables)]
116 fn memory_persistent_allocations<
117 Output: Send,
118 Input: Send,
119 Func: Fn(Input) -> Output + Send,
120 >(
121 device: &Self::Device,
122 input: Input,
123 func: Func,
124 ) -> Output {
125 func(input)
126 }
127
128 /// Manually triggers a memory cleanup on the given device.
129 #[allow(unused_variables)]
130 fn memory_cleanup(device: &Self::Device) {}
131
132 /// Name of the backend.
133 fn name(device: &Self::Device) -> String;
134
135 /// Seeds the backend on the specified device.
136 ///
137 /// There is no guarantee that only the specified device will be seeded, but it is guaranteed
138 /// that at least the specified device will be seeded.
139 ///
140 /// In all cases, this should ensure deterministic execution for a single-threaded program.
141 fn seed(device: &Self::Device, seed: u64);
142
143 /// Sync the backend, ensure that all computation are finished.
144 fn sync(_device: &Self::Device) -> Result<(), ExecutionError> {
145 Ok(())
146 }
147
148 /// Marks the given data as being used as a staging buffer for transfer between CPU and
149 /// accelerators like GPUs.
150 ///
151 /// The given data might be transferred to pinned memory or another format to improve data transfer
152 /// speed.
153 fn staging<'a, Iter>(_data: Iter, _device: &Self::Device)
154 where
155 Iter: Iterator<Item = &'a mut TensorData>,
156 {
157 }
158
159 /// Whether the type is fully supported by the specified device for general operations.
160 ///
161 /// A type is considered supported if it can be used for the full suite of tensor
162 /// operations, including storage, conversion, and basic arithmetic.
163 ///
164 /// Returning `false` does not necessarily mean the device cannot handle the type at all.
165 /// For instance, a device might support a type only for specialized hardware
166 /// acceleration (e.g., matrix multiplication) but lack general arithmetic support. Such
167 /// types should return `false` here as they are not globally supported.
168 fn supports_dtype(device: &Self::Device, dtype: DType) -> bool {
169 Self::dtype_usage(device, dtype).is_superset(DTypeUsage::general())
170 }
171
172 /// Returns the [DTypeUsageSet] for the given [DType] on the specified device.
173 fn dtype_usage(device: &Self::Device, dtype: DType) -> DTypeUsageSet;
174
175 /// Returns the number of devices available on this backend.
176 /// `device` is a reference device used to determine the underlying backend that should be queried.
177 /// A CUDA device will return all devices available to CUDA, a Vulkan device will return all
178 /// devices available to Vulkan, etc.
179 fn device_count(type_id: u16) -> usize;
180}
181
182/// An error that can happen when syncing a device.
183#[derive(Error, Serialize, Deserialize)]
184pub enum ExecutionError {
185 /// A generic error happened during execution.
186 ///
187 /// The backtrace and context information should be included in the reason string.
188 #[error("An error happened during execution\nCaused by:\n {reason}")]
189 WithContext {
190 /// The reason of the error.
191 reason: String,
192 },
193 /// A generic error happened during execution thrown in the Burn project.
194 ///
195 /// The full context isn't captured by the string alone.
196 #[error("An error happened during execution\nCaused by:\n {reason}")]
197 Generic {
198 /// The reason of the error.
199 reason: String,
200 /// The backtrace.
201 #[serde(skip)]
202 backtrace: BackTrace,
203 },
204}
205
206impl core::fmt::Debug for ExecutionError {
207 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
208 f.write_fmt(format_args!("{self}"))
209 }
210}
211
212/// Trait that allows a backend to support autodiff.
213pub trait AutodiffBackend: Backend {
214 /// The inner backend type.
215 type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>;
216
217 /// Gradients type.
218 type Gradients: Send;
219
220 /// Backward pass.
221 ///
222 /// # Arguments
223 ///
224 /// * `tensor` - The tensor is the last node of computational graph where the gradients are computed.
225 ///
226 /// # Returns
227 ///
228 /// The gradients.
229 fn backward(tensor: FloatTensor<Self>) -> Self::Gradients;
230
231 /// Returns the gradients of a tensor.
232 ///
233 /// # Arguments
234 ///
235 /// * `tensor` - The tensor to extract the gradients from.
236 ///
237 /// # Returns
238 ///
239 /// An optional tensor containing the gradient.
240 fn grad(
241 tensor: &FloatTensor<Self>,
242 grads: &Self::Gradients,
243 ) -> Option<FloatTensor<Self::InnerBackend>>;
244
245 /// Pops the gradients of a tensor and returns them.
246 ///
247 /// # Arguments
248 ///
249 /// * `tensor` - The tensor to pop the gradients from.
250 /// * `grads` - The gradients.
251 ///
252 /// # Returns
253 ///
254 /// An optional tensor containing the given gradients.
255 fn grad_remove(
256 tensor: &FloatTensor<Self>,
257 grads: &mut Self::Gradients,
258 ) -> Option<FloatTensor<Self::InnerBackend>>;
259
260 /// Replace the gradients of a tensor with the one provided.
261 ///
262 /// If no gradient existed for the provided tensor, register it.
263 ///
264 /// # Arguments
265 ///
266 /// * `tensor` - The tensor to pop the gradients from.
267 /// * `grads` - The gradients.
268 /// * `grad` - The updated grad tensor.
269 fn grad_replace(
270 tensor: &FloatTensor<Self>,
271 grads: &mut Self::Gradients,
272 grad: FloatTensor<Self::InnerBackend>,
273 );
274
275 /// Returns the tensor with inner backend type.
276 ///
277 /// # Arguments
278 ///
279 /// * `tensor` - The tensor to get the inner backend tensor for.
280 ///
281 /// # Returns
282 ///
283 /// The inner backend tensor.
284 fn inner(tensor: FloatTensor<Self>) -> FloatTensor<Self::InnerBackend>;
285
286 /// Returns the tensor with inner backend type.
287 ///
288 /// # Arguments
289 ///
290 /// * `tensor` - The tensor to get the inner backend tensor for.
291 ///
292 /// # Returns
293 ///
294 /// The inner backend tensor.
295 fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend>;
296
297 /// Returns the tensor with inner backend type.
298 ///
299 /// # Arguments
300 ///
301 /// * `tensor` - The tensor to get the inner backend tensor for.
302 ///
303 /// # Returns
304 ///
305 /// The inner backend tensor.
306 fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend>;
307
308 /// Returns the tensor with inner backend type.
309 ///
310 /// # Arguments
311 ///
312 /// * `tensor` - The tensor to get the inner backend tensor for.
313 ///
314 /// # Returns
315 ///
316 /// The inner backend tensor.
317 fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend>;
318
319 /// Converts the inner backend tensor to the autodiff backend tensor.
320 ///
321 /// # Arguments
322 ///
323 /// * `tensor` - The inner backend tensor to convert.
324 ///
325 ///
326 /// # Returns
327 ///
328 /// The autodiff backend tensor.
329 fn from_inner(tensor: FloatTensor<Self::InnerBackend>) -> FloatTensor<Self>;
330
331 /// Converts the inner backend tensor to the autodiff backend tensor.
332 ///
333 /// # Arguments
334 ///
335 /// * `tensor` - The inner backend tensor to convert.
336 ///
337 ///
338 /// # Returns
339 ///
340 /// The autodiff backend tensor.
341 fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self>;
342
343 /// Converts the inner backend tensor to the autodiff backend tensor.
344 ///
345 /// # Arguments
346 ///
347 /// * `tensor` - The inner backend tensor to convert.
348 ///
349 ///
350 /// # Returns
351 ///
352 /// The autodiff backend tensor.
353 fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self>;
354
355 /// Converts the inner backend tensor to the autodiff backend tensor.
356 ///
357 /// # Arguments
358 ///
359 /// * `tensor` - The inner backend tensor to convert.
360 ///
361 ///
362 /// # Returns
363 ///
364 /// The autodiff backend tensor.
365 fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self>;
366
367 #[cfg(feature = "distributed")]
368 /// Mark the tensor as distributed across multiple devices.
369 /// The gradients will be aggregated during the backward pass.
370 ///
371 /// This function does nothing when distributed training is not available.
372 fn set_distributed_params(
373 tensor: FloatTensor<Self>,
374 _param_id: DistributedParamId,
375 ) -> FloatTensor<Self> {
376 tensor
377 }
378
379 #[cfg(feature = "distributed")]
380 /// Returns the distributed parameters if the tensor was marked as distributed.
381 fn distributed_params(_tensor: &FloatTensor<Self>) -> Option<DistributedParams> {
382 None
383 }
384
385 #[cfg(feature = "distributed")]
386 /// Returns true if the tensor was marked as distributed.
387 fn is_distributed(_tensor: &FloatTensor<Self>) -> bool {
388 false
389 }
390}
391
392/// Describes how a data type can be used on a given device.
393///
394/// A data type may be supported for different classes of operations. Not all
395/// data types that appear in hardware or kernel implementations are suitable
396/// for general-purpose tensor operations.
397#[derive(Debug, EnumSetType)]
398pub enum DTypeUsage {
399 /// The type can be stored in device memory and converted to and from
400 /// other supported data types.
401 Storage,
402 /// The type supports general-purpose arithmetic and common tensor
403 /// operations (e.g. elementwise ops, reductions, etc.).
404 Arithmetic,
405 /// The type is supported by hardware-accelerated execution paths.
406 ///
407 /// This typically indicates support for accelerator-backed compute units (e.g., tensor
408 /// cores executing MMA instructions) for high-performance operations such as matrix
409 /// multiplication and operations that lower to it.
410 ///
411 /// # Notes
412 /// - A type can be both [`Arithmetic`](DTypeUsage::Arithmetic) and
413 /// [`Accelerated`](DTypeUsage::Accelerated) if it supports general-purpose operations
414 /// *and* accelerated paths.
415 /// - If a type is marked as `Accelerated` but not `Arithmetic`, it is not
416 /// suitable for general-purpose tensor operations and may only be used
417 /// in specific accelerated operations.
418 ///
419 /// `Accelerated` is a **flag**, not a detailed descriptor. It does not enumerate which
420 /// operations are accelerated or which accelerator features are available.
421 Accelerated,
422}
423
424/// A set of [DTypeUsage] representing the total capabilities of a data type on a device.
425pub type DTypeUsageSet = EnumSet<DTypeUsage>;
426
427impl DTypeUsage {
428 /// Returns the usage set required for general-purpose tensor support.
429 pub fn general() -> DTypeUsageSet {
430 DTypeUsage::Storage | DTypeUsage::Arithmetic
431 }
432}