Skip to main content

Backend

Trait Backend 

Source
pub trait Backend:
    Sized
    + BackendTypes
    + FloatTensorOps<Self>
    + BoolTensorOps<Self>
    + IntTensorOps<Self>
    + ModuleOps<Self>
    + ActivationOps<Self>
    + QTensorOps<Self>
    + TransactionOps<Self>
    + Clone
    + Default
    + Send
    + Sync
    + Debug
    + 'static {
    // Required methods
    fn name(device: &Self::Device) -> String;
    fn seed(device: &Self::Device, seed: u64);
    fn dtype_usage(device: &Self::Device, dtype: DType) -> EnumSet<DTypeUsage>;
    fn device_count(type_id: u16) -> usize;

    // Provided methods
    fn ad_enabled(_device: &Self::Device) -> bool { ... }
    fn memory_persistent_allocations<Output, Input, Func>(
        device: &Self::Device,
        input: Input,
        func: Func,
    ) -> Output
       where Output: Send,
             Input: Send,
             Func: Fn(Input) -> Output + Send { ... }
    fn memory_cleanup(device: &Self::Device) { ... }
    fn sync(_device: &Self::Device) -> Result<(), ExecutionError> { ... }
    fn staging<'a, Iter>(_data: Iter, _device: &Self::Device)
       where Iter: Iterator<Item = &'a mut TensorData> { ... }
    fn supports_dtype(device: &Self::Device, dtype: DType) -> bool { ... }
}
Expand description

This trait defines all types and functions needed for a backend to be used with burn.

§Design

This trait aims to be as unopinionated as possible and allows implementations to define their own types and patterns. Therefore, there are few pre-defined abstractions baked into this trait.

Backends must define their own tensor types for each data type: float, int, and bool. Since we minimize assumptions, we chose to separate these types, as they are used in different contexts. However, some backends may have a generic tensor type that is used for all data types.

§Eager Mode

Because burn supports dynamic graphs, the backend trait is designed around kernel implementations that can be called without any mutable context or graph. This may not be ideal for backends that want to configure their computational graphs and execute them multiple times.

To implement this kind of backend, channels could be used to communicate with a backend server thread to build the computation graphs and re-execute the ones that are repeated, with some form of cache. Once that pattern has matured, a graph mode backend trait could be extracted from it, allowing other backends of the same kind to be quickly integrated with burn. This pattern could also be used to create an operation fusion trait, which allows backends to define what kind of graph structures can be fused into one operation.

§Multi-Threaded

Backend tensor types are all Clone + Send, which allows them to be safely sent between threads. It is recommended to wrap tensors with Arc, which avoids copying the tensor’s buffer. Note that it is still possible to mutate and reuse tensors’ buffer without locking; see the next section on the Mutable API.

§Mutable API

There is no mutable or inplace operation API to implement, but that does not mean that backends cannot support them. Using try_unwrap and get_mut allows backends to have access to an owned or mutable reference to their tensor buffer data structure if the tensor is not shared. In that case, backends can dispatch to their owned inplace operations for better performance.

§Documentation

Most of the documentation for each function can be found on the user API Tensor struct in the burn-tensor crate. For modules, public functions are often created, which can be used by burn-core modules.

Required Methods§

Source

fn name(device: &Self::Device) -> String

Name of the backend.

Source

fn seed(device: &Self::Device, seed: u64)

Seeds the backend on the specified device.

There is no guarantee that only the specified device will be seeded, but it is guaranteed that at least the specified device will be seeded.

In all cases, this should ensure deterministic execution for a single-threaded program.

Source

fn dtype_usage(device: &Self::Device, dtype: DType) -> EnumSet<DTypeUsage>

Returns the DTypeUsageSet for the given DType on the specified device.

Source

fn device_count(type_id: u16) -> usize

Returns the number of devices available on this backend. device is a reference device used to determine the underlying backend that should be queried. A CUDA device will return all devices available to CUDA, a Vulkan device will return all devices available to Vulkan, etc.

Provided Methods§

Source

fn ad_enabled(_device: &Self::Device) -> bool

If autodiff is enabled.

Source

fn memory_persistent_allocations<Output, Input, Func>( device: &Self::Device, input: Input, func: Func, ) -> Output
where Output: Send, Input: Send, Func: Fn(Input) -> Output + Send,

Sets the current allocation mode to persistent.

Source

fn memory_cleanup(device: &Self::Device)

Manually triggers a memory cleanup on the given device.

Source

fn sync(_device: &Self::Device) -> Result<(), ExecutionError>

Sync the backend, ensure that all computation are finished.

Source

fn staging<'a, Iter>(_data: Iter, _device: &Self::Device)
where Iter: Iterator<Item = &'a mut TensorData>,

Marks the given data as being used as a staging buffer for transfer between CPU and accelerators like GPUs.

The given data might be transferred to pinned memory or another format to improve data transfer speed.

Source

fn supports_dtype(device: &Self::Device, dtype: DType) -> bool

Whether the type is fully supported by the specified device for general operations.

A type is considered supported if it can be used for the full suite of tensor operations, including storage, conversion, and basic arithmetic.

Returning false does not necessarily mean the device cannot handle the type at all. For instance, a device might support a type only for specialized hardware acceleration (e.g., matrix multiplication) but lack general arithmetic support. Such types should return false here as they are not globally supported.

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementors§