Skip to main content

burn_autodiff/
backend.rs

1use crate::{
2    checkpoint::strategy::{CheckpointStrategy, NoCheckpointing},
3    grads::Gradients,
4    tensor::AutodiffTensor,
5};
6use alloc::{format, string::String};
7use core::marker::PhantomData;
8
9use burn_backend::{
10    backend::{AutodiffBackend, Backend, BackendTypes, ExecutionError},
11    tensor::{BoolTensor, IntTensor, QuantizedTensor},
12};
13
14#[cfg(feature = "distributed")]
15use burn_backend::distributed::{DistributedBackend, DistributedParamId, DistributedParams};
16
17/// Enable auto-differentiation on a backend.
18///
19/// This works as a backend decorator, extending the functionality of any backend with
20/// backpropagation.
21#[derive(Clone, Copy, Debug, Default)]
22pub struct Autodiff<B, C = NoCheckpointing> {
23    _b: PhantomData<B>,
24    _checkpoint_strategy: PhantomData<C>,
25}
26
27impl<B: Backend, C: CheckpointStrategy> BackendTypes for Autodiff<B, C> {
28    type Device = B::Device;
29
30    type FloatTensorPrimitive = AutodiffTensor<B>;
31    type FloatElem = B::FloatElem;
32
33    type IntTensorPrimitive = B::IntTensorPrimitive;
34    type IntElem = B::IntElem;
35
36    type BoolTensorPrimitive = B::BoolTensorPrimitive;
37    type BoolElem = B::BoolElem;
38
39    type QuantizedTensorPrimitive = B::QuantizedTensorPrimitive;
40}
41
42impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
43    fn ad_enabled(_device: &Self::Device) -> bool {
44        true
45    }
46
47    fn name(device: &Self::Device) -> String {
48        format!("autodiff<{}>", B::name(device))
49    }
50
51    fn seed(device: &B::Device, seed: u64) {
52        B::seed(device, seed)
53    }
54
55    fn sync(device: &B::Device) -> Result<(), ExecutionError> {
56        B::sync(device)
57    }
58
59    fn memory_persistent_allocations<
60        Output: Send,
61        Input: Send,
62        Func: Fn(Input) -> Output + Send,
63    >(
64        device: &Self::Device,
65        input: Input,
66        func: Func,
67    ) -> Output {
68        B::memory_persistent_allocations(device, input, func)
69    }
70
71    fn memory_cleanup(device: &Self::Device) {
72        B::memory_cleanup(device)
73    }
74
75    fn staging<'a, Iter>(data: Iter, device: &Self::Device)
76    where
77        Iter: Iterator<Item = &'a mut burn_backend::TensorData>,
78    {
79        B::staging(data, device);
80    }
81
82    fn supports_dtype(device: &Self::Device, dtype: burn_std::DType) -> bool {
83        B::supports_dtype(device, dtype)
84    }
85
86    fn dtype_usage(device: &Self::Device, dtype: burn_std::DType) -> burn_backend::DTypeUsageSet {
87        B::dtype_usage(device, dtype)
88    }
89
90    fn device_count(type_id: u16) -> usize {
91        B::device_count(type_id)
92    }
93}
94
95#[cfg(not(feature = "distributed"))]
96impl<B: Backend, C: CheckpointStrategy> AutodiffBackend for Autodiff<B, C> {
97    type InnerBackend = B;
98    type Gradients = Gradients;
99
100    fn backward(tensor: AutodiffTensor<B>) -> Gradients {
101        tensor.backward()
102    }
103
104    fn grad(tensor: &AutodiffTensor<B>, grads: &Gradients) -> Option<B::FloatTensorPrimitive> {
105        tensor.grad(grads)
106    }
107
108    fn grad_remove(
109        tensor: &AutodiffTensor<B>,
110        grads: &mut Gradients,
111    ) -> Option<B::FloatTensorPrimitive> {
112        tensor.grad_remove(grads)
113    }
114    fn inner(tensor: AutodiffTensor<B>) -> B::FloatTensorPrimitive {
115        tensor.primitive
116    }
117
118    fn from_inner(tensor: B::FloatTensorPrimitive) -> AutodiffTensor<B> {
119        AutodiffTensor::new(tensor)
120    }
121
122    fn grad_replace(
123        tensor: &AutodiffTensor<B>,
124        grads: &mut Self::Gradients,
125        grad: B::FloatTensorPrimitive,
126    ) {
127        tensor.grad_replace(grads, grad);
128    }
129
130    fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend> {
131        tensor
132    }
133
134    fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend> {
135        tensor
136    }
137
138    fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self> {
139        tensor
140    }
141
142    fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self> {
143        tensor
144    }
145
146    fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend> {
147        tensor
148    }
149
150    fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self> {
151        tensor
152    }
153}
154
155#[cfg(feature = "distributed")]
156impl<B: DistributedBackend, C: CheckpointStrategy> AutodiffBackend for Autodiff<B, C> {
157    type InnerBackend = B;
158    type Gradients = Gradients;
159
160    fn backward(tensor: AutodiffTensor<B>) -> Gradients {
161        tensor.backward()
162    }
163
164    fn grad(tensor: &AutodiffTensor<B>, grads: &Gradients) -> Option<B::FloatTensorPrimitive> {
165        tensor.grad(grads)
166    }
167
168    fn grad_remove(
169        tensor: &AutodiffTensor<B>,
170        grads: &mut Gradients,
171    ) -> Option<B::FloatTensorPrimitive> {
172        tensor.grad_remove(grads)
173    }
174    fn inner(tensor: AutodiffTensor<B>) -> B::FloatTensorPrimitive {
175        tensor.primitive
176    }
177
178    fn from_inner(tensor: B::FloatTensorPrimitive) -> AutodiffTensor<B> {
179        AutodiffTensor::new(tensor)
180    }
181
182    fn grad_replace(
183        tensor: &AutodiffTensor<B>,
184        grads: &mut Self::Gradients,
185        grad: B::FloatTensorPrimitive,
186    ) {
187        tensor.grad_replace(grads, grad);
188    }
189
190    fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend> {
191        tensor
192    }
193
194    fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend> {
195        tensor
196    }
197
198    fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self> {
199        tensor
200    }
201
202    fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self> {
203        tensor
204    }
205
206    fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend> {
207        tensor
208    }
209
210    fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self> {
211        tensor
212    }
213
214    fn set_distributed_params(
215        tensor: AutodiffTensor<B>,
216        param_id: DistributedParamId,
217    ) -> AutodiffTensor<B> {
218        tensor.grad_distributed(param_id)
219    }
220
221    fn distributed_params(tensor: &AutodiffTensor<B>) -> Option<DistributedParams> {
222        tensor.node.distributed_params.clone()
223    }
224
225    fn is_distributed(tensor: &AutodiffTensor<B>) -> bool {
226        tensor.node.distributed_params.is_some()
227    }
228}