1use crate::{
2 checkpoint::strategy::{CheckpointStrategy, NoCheckpointing},
3 grads::Gradients,
4 tensor::AutodiffTensor,
5};
6use alloc::{format, string::String};
7use core::marker::PhantomData;
8
9use burn_backend::{
10 backend::{AutodiffBackend, Backend, BackendTypes, ExecutionError},
11 tensor::{BoolTensor, IntTensor, QuantizedTensor},
12};
13
14#[cfg(feature = "distributed")]
15use burn_backend::distributed::{DistributedBackend, DistributedParamId, DistributedParams};
16
17#[derive(Clone, Copy, Debug, Default)]
22pub struct Autodiff<B, C = NoCheckpointing> {
23 _b: PhantomData<B>,
24 _checkpoint_strategy: PhantomData<C>,
25}
26
27impl<B: Backend, C: CheckpointStrategy> BackendTypes for Autodiff<B, C> {
28 type Device = B::Device;
29
30 type FloatTensorPrimitive = AutodiffTensor<B>;
31 type FloatElem = B::FloatElem;
32
33 type IntTensorPrimitive = B::IntTensorPrimitive;
34 type IntElem = B::IntElem;
35
36 type BoolTensorPrimitive = B::BoolTensorPrimitive;
37 type BoolElem = B::BoolElem;
38
39 type QuantizedTensorPrimitive = B::QuantizedTensorPrimitive;
40}
41
42impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
43 fn ad_enabled(_device: &Self::Device) -> bool {
44 true
45 }
46
47 fn name(device: &Self::Device) -> String {
48 format!("autodiff<{}>", B::name(device))
49 }
50
51 fn seed(device: &B::Device, seed: u64) {
52 B::seed(device, seed)
53 }
54
55 fn sync(device: &B::Device) -> Result<(), ExecutionError> {
56 B::sync(device)
57 }
58
59 fn memory_persistent_allocations<
60 Output: Send,
61 Input: Send,
62 Func: Fn(Input) -> Output + Send,
63 >(
64 device: &Self::Device,
65 input: Input,
66 func: Func,
67 ) -> Output {
68 B::memory_persistent_allocations(device, input, func)
69 }
70
71 fn memory_cleanup(device: &Self::Device) {
72 B::memory_cleanup(device)
73 }
74
75 fn staging<'a, Iter>(data: Iter, device: &Self::Device)
76 where
77 Iter: Iterator<Item = &'a mut burn_backend::TensorData>,
78 {
79 B::staging(data, device);
80 }
81
82 fn supports_dtype(device: &Self::Device, dtype: burn_std::DType) -> bool {
83 B::supports_dtype(device, dtype)
84 }
85
86 fn dtype_usage(device: &Self::Device, dtype: burn_std::DType) -> burn_backend::DTypeUsageSet {
87 B::dtype_usage(device, dtype)
88 }
89
90 fn device_count(type_id: u16) -> usize {
91 B::device_count(type_id)
92 }
93}
94
95#[cfg(not(feature = "distributed"))]
96impl<B: Backend, C: CheckpointStrategy> AutodiffBackend for Autodiff<B, C> {
97 type InnerBackend = B;
98 type Gradients = Gradients;
99
100 fn backward(tensor: AutodiffTensor<B>) -> Gradients {
101 tensor.backward()
102 }
103
104 fn grad(tensor: &AutodiffTensor<B>, grads: &Gradients) -> Option<B::FloatTensorPrimitive> {
105 tensor.grad(grads)
106 }
107
108 fn grad_remove(
109 tensor: &AutodiffTensor<B>,
110 grads: &mut Gradients,
111 ) -> Option<B::FloatTensorPrimitive> {
112 tensor.grad_remove(grads)
113 }
114 fn inner(tensor: AutodiffTensor<B>) -> B::FloatTensorPrimitive {
115 tensor.primitive
116 }
117
118 fn from_inner(tensor: B::FloatTensorPrimitive) -> AutodiffTensor<B> {
119 AutodiffTensor::new(tensor)
120 }
121
122 fn grad_replace(
123 tensor: &AutodiffTensor<B>,
124 grads: &mut Self::Gradients,
125 grad: B::FloatTensorPrimitive,
126 ) {
127 tensor.grad_replace(grads, grad);
128 }
129
130 fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend> {
131 tensor
132 }
133
134 fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend> {
135 tensor
136 }
137
138 fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self> {
139 tensor
140 }
141
142 fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self> {
143 tensor
144 }
145
146 fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend> {
147 tensor
148 }
149
150 fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self> {
151 tensor
152 }
153}
154
155#[cfg(feature = "distributed")]
156impl<B: DistributedBackend, C: CheckpointStrategy> AutodiffBackend for Autodiff<B, C> {
157 type InnerBackend = B;
158 type Gradients = Gradients;
159
160 fn backward(tensor: AutodiffTensor<B>) -> Gradients {
161 tensor.backward()
162 }
163
164 fn grad(tensor: &AutodiffTensor<B>, grads: &Gradients) -> Option<B::FloatTensorPrimitive> {
165 tensor.grad(grads)
166 }
167
168 fn grad_remove(
169 tensor: &AutodiffTensor<B>,
170 grads: &mut Gradients,
171 ) -> Option<B::FloatTensorPrimitive> {
172 tensor.grad_remove(grads)
173 }
174 fn inner(tensor: AutodiffTensor<B>) -> B::FloatTensorPrimitive {
175 tensor.primitive
176 }
177
178 fn from_inner(tensor: B::FloatTensorPrimitive) -> AutodiffTensor<B> {
179 AutodiffTensor::new(tensor)
180 }
181
182 fn grad_replace(
183 tensor: &AutodiffTensor<B>,
184 grads: &mut Self::Gradients,
185 grad: B::FloatTensorPrimitive,
186 ) {
187 tensor.grad_replace(grads, grad);
188 }
189
190 fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend> {
191 tensor
192 }
193
194 fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend> {
195 tensor
196 }
197
198 fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self> {
199 tensor
200 }
201
202 fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self> {
203 tensor
204 }
205
206 fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend> {
207 tensor
208 }
209
210 fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self> {
211 tensor
212 }
213
214 fn set_distributed_params(
215 tensor: AutodiffTensor<B>,
216 param_id: DistributedParamId,
217 ) -> AutodiffTensor<B> {
218 tensor.grad_distributed(param_id)
219 }
220
221 fn distributed_params(tensor: &AutodiffTensor<B>) -> Option<DistributedParams> {
222 tensor.node.distributed_params.clone()
223 }
224
225 fn is_distributed(tensor: &AutodiffTensor<B>) -> bool {
226 tensor.node.distributed_params.is_some()
227 }
228}