burn_tensor/tensor/backend/
base.rs

1use alloc::string::String;
2
3use crate::TensorMetadata;
4use crate::tensor::Element;
5use crate::{ops::*, quantization::QTensorPrimitive};
6
7use super::DeviceOps;
8
9/// This trait defines all types and functions needed for a backend to be used with burn.
10///
11/// ## Design
12///
13/// This trait aims to be as unopinionated as possible and allows implementations to define
14/// their own types and patterns. Therefore, there are few pre-defined abstractions baked
15/// into this trait.
16///
17/// Backends must define their own tensor types for each data type: `float`, `int`, and `bool`.
18/// Since we minimize assumptions, we chose to separate these types, as they are used in
19/// different contexts. However, some backends may have a generic tensor type that is used
20/// for all data types.
21///
22/// ### Eager Mode
23///
24/// Because burn supports dynamic graphs, the backend trait is designed around kernel
25/// implementations that can be called without any mutable context or graph. This may not be
26/// ideal for backends that want to configure their computational graphs and execute them
27/// multiple times.
28///
29/// To implement this kind of backend, channels could be used to communicate with a backend
30/// server thread to build the computation graphs and re-execute the ones that are repeated,
31/// with some form of cache. Once that pattern has matured, a graph mode backend trait could
32/// be extracted from it, allowing other backends of the same kind to be quickly integrated
33/// with burn. This pattern could also be used to create an operation fusion trait, which
34/// allows backends to define what kind of graph structures can be fused into one operation.
35///
36/// ### Multi-Threaded
37///
38/// Backend tensor types are all `Clone` + `Send`, which allows them to be safely
39/// sent between threads. It is recommended to wrap tensors with [Arc](alloc::sync::Arc),
40/// which avoids copying the tensor's buffer. Note that it is still possible to mutate and
41/// reuse tensors' buffer without locking; see the next section on the Mutable API.
42///
43/// ### Mutable API
44///
45/// There is no mutable or inplace operation API to implement, but that does not mean that
46/// backends cannot support them. Using [try_unwrap](alloc::sync::Arc::try_unwrap) and
47/// [get_mut](alloc::sync::Arc::get_mut) allows backends to have access to an owned or mutable
48/// reference to their tensor buffer data structure if the tensor is not shared. In that case,
49/// backends can dispatch to their owned inplace operations for better performance.
50///
51/// ## Documentation
52///
53/// Most of the documentation for each function can be found on the user API [tensor struct](crate::Tensor).
54/// For modules, public functions are often created, which can be used by `burn-core` modules.
55pub trait Backend:
56    FloatTensorOps<Self>
57    + BoolTensorOps<Self>
58    + IntTensorOps<Self>
59    + ModuleOps<Self>
60    + ActivationOps<Self>
61    + QTensorOps<Self>
62    + TransactionOps<Self>
63    + Clone
64    + Default
65    + Sized
66    + Send
67    + Sync
68    + core::fmt::Debug
69    + 'static
70{
71    /// Device type.
72    type Device: DeviceOps;
73
74    /// Tensor primitive to be used for all float operations.
75    type FloatTensorPrimitive: TensorMetadata + 'static;
76    /// Default float element type.
77    type FloatElem: Element;
78
79    /// Tensor primitive to be used for all int operations.
80    type IntTensorPrimitive: TensorMetadata + 'static;
81    /// Int element type.
82    type IntElem: Element;
83
84    /// Tensor primitive to be used for all bool operations.
85    type BoolTensorPrimitive: TensorMetadata + 'static;
86    /// Tensor primitive to be used for all bool operations.
87    type BoolElem: Element;
88
89    /// Tensor primitive to be used for all quantized operations.
90    type QuantizedTensorPrimitive: TensorMetadata + QTensorPrimitive + 'static;
91
92    /// If autodiff is enabled.
93    fn ad_enabled() -> bool {
94        false
95    }
96
97    /// Sets the current allocation mode to persistent.
98    #[allow(unused_variables)]
99    fn memory_persistent_allocations<Output, Input, Func: Fn(Input) -> Output>(
100        device: &Self::Device,
101        input: Input,
102        func: Func,
103    ) -> Output {
104        func(input)
105    }
106
107    /// Manually triggers a memory cleanup on the given device.
108    #[allow(unused_variables)]
109    fn memory_cleanup(device: &Self::Device) {}
110
111    /// Name of the backend.
112    fn name(device: &Self::Device) -> String;
113
114    /// Seeds the backend on the specified device.
115    ///
116    /// There is no guarantee that only the specified device will be seeded, but it is guaranteed
117    /// that at least the specified device will be seeded.
118    ///
119    /// In all cases, this should ensure deterministic execution for a single-threaded program.
120    fn seed(device: &Self::Device, seed: u64);
121
122    /// Sync the backend, ensure that all computation are finished.
123    fn sync(_device: &Self::Device) {}
124}
125
126/// Trait that allows a backend to support autodiff.
127pub trait AutodiffBackend: Backend {
128    /// The inner backend type.
129    type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>;
130
131    /// Gradients type.
132    type Gradients: Send;
133
134    /// Backward pass.
135    ///
136    /// # Arguments
137    ///
138    /// * `tensor` - The tensor is the last node of computational graph where the gradients are computed.
139    ///
140    /// # Returns
141    ///
142    /// The gradients.
143    fn backward(tensor: FloatTensor<Self>) -> Self::Gradients;
144
145    /// Returns the gradients of a tensor.
146    ///
147    /// # Arguments
148    ///
149    /// * `tensor` - The tensor to extract the gradients from.
150    ///
151    /// # Returns
152    ///
153    /// An optional tensor containing the gradient.
154    fn grad(
155        tensor: &FloatTensor<Self>,
156        grads: &Self::Gradients,
157    ) -> Option<FloatTensor<Self::InnerBackend>>;
158
159    /// Pops the gradients of a tensor and returns them.
160    ///
161    /// # Arguments
162    ///
163    /// * `tensor` - The tensor to pop the gradients from.
164    /// * `grads` - The gradients.
165    ///
166    /// # Returns
167    ///
168    /// An optional tensor containing the given gradients.
169    fn grad_remove(
170        tensor: &FloatTensor<Self>,
171        grads: &mut Self::Gradients,
172    ) -> Option<FloatTensor<Self::InnerBackend>>;
173
174    /// Replace the gradients of a tensor with the one provided.
175    ///
176    /// If no gradient existed for the provided tensor, register it.
177    ///
178    /// # Arguments
179    ///
180    /// * `tensor` - The tensor to pop the gradients from.
181    /// * `grads` - The gradients.
182    /// * `grad` - The updated grad tensor.
183    fn grad_replace(
184        tensor: &FloatTensor<Self>,
185        grads: &mut Self::Gradients,
186        grad: FloatTensor<Self::InnerBackend>,
187    );
188
189    /// Returns the tensor with inner backend type.
190    ///
191    /// # Arguments
192    ///
193    /// * `tensor` - The tensor to get the inner backend tensor for.
194    ///
195    /// # Returns
196    ///
197    /// The inner backend tensor.
198    fn inner(tensor: FloatTensor<Self>) -> FloatTensor<Self::InnerBackend>;
199
200    /// Returns the tensor with inner backend type.
201    ///
202    /// # Arguments
203    ///
204    /// * `tensor` - The tensor to get the inner backend tensor for.
205    ///
206    /// # Returns
207    ///
208    /// The inner backend tensor.
209    fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend>;
210
211    /// Returns the tensor with inner backend type.
212    ///
213    /// # Arguments
214    ///
215    /// * `tensor` - The tensor to get the inner backend tensor for.
216    ///
217    /// # Returns
218    ///
219    /// The inner backend tensor.
220    fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend>;
221
222    /// Returns the tensor with inner backend type.
223    ///
224    /// # Arguments
225    ///
226    /// * `tensor` - The tensor to get the inner backend tensor for.
227    ///
228    /// # Returns
229    ///
230    /// The inner backend tensor.
231    fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend>;
232
233    /// Converts the inner backend tensor to the autodiff backend tensor.
234    ///
235    /// # Arguments
236    ///
237    /// * `tensor` - The inner backend tensor to convert.
238    ///
239    ///
240    /// # Returns
241    ///
242    /// The autodiff backend tensor.
243    fn from_inner(tensor: FloatTensor<Self::InnerBackend>) -> FloatTensor<Self>;
244
245    /// Converts the inner backend tensor to the autodiff backend tensor.
246    ///
247    /// # Arguments
248    ///
249    /// * `tensor` - The inner backend tensor to convert.
250    ///
251    ///
252    /// # Returns
253    ///
254    /// The autodiff backend tensor.
255    fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self>;
256
257    /// Converts the inner backend tensor to the autodiff backend tensor.
258    ///
259    /// # Arguments
260    ///
261    /// * `tensor` - The inner backend tensor to convert.
262    ///
263    ///
264    /// # Returns
265    ///
266    /// The autodiff backend tensor.
267    fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self>;
268
269    /// Converts the inner backend tensor to the autodiff backend tensor.
270    ///
271    /// # Arguments
272    ///
273    /// * `tensor` - The inner backend tensor to convert.
274    ///
275    ///
276    /// # Returns
277    ///
278    /// The autodiff backend tensor.
279    fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self>;
280}