burn-backend 0.21.0

Core backend interfaces and data structures for executing tensor operations in Burn.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
use burn_std::DType;
pub use burn_std::backtrace::BackTrace;

use alloc::string::String;
use enumset::{EnumSet, EnumSetType};
use serde::{Deserialize, Serialize};
use thiserror::Error;

use crate::element::Element;
use crate::ops::*;
use crate::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
use crate::{QTensorPrimitive, TensorData, TensorMetadata};

#[cfg(feature = "distributed")]
use crate::distributed::{DistributedParamId, DistributedParams};

use super::DeviceOps;

/// The mapping of types used by Backend and traits like Numeric, BasicOps
pub trait BackendTypes {
    /// Device type.
    type Device: DeviceOps;

    /// Tensor primitive to be used for all float operations.
    type FloatTensorPrimitive: TensorMetadata + 'static;
    /// Default float element type.
    type FloatElem: Element;

    /// Tensor primitive to be used for all int operations.
    type IntTensorPrimitive: TensorMetadata + 'static;
    /// Int element type.
    type IntElem: Element;

    /// Tensor primitive to be used for all bool operations.
    type BoolTensorPrimitive: TensorMetadata + 'static;
    /// Tensor primitive to be used for all bool operations.
    type BoolElem: Element;

    /// Tensor primitive to be used for all quantized operations.
    type QuantizedTensorPrimitive: TensorMetadata + QTensorPrimitive + 'static;
}

/// This trait defines all types and functions needed for a backend to be used with burn.
///
/// ## Design
///
/// This trait aims to be as unopinionated as possible and allows implementations to define
/// their own types and patterns. Therefore, there are few pre-defined abstractions baked
/// into this trait.
///
/// Backends must define their own tensor types for each data type: `float`, `int`, and `bool`.
/// Since we minimize assumptions, we chose to separate these types, as they are used in
/// different contexts. However, some backends may have a generic tensor type that is used
/// for all data types.
///
/// ### Eager Mode
///
/// Because burn supports dynamic graphs, the backend trait is designed around kernel
/// implementations that can be called without any mutable context or graph. This may not be
/// ideal for backends that want to configure their computational graphs and execute them
/// multiple times.
///
/// To implement this kind of backend, channels could be used to communicate with a backend
/// server thread to build the computation graphs and re-execute the ones that are repeated,
/// with some form of cache. Once that pattern has matured, a graph mode backend trait could
/// be extracted from it, allowing other backends of the same kind to be quickly integrated
/// with burn. This pattern could also be used to create an operation fusion trait, which
/// allows backends to define what kind of graph structures can be fused into one operation.
///
/// ### Multi-Threaded
///
/// Backend tensor types are all `Clone` + `Send`, which allows them to be safely
/// sent between threads. It is recommended to wrap tensors with [Arc](alloc::sync::Arc),
/// which avoids copying the tensor's buffer. Note that it is still possible to mutate and
/// reuse tensors' buffer without locking; see the next section on the Mutable API.
///
/// ### Mutable API
///
/// There is no mutable or inplace operation API to implement, but that does not mean that
/// backends cannot support them. Using [try_unwrap](alloc::sync::Arc::try_unwrap) and
/// [get_mut](alloc::sync::Arc::get_mut) allows backends to have access to an owned or mutable
/// reference to their tensor buffer data structure if the tensor is not shared. In that case,
/// backends can dispatch to their owned inplace operations for better performance.
///
/// ## Documentation
///
/// Most of the documentation for each function can be found on the user API
#[cfg_attr(doc, doc = crate::doc_tensor!())]
#[cfg_attr(not(doc), doc = "`Tensor`")]
/// struct in the `burn-tensor` crate.
/// For modules, public functions are often created, which can be used by `burn-core` modules.
pub trait Backend:
    BackendTypes
    + FloatTensorOps<Self>
    + BoolTensorOps<Self>
    + IntTensorOps<Self>
    + ModuleOps<Self>
    + ActivationOps<Self>
    + QTensorOps<Self>
    + TransactionOps<Self>
    + Clone
    + Default
    + Sized
    + Send
    + Sync
    + core::fmt::Debug
    + 'static
{
    /// If autodiff is enabled.
    fn ad_enabled(_device: &Self::Device) -> bool {
        false
    }

    /// Sets the current allocation mode to persistent.
    #[allow(unused_variables)]
    fn memory_persistent_allocations<
        Output: Send,
        Input: Send,
        Func: Fn(Input) -> Output + Send,
    >(
        device: &Self::Device,
        input: Input,
        func: Func,
    ) -> Output {
        func(input)
    }

    /// Manually triggers a memory cleanup on the given device.
    #[allow(unused_variables)]
    fn memory_cleanup(device: &Self::Device) {}

    /// Name of the backend.
    fn name(device: &Self::Device) -> String;

    /// Seeds the backend on the specified device.
    ///
    /// There is no guarantee that only the specified device will be seeded, but it is guaranteed
    /// that at least the specified device will be seeded.
    ///
    /// In all cases, this should ensure deterministic execution for a single-threaded program.
    fn seed(device: &Self::Device, seed: u64);

    /// Sync the backend, ensure that all computation are finished.
    fn sync(_device: &Self::Device) -> Result<(), ExecutionError> {
        Ok(())
    }

    /// Marks the given data as being used as a staging buffer for transfer between CPU and
    /// accelerators like GPUs.
    ///
    /// The given data might be transferred to pinned memory or another format to improve data transfer
    /// speed.
    fn staging<'a, Iter>(_data: Iter, _device: &Self::Device)
    where
        Iter: Iterator<Item = &'a mut TensorData>,
    {
    }

    /// Whether the type is fully supported by the specified device for general operations.
    ///
    /// A type is considered supported if it can be used for the full suite of tensor
    /// operations, including storage, conversion, and basic arithmetic.
    ///
    /// Returning `false` does not necessarily mean the device cannot handle the type at all.
    /// For instance, a device might support a type only for specialized hardware
    /// acceleration (e.g., matrix multiplication) but lack general arithmetic support. Such
    /// types should return `false` here as they are not globally supported.
    fn supports_dtype(device: &Self::Device, dtype: DType) -> bool {
        Self::dtype_usage(device, dtype).is_superset(DTypeUsage::general())
    }

    /// Returns the [DTypeUsageSet] for the given [DType] on the specified device.
    fn dtype_usage(device: &Self::Device, dtype: DType) -> DTypeUsageSet;

    /// Returns the number of devices available on this backend.
    /// `device` is a reference device used to determine the underlying backend that should be queried.
    /// A CUDA device will return all devices available to CUDA, a Vulkan device will return all
    /// devices available to Vulkan, etc.
    fn device_count(type_id: u16) -> usize;
}

/// An error that can happen when syncing a device.
#[derive(Error, Serialize, Deserialize)]
pub enum ExecutionError {
    /// A generic error happened during execution.
    ///
    /// The backtrace and context information should be included in the reason string.
    #[error("An error happened during execution\nCaused by:\n  {reason}")]
    WithContext {
        /// The reason of the error.
        reason: String,
    },
    /// A generic error happened during execution thrown in the Burn project.
    ///
    /// The full context isn't captured by the string alone.
    #[error("An error happened during execution\nCaused by:\n  {reason}")]
    Generic {
        /// The reason of the error.
        reason: String,
        /// The backtrace.
        #[serde(skip)]
        backtrace: BackTrace,
    },
}

impl core::fmt::Debug for ExecutionError {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        f.write_fmt(format_args!("{self}"))
    }
}

/// Trait that allows a backend to support autodiff.
pub trait AutodiffBackend: Backend {
    /// The inner backend type.
    type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>;

    /// Gradients type.
    type Gradients: Send;

    /// Backward pass.
    ///
    /// # Arguments
    ///
    /// * `tensor` - The tensor is the last node of computational graph where the gradients are computed.
    ///
    /// # Returns
    ///
    /// The gradients.
    fn backward(tensor: FloatTensor<Self>) -> Self::Gradients;

    /// Returns the gradients of a tensor.
    ///
    /// # Arguments
    ///
    /// * `tensor` - The tensor to extract the gradients from.
    ///
    /// # Returns
    ///
    /// An optional tensor containing the gradient.
    fn grad(
        tensor: &FloatTensor<Self>,
        grads: &Self::Gradients,
    ) -> Option<FloatTensor<Self::InnerBackend>>;

    /// Pops the gradients of a tensor and returns them.
    ///
    /// # Arguments
    ///
    /// * `tensor` - The tensor to pop the gradients from.
    /// * `grads` - The gradients.
    ///
    /// # Returns
    ///
    /// An optional tensor containing the given gradients.
    fn grad_remove(
        tensor: &FloatTensor<Self>,
        grads: &mut Self::Gradients,
    ) -> Option<FloatTensor<Self::InnerBackend>>;

    /// Replace the gradients of a tensor with the one provided.
    ///
    /// If no gradient existed for the provided tensor, register it.
    ///
    /// # Arguments
    ///
    /// * `tensor` - The tensor to pop the gradients from.
    /// * `grads` - The gradients.
    /// * `grad` - The updated grad tensor.
    fn grad_replace(
        tensor: &FloatTensor<Self>,
        grads: &mut Self::Gradients,
        grad: FloatTensor<Self::InnerBackend>,
    );

    /// Returns the tensor with inner backend type.
    ///
    /// # Arguments
    ///
    /// * `tensor` - The tensor to get the inner backend tensor for.
    ///
    /// # Returns
    ///
    /// The inner backend tensor.
    fn inner(tensor: FloatTensor<Self>) -> FloatTensor<Self::InnerBackend>;

    /// Returns the tensor with inner backend type.
    ///
    /// # Arguments
    ///
    /// * `tensor` - The tensor to get the inner backend tensor for.
    ///
    /// # Returns
    ///
    /// The inner backend tensor.
    fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend>;

    /// Returns the tensor with inner backend type.
    ///
    /// # Arguments
    ///
    /// * `tensor` - The tensor to get the inner backend tensor for.
    ///
    /// # Returns
    ///
    /// The inner backend tensor.
    fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend>;

    /// Returns the tensor with inner backend type.
    ///
    /// # Arguments
    ///
    /// * `tensor` - The tensor to get the inner backend tensor for.
    ///
    /// # Returns
    ///
    /// The inner backend tensor.
    fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend>;

    /// Converts the inner backend tensor to the autodiff backend tensor.
    ///
    /// # Arguments
    ///
    /// * `tensor` - The inner backend tensor to convert.
    ///
    ///
    /// # Returns
    ///
    /// The autodiff backend tensor.
    fn from_inner(tensor: FloatTensor<Self::InnerBackend>) -> FloatTensor<Self>;

    /// Converts the inner backend tensor to the autodiff backend tensor.
    ///
    /// # Arguments
    ///
    /// * `tensor` - The inner backend tensor to convert.
    ///
    ///
    /// # Returns
    ///
    /// The autodiff backend tensor.
    fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self>;

    /// Converts the inner backend tensor to the autodiff backend tensor.
    ///
    /// # Arguments
    ///
    /// * `tensor` - The inner backend tensor to convert.
    ///
    ///
    /// # Returns
    ///
    /// The autodiff backend tensor.
    fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self>;

    /// Converts the inner backend tensor to the autodiff backend tensor.
    ///
    /// # Arguments
    ///
    /// * `tensor` - The inner backend tensor to convert.
    ///
    ///
    /// # Returns
    ///
    /// The autodiff backend tensor.
    fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self>;

    #[cfg(feature = "distributed")]
    /// Mark the tensor as distributed across multiple devices.
    /// The gradients will be aggregated during the backward pass.
    ///
    /// This function does nothing when distributed training is not available.
    fn set_distributed_params(
        tensor: FloatTensor<Self>,
        _param_id: DistributedParamId,
    ) -> FloatTensor<Self> {
        tensor
    }

    #[cfg(feature = "distributed")]
    /// Returns the distributed parameters if the tensor was marked as distributed.
    fn distributed_params(_tensor: &FloatTensor<Self>) -> Option<DistributedParams> {
        None
    }

    #[cfg(feature = "distributed")]
    /// Returns true if the tensor was marked as distributed.
    fn is_distributed(_tensor: &FloatTensor<Self>) -> bool {
        false
    }
}

/// Describes how a data type can be used on a given device.
///
/// A data type may be supported for different classes of operations. Not all
/// data types that appear in hardware or kernel implementations are suitable
/// for general-purpose tensor operations.
#[derive(Debug, EnumSetType)]
pub enum DTypeUsage {
    /// The type can be stored in device memory and converted to and from
    /// other supported data types.
    Storage,
    /// The type supports general-purpose arithmetic and common tensor
    /// operations (e.g. elementwise ops, reductions, etc.).
    Arithmetic,
    /// The type is supported by hardware-accelerated execution paths.
    ///
    /// This typically indicates support for accelerator-backed compute units (e.g., tensor
    /// cores executing MMA instructions) for high-performance operations such as matrix
    /// multiplication and operations that lower to it.
    ///
    /// # Notes
    /// - A type can be both [`Arithmetic`](DTypeUsage::Arithmetic) and
    ///   [`Accelerated`](DTypeUsage::Accelerated) if it supports general-purpose operations
    ///   *and* accelerated paths.
    /// - If a type is marked as `Accelerated` but not `Arithmetic`, it is not
    ///   suitable for general-purpose tensor operations and may only be used
    ///   in specific accelerated operations.
    ///
    /// `Accelerated` is a **flag**, not a detailed descriptor. It does not enumerate which
    /// operations are accelerated or which accelerator features are available.
    Accelerated,
}

/// A set of [DTypeUsage] representing the total capabilities of a data type on a device.
pub type DTypeUsageSet = EnumSet<DTypeUsage>;

impl DTypeUsage {
    /// Returns the usage set required for general-purpose tensor support.
    pub fn general() -> DTypeUsageSet {
        DTypeUsage::Storage | DTypeUsage::Arithmetic
    }
}