1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
use alloc::string::String;
use crate::ops::*;
use crate::tensor::Element;
use super::BackendBridge;
/// This trait defines all types and functions needed for a backend to be used with burn.
///
/// ## Design
///
/// This trait aims to be as unopinionated as possible and allows implementations to define
/// their own types and patterns. Therefore, there are few pre-defined abstractions baked
/// into this trait.
///
/// Backends must define their own tensor types for each data type: `float`, `int`, and `bool`.
/// Since we minimize assumptions, we chose to separate these types, as they are used in
/// different contexts. However, some backends may have a generic tensor type that is used
/// for all data types.
///
/// ### Eager Mode
///
/// Because burn supports dynamic graphs, the backend trait is designed around kernel
/// implementations that can be called without any mutable context or graph. This may not be
/// ideal for backends that want to configure their computational graphs and execute them
/// multiple times.
///
/// To implement this kind of backend, channels could be used to communicate with a backend
/// server thread to build the computation graphs and re-execute the ones that are repeated,
/// with some form of cache. Once that pattern has matured, a graph mode backend trait could
/// be extracted from it, allowing other backends of the same kind to be quickly integrated
/// with burn. This pattern could also be used to create an operation fusion trait, which
/// allows backends to define what kind of graph structures can be fused into one operation.
///
/// ### Multi-Threaded
///
/// Backend tensor types are all `Clone` + `Send`, which allows them to be safely
/// sent between threads. It is recommended to wrap tensors with [Arc](alloc::sync::Arc),
/// which avoids copying the tensor's buffer. Note that it is still possible to mutate and
/// reuse tensors' buffer without locking; see the next section on the Mutable API.
///
/// ### Mutable API
///
/// There is no mutable or inplace operation API to implement, but that does not mean that
/// backends cannot support them. Using [try_unwrap](alloc::sync::Arc::try_unwrap) and
/// [get_mut](alloc::sync::Arc::get_mut) allows backends to have access to an owned or mutable
/// reference to their tensor buffer data structure if the tensor is not shared. In that case,
/// backends can dispatch to their owned inplace operations for better performance.
///
/// ## Documentation
///
/// Most of the documentation for each function can be found on the user API [tensor struct](crate::Tensor).
/// For modules, public functions are often created, which can be used by `burn-core` modules.
pub trait Backend:
FloatTensorOps<Self>
+ BoolTensorOps<Self>
+ IntTensorOps<Self>
+ ModuleOps<Self>
+ ActivationOps<Self>
+ Clone
+ Sized
+ Default
+ Send
+ Sync
+ core::fmt::Debug
+ 'static
{
/// Device type.
type Device: Clone + Default + PartialEq + core::fmt::Debug + Send + Sync;
/// A bridge that can cast tensors to full precision.
type FullPrecisionBridge: BackendBridge<Self> + 'static;
/// Tensor primitive to be used for all float operations.
type FloatTensorPrimitive<const D: usize>: Clone + Send + 'static + core::fmt::Debug;
/// Float element type.
type FloatElem: Element;
/// Tensor primitive to be used for all int operations.
type IntTensorPrimitive<const D: usize>: Clone + Send + 'static + core::fmt::Debug;
/// Int element type.
type IntElem: Element;
/// Tensor primitive to be used for all bool operations.
type BoolTensorPrimitive<const D: usize>: Clone + Send + 'static + core::fmt::Debug;
/// If autodiff is enabled.
fn ad_enabled() -> bool {
false
}
/// Name of the backend.
fn name() -> String;
/// Seed the backend.
fn seed(seed: u64);
/// Sync the backend, ensure that all computation are finished.
fn sync(_device: &Self::Device) {}
}
/// Trait that allows a backend to support autodiff.
pub trait AutodiffBackend: Backend {
/// The inner backend type.
type InnerBackend: Backend<
Device = Self::Device,
FloatElem = Self::FloatElem,
IntElem = Self::IntElem,
>;
/// Gradients type.
type Gradients: Send;
/// Backward pass.
///
/// # Arguments
///
/// * `tensor` - The tensor is the last node of computational graph where the gradients are computed.
///
/// # Returns
///
/// The gradients.
fn backward<const D: usize>(tensor: FloatTensor<Self, D>) -> Self::Gradients;
/// Returns the gradients of a tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to extract the gradients from.
///
/// # Returns
///
/// An optional tensor containing the gradient.
fn grad<const D: usize>(
tensor: &FloatTensor<Self, D>,
grads: &Self::Gradients,
) -> Option<FloatTensor<Self::InnerBackend, D>>;
/// Pops the gradients of a tensor and returns them.
///
/// # Arguments
///
/// * `tensor` - The tensor to pop the gradients from.
/// * `grads` - The gradients.
///
/// # Returns
///
/// An optional tensor containing the given gradients.
fn grad_remove<const D: usize>(
tensor: &FloatTensor<Self, D>,
grads: &mut Self::Gradients,
) -> Option<FloatTensor<Self::InnerBackend, D>>;
/// Replace the gradients of a tensor with the one provided.
///
/// If no gradient existed for the provided tensor, register it.
///
/// # Arguments
///
/// * `tensor` - The tensor to pop the gradients from.
/// * `grads` - The gradients.
/// * `grad` - The updated grad tensor.
fn grad_replace<const D: usize>(
tensor: &FloatTensor<Self, D>,
grads: &mut Self::Gradients,
grad: FloatTensor<Self::InnerBackend, D>,
);
/// Returns the tensor with inner backend type.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the inner backend tensor for.
///
/// # Returns
///
/// The inner backend tensor.
fn inner<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self::InnerBackend, D>;
/// Returns the tensor with inner backend type.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the inner backend tensor for.
///
/// # Returns
///
/// The inner backend tensor.
fn int_inner<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self::InnerBackend, D>;
/// Returns the tensor with inner backend type.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the inner backend tensor for.
///
/// # Returns
///
/// The inner backend tensor.
fn bool_inner<const D: usize>(tensor: BoolTensor<Self, D>)
-> BoolTensor<Self::InnerBackend, D>;
/// Converts the inner backend tensor to the autodiff backend tensor.
///
/// # Arguments
///
/// * `tensor` - The inner backend tensor to convert.
///
///
/// # Returns
///
/// The autodiff backend tensor.
fn from_inner<const D: usize>(
tensor: FloatTensor<Self::InnerBackend, D>,
) -> FloatTensor<Self, D>;
/// Converts the inner backend tensor to the autodiff backend tensor.
///
/// # Arguments
///
/// * `tensor` - The inner backend tensor to convert.
///
///
/// # Returns
///
/// The autodiff backend tensor.
fn int_from_inner<const D: usize>(
tensor: IntTensor<Self::InnerBackend, D>,
) -> IntTensor<Self, D>;
/// Converts the inner backend tensor to the autodiff backend tensor.
///
/// # Arguments
///
/// * `tensor` - The inner backend tensor to convert.
///
///
/// # Returns
///
/// The autodiff backend tensor.
fn bool_from_inner<const D: usize>(
tensor: BoolTensor<Self::InnerBackend, D>,
) -> BoolTensor<Self, D>;
}