pub trait Backend:
Sized
+ BackendTypes
+ FloatTensorOps<Self>
+ BoolTensorOps<Self>
+ IntTensorOps<Self>
+ ModuleOps<Self>
+ ActivationOps<Self>
+ QTensorOps<Self>
+ TransactionOps<Self>
+ Clone
+ Default
+ Send
+ Sync
+ Debug
+ 'static {
// Required methods
fn name(device: &Self::Device) -> String;
fn seed(device: &Self::Device, seed: u64);
fn dtype_usage(device: &Self::Device, dtype: DType) -> EnumSet<DTypeUsage>;
fn device_count(type_id: u16) -> usize;
// Provided methods
fn ad_enabled(_device: &Self::Device) -> bool { ... }
fn memory_persistent_allocations<Output, Input, Func>(
device: &Self::Device,
input: Input,
func: Func,
) -> Output
where Output: Send,
Input: Send,
Func: Fn(Input) -> Output + Send { ... }
fn memory_cleanup(device: &Self::Device) { ... }
fn sync(_device: &Self::Device) -> Result<(), ExecutionError> { ... }
fn staging<'a, Iter>(_data: Iter, _device: &Self::Device)
where Iter: Iterator<Item = &'a mut TensorData> { ... }
fn supports_dtype(device: &Self::Device, dtype: DType) -> bool { ... }
}Expand description
This trait defines all types and functions needed for a backend to be used with burn.
§Design
This trait aims to be as unopinionated as possible and allows implementations to define their own types and patterns. Therefore, there are few pre-defined abstractions baked into this trait.
Backends must define their own tensor types for each data type: float, int, and bool.
Since we minimize assumptions, we chose to separate these types, as they are used in
different contexts. However, some backends may have a generic tensor type that is used
for all data types.
§Eager Mode
Because burn supports dynamic graphs, the backend trait is designed around kernel implementations that can be called without any mutable context or graph. This may not be ideal for backends that want to configure their computational graphs and execute them multiple times.
To implement this kind of backend, channels could be used to communicate with a backend server thread to build the computation graphs and re-execute the ones that are repeated, with some form of cache. Once that pattern has matured, a graph mode backend trait could be extracted from it, allowing other backends of the same kind to be quickly integrated with burn. This pattern could also be used to create an operation fusion trait, which allows backends to define what kind of graph structures can be fused into one operation.
§Multi-Threaded
Backend tensor types are all Clone + Send, which allows them to be safely
sent between threads. It is recommended to wrap tensors with Arc,
which avoids copying the tensor’s buffer. Note that it is still possible to mutate and
reuse tensors’ buffer without locking; see the next section on the Mutable API.
§Mutable API
There is no mutable or inplace operation API to implement, but that does not mean that backends cannot support them. Using try_unwrap and get_mut allows backends to have access to an owned or mutable reference to their tensor buffer data structure if the tensor is not shared. In that case, backends can dispatch to their owned inplace operations for better performance.
§Documentation
Most of the documentation for each function can be found on the user API
Tensor
struct in the burn-tensor crate.
For modules, public functions are often created, which can be used by burn-core modules.
Required Methods§
Sourcefn seed(device: &Self::Device, seed: u64)
fn seed(device: &Self::Device, seed: u64)
Seeds the backend on the specified device.
There is no guarantee that only the specified device will be seeded, but it is guaranteed that at least the specified device will be seeded.
In all cases, this should ensure deterministic execution for a single-threaded program.
Sourcefn dtype_usage(device: &Self::Device, dtype: DType) -> EnumSet<DTypeUsage>
fn dtype_usage(device: &Self::Device, dtype: DType) -> EnumSet<DTypeUsage>
Returns the DTypeUsageSet for the given DType on the specified device.
Sourcefn device_count(type_id: u16) -> usize
fn device_count(type_id: u16) -> usize
Returns the number of devices available on this backend.
device is a reference device used to determine the underlying backend that should be queried.
A CUDA device will return all devices available to CUDA, a Vulkan device will return all
devices available to Vulkan, etc.
Provided Methods§
Sourcefn ad_enabled(_device: &Self::Device) -> bool
fn ad_enabled(_device: &Self::Device) -> bool
If autodiff is enabled.
Sourcefn memory_persistent_allocations<Output, Input, Func>(
device: &Self::Device,
input: Input,
func: Func,
) -> Output
fn memory_persistent_allocations<Output, Input, Func>( device: &Self::Device, input: Input, func: Func, ) -> Output
Sets the current allocation mode to persistent.
Sourcefn memory_cleanup(device: &Self::Device)
fn memory_cleanup(device: &Self::Device)
Manually triggers a memory cleanup on the given device.
Sourcefn sync(_device: &Self::Device) -> Result<(), ExecutionError>
fn sync(_device: &Self::Device) -> Result<(), ExecutionError>
Sync the backend, ensure that all computation are finished.
Sourcefn staging<'a, Iter>(_data: Iter, _device: &Self::Device)where
Iter: Iterator<Item = &'a mut TensorData>,
fn staging<'a, Iter>(_data: Iter, _device: &Self::Device)where
Iter: Iterator<Item = &'a mut TensorData>,
Marks the given data as being used as a staging buffer for transfer between CPU and accelerators like GPUs.
The given data might be transferred to pinned memory or another format to improve data transfer speed.
Sourcefn supports_dtype(device: &Self::Device, dtype: DType) -> bool
fn supports_dtype(device: &Self::Device, dtype: DType) -> bool
Whether the type is fully supported by the specified device for general operations.
A type is considered supported if it can be used for the full suite of tensor operations, including storage, conversion, and basic arithmetic.
Returning false does not necessarily mean the device cannot handle the type at all.
For instance, a device might support a type only for specialized hardware
acceleration (e.g., matrix multiplication) but lack general arithmetic support. Such
types should return false here as they are not globally supported.
Dyn Compatibility§
This trait is not dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.