burn_tensor/tensor/backend/base.rs
1use alloc::string::String;
2
3use crate::tensor::Element;
4use crate::TensorMetadata;
5use crate::{ops::*, quantization::QTensorPrimitive};
6
7use super::DeviceOps;
8
9/// This trait defines all types and functions needed for a backend to be used with burn.
10///
11/// ## Design
12///
13/// This trait aims to be as unopinionated as possible and allows implementations to define
14/// their own types and patterns. Therefore, there are few pre-defined abstractions baked
15/// into this trait.
16///
17/// Backends must define their own tensor types for each data type: `float`, `int`, and `bool`.
18/// Since we minimize assumptions, we chose to separate these types, as they are used in
19/// different contexts. However, some backends may have a generic tensor type that is used
20/// for all data types.
21///
22/// ### Eager Mode
23///
24/// Because burn supports dynamic graphs, the backend trait is designed around kernel
25/// implementations that can be called without any mutable context or graph. This may not be
26/// ideal for backends that want to configure their computational graphs and execute them
27/// multiple times.
28///
29/// To implement this kind of backend, channels could be used to communicate with a backend
30/// server thread to build the computation graphs and re-execute the ones that are repeated,
31/// with some form of cache. Once that pattern has matured, a graph mode backend trait could
32/// be extracted from it, allowing other backends of the same kind to be quickly integrated
33/// with burn. This pattern could also be used to create an operation fusion trait, which
34/// allows backends to define what kind of graph structures can be fused into one operation.
35///
36/// ### Multi-Threaded
37///
38/// Backend tensor types are all `Clone` + `Send`, which allows them to be safely
39/// sent between threads. It is recommended to wrap tensors with [Arc](alloc::sync::Arc),
40/// which avoids copying the tensor's buffer. Note that it is still possible to mutate and
41/// reuse tensors' buffer without locking; see the next section on the Mutable API.
42///
43/// ### Mutable API
44///
45/// There is no mutable or inplace operation API to implement, but that does not mean that
46/// backends cannot support them. Using [try_unwrap](alloc::sync::Arc::try_unwrap) and
47/// [get_mut](alloc::sync::Arc::get_mut) allows backends to have access to an owned or mutable
48/// reference to their tensor buffer data structure if the tensor is not shared. In that case,
49/// backends can dispatch to their owned inplace operations for better performance.
50///
51/// ## Documentation
52///
53/// Most of the documentation for each function can be found on the user API [tensor struct](crate::Tensor).
54/// For modules, public functions are often created, which can be used by `burn-core` modules.
55pub trait Backend:
56 FloatTensorOps<Self>
57 + BoolTensorOps<Self>
58 + IntTensorOps<Self>
59 + ModuleOps<Self>
60 + ActivationOps<Self>
61 + QTensorOps<Self>
62 + TransactionOps<Self>
63 + Clone
64 + Default
65 + Sized
66 + Send
67 + Sync
68 + core::fmt::Debug
69 + 'static
70{
71 /// Device type.
72 type Device: DeviceOps;
73
74 /// Tensor primitive to be used for all float operations.
75 type FloatTensorPrimitive: TensorMetadata + 'static;
76 /// Default float element type.
77 type FloatElem: Element;
78
79 /// Tensor primitive to be used for all int operations.
80 type IntTensorPrimitive: TensorMetadata + 'static;
81 /// Int element type.
82 type IntElem: Element;
83
84 /// Tensor primitive to be used for all bool operations.
85 type BoolTensorPrimitive: TensorMetadata + 'static;
86 /// Tensor primitive to be used for all bool operations.
87 type BoolElem: Element;
88
89 /// Tensor primitive to be used for all quantized operations.
90 type QuantizedTensorPrimitive: TensorMetadata + QTensorPrimitive + 'static;
91 /// Quantized tensor encoding type.
92 type QuantizedEncoding: Element;
93
94 /// If autodiff is enabled.
95 fn ad_enabled() -> bool {
96 false
97 }
98
99 /// Name of the backend.
100 fn name() -> String;
101
102 /// Seed the backend.
103 fn seed(seed: u64);
104
105 /// Sync the backend, ensure that all computation are finished.
106 fn sync(_device: &Self::Device) {}
107}
108
109/// Trait that allows a backend to support autodiff.
110pub trait AutodiffBackend: Backend {
111 /// The inner backend type.
112 type InnerBackend: Backend<
113 Device = Self::Device,
114 FloatElem = Self::FloatElem,
115 IntElem = Self::IntElem,
116 >;
117
118 /// Gradients type.
119 type Gradients: Send;
120
121 /// Backward pass.
122 ///
123 /// # Arguments
124 ///
125 /// * `tensor` - The tensor is the last node of computational graph where the gradients are computed.
126 ///
127 /// # Returns
128 ///
129 /// The gradients.
130 fn backward(tensor: FloatTensor<Self>) -> Self::Gradients;
131
132 /// Returns the gradients of a tensor.
133 ///
134 /// # Arguments
135 ///
136 /// * `tensor` - The tensor to extract the gradients from.
137 ///
138 /// # Returns
139 ///
140 /// An optional tensor containing the gradient.
141 fn grad(
142 tensor: &FloatTensor<Self>,
143 grads: &Self::Gradients,
144 ) -> Option<FloatTensor<Self::InnerBackend>>;
145
146 /// Pops the gradients of a tensor and returns them.
147 ///
148 /// # Arguments
149 ///
150 /// * `tensor` - The tensor to pop the gradients from.
151 /// * `grads` - The gradients.
152 ///
153 /// # Returns
154 ///
155 /// An optional tensor containing the given gradients.
156 fn grad_remove(
157 tensor: &FloatTensor<Self>,
158 grads: &mut Self::Gradients,
159 ) -> Option<FloatTensor<Self::InnerBackend>>;
160
161 /// Replace the gradients of a tensor with the one provided.
162 ///
163 /// If no gradient existed for the provided tensor, register it.
164 ///
165 /// # Arguments
166 ///
167 /// * `tensor` - The tensor to pop the gradients from.
168 /// * `grads` - The gradients.
169 /// * `grad` - The updated grad tensor.
170 fn grad_replace(
171 tensor: &FloatTensor<Self>,
172 grads: &mut Self::Gradients,
173 grad: FloatTensor<Self::InnerBackend>,
174 );
175
176 /// Returns the tensor with inner backend type.
177 ///
178 /// # Arguments
179 ///
180 /// * `tensor` - The tensor to get the inner backend tensor for.
181 ///
182 /// # Returns
183 ///
184 /// The inner backend tensor.
185 fn inner(tensor: FloatTensor<Self>) -> FloatTensor<Self::InnerBackend>;
186
187 /// Returns the tensor with inner backend type.
188 ///
189 /// # Arguments
190 ///
191 /// * `tensor` - The tensor to get the inner backend tensor for.
192 ///
193 /// # Returns
194 ///
195 /// The inner backend tensor.
196 fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend>;
197
198 /// Returns the tensor with inner backend type.
199 ///
200 /// # Arguments
201 ///
202 /// * `tensor` - The tensor to get the inner backend tensor for.
203 ///
204 /// # Returns
205 ///
206 /// The inner backend tensor.
207 fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend>;
208
209 /// Returns the tensor with inner backend type.
210 ///
211 /// # Arguments
212 ///
213 /// * `tensor` - The tensor to get the inner backend tensor for.
214 ///
215 /// # Returns
216 ///
217 /// The inner backend tensor.
218 fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend>;
219
220 /// Converts the inner backend tensor to the autodiff backend tensor.
221 ///
222 /// # Arguments
223 ///
224 /// * `tensor` - The inner backend tensor to convert.
225 ///
226 ///
227 /// # Returns
228 ///
229 /// The autodiff backend tensor.
230 fn from_inner(tensor: FloatTensor<Self::InnerBackend>) -> FloatTensor<Self>;
231
232 /// Converts the inner backend tensor to the autodiff backend tensor.
233 ///
234 /// # Arguments
235 ///
236 /// * `tensor` - The inner backend tensor to convert.
237 ///
238 ///
239 /// # Returns
240 ///
241 /// The autodiff backend tensor.
242 fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self>;
243
244 /// Converts the inner backend tensor to the autodiff backend tensor.
245 ///
246 /// # Arguments
247 ///
248 /// * `tensor` - The inner backend tensor to convert.
249 ///
250 ///
251 /// # Returns
252 ///
253 /// The autodiff backend tensor.
254 fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self>;
255
256 /// Converts the inner backend tensor to the autodiff backend tensor.
257 ///
258 /// # Arguments
259 ///
260 /// * `tensor` - The inner backend tensor to convert.
261 ///
262 ///
263 /// # Returns
264 ///
265 /// The autodiff backend tensor.
266 fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self>;
267}