Trait burn::prelude::Backend

source ·
pub trait Backend: Sized + FloatTensorOps<Self> + BoolTensorOps<Self> + IntTensorOps<Self> + ModuleOps<Self> + ActivationOps<Self> + Clone + Default + Send + Sync + Debug + 'static {
    type Device: Clone + Default + PartialEq + Debug + Send + Sync;
    type FullPrecisionBridge: BackendBridge<Self> + 'static;
    type FloatTensorPrimitive<const D: usize>: Clone + Send + 'static + Debug;
    type FloatElem: Element;
    type IntTensorPrimitive<const D: usize>: Clone + Send + 'static + Debug;
    type IntElem: Element;
    type BoolTensorPrimitive<const D: usize>: Clone + Send + 'static + Debug;

    // Required methods
    fn name() -> String;
    fn seed(seed: u64);

    // Provided methods
    fn ad_enabled() -> bool { ... }
    fn sync(_device: &Self::Device) { ... }
}
Expand description

This trait defines all types and functions needed for a backend to be used with burn.

§Design

This trait aims to be as unopinionated as possible and allows implementations to define their own types and patterns. Therefore, there are few pre-defined abstractions baked into this trait.

Backends must define their own tensor types for each data type: float, int, and bool. Since we minimize assumptions, we chose to separate these types, as they are used in different contexts. However, some backends may have a generic tensor type that is used for all data types.

§Eager Mode

Because burn supports dynamic graphs, the backend trait is designed around kernel implementations that can be called without any mutable context or graph. This may not be ideal for backends that want to configure their computational graphs and execute them multiple times.

To implement this kind of backend, channels could be used to communicate with a backend server thread to build the computation graphs and re-execute the ones that are repeated, with some form of cache. Once that pattern has matured, a graph mode backend trait could be extracted from it, allowing other backends of the same kind to be quickly integrated with burn. This pattern could also be used to create an operation fusion trait, which allows backends to define what kind of graph structures can be fused into one operation.

§Multi-Threaded

Backend tensor types are all Clone + Send, which allows them to be safely sent between threads. It is recommended to wrap tensors with Arc, which avoids copying the tensor’s buffer. Note that it is still possible to mutate and reuse tensors’ buffer without locking; see the next section on the Mutable API.

§Mutable API

There is no mutable or inplace operation API to implement, but that does not mean that backends cannot support them. Using try_unwrap and get_mut allows backends to have access to an owned or mutable reference to their tensor buffer data structure if the tensor is not shared. In that case, backends can dispatch to their owned inplace operations for better performance.

§Documentation

Most of the documentation for each function can be found on the user API tensor struct. For modules, public functions are often created, which can be used by burn-core modules.

Required Associated Types§

source

type Device: Clone + Default + PartialEq + Debug + Send + Sync

Device type.

source

type FullPrecisionBridge: BackendBridge<Self> + 'static

A bridge that can cast tensors to full precision.

source

type FloatTensorPrimitive<const D: usize>: Clone + Send + 'static + Debug

Tensor primitive to be used for all float operations.

source

type FloatElem: Element

Float element type.

source

type IntTensorPrimitive<const D: usize>: Clone + Send + 'static + Debug

Tensor primitive to be used for all int operations.

source

type IntElem: Element

Int element type.

source

type BoolTensorPrimitive<const D: usize>: Clone + Send + 'static + Debug

Tensor primitive to be used for all bool operations.

Required Methods§

source

fn name() -> String

Name of the backend.

source

fn seed(seed: u64)

Seed the backend.

Provided Methods§

source

fn ad_enabled() -> bool

If autodiff is enabled.

source

fn sync(_device: &Self::Device)

Sync the backend, ensure that all computation are finished.

Object Safety§

This trait is not object safe.

Implementations on Foreign Types§

source§

impl<B> Backend for Fusion<B>
where B: FusionBackend,

Implementors§

source§

impl<B, C> Backend for Autodiff<B, C>

source§

impl<E> Backend for LibTorch<E>
where E: TchElement,

source§

impl<E> Backend for NdArray<E>

source§

impl<F, I> Backend for Candle<F, I>
where F: FloatCandleElement, I: IntCandleElement,

source§

impl<R> Backend for JitBackend<R>
where R: Runtime,