1use crate::{
2 checkpoint::strategy::{CheckpointStrategy, NoCheckpointing},
3 grads::Gradients,
4 runtime::AutodiffClient,
5 tensor::AutodiffTensor,
6};
7use alloc::{format, string::String};
8use burn_backend::{
9 backend::{AutodiffBackend, Backend, ExecutionError},
10 tensor::{BoolTensor, IntTensor, QuantizedTensor},
11};
12use core::marker::PhantomData;
13
14#[derive(Clone, Copy, Debug, Default)]
19pub struct Autodiff<B, C = NoCheckpointing> {
20 _b: PhantomData<B>,
21 _checkpoint_strategy: PhantomData<C>,
22}
23
24impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
25 type Device = B::Device;
26
27 type FloatTensorPrimitive = AutodiffTensor<B>;
28 type FloatElem = B::FloatElem;
29
30 type IntTensorPrimitive = B::IntTensorPrimitive;
31 type IntElem = B::IntElem;
32
33 type BoolTensorPrimitive = B::BoolTensorPrimitive;
34 type BoolElem = B::BoolElem;
35
36 type QuantizedTensorPrimitive = B::QuantizedTensorPrimitive;
37
38 fn ad_enabled() -> bool {
39 true
40 }
41
42 fn name(device: &Self::Device) -> String {
43 format!("autodiff<{}>", B::name(device))
44 }
45
46 fn seed(device: &B::Device, seed: u64) {
47 B::seed(device, seed)
48 }
49
50 fn sync(device: &B::Device) -> Result<(), ExecutionError> {
51 B::sync(device)
52 }
53
54 fn memory_persistent_allocations<Output, Input, Func: Fn(Input) -> Output>(
55 device: &Self::Device,
56 input: Input,
57 func: Func,
58 ) -> Output {
59 B::memory_persistent_allocations(device, input, func)
60 }
61
62 fn memory_cleanup(device: &Self::Device) {
63 B::memory_cleanup(device)
64 }
65
66 fn staging<'a, Iter>(data: Iter, device: &Self::Device)
67 where
68 Iter: Iterator<Item = &'a mut burn_backend::TensorData>,
69 {
70 B::staging(data, device);
71 }
72
73 fn supports_dtype(device: &Self::Device, dtype: burn_std::DType) -> bool {
74 B::supports_dtype(device, dtype)
75 }
76}
77
78impl<B: Backend, C: CheckpointStrategy> AutodiffBackend for Autodiff<B, C> {
79 type InnerBackend = B;
80 type Gradients = Gradients;
81
82 fn backward(tensor: AutodiffTensor<B>) -> Gradients {
83 let client = tensor.node.client.clone();
84
85 AutodiffClient::backward::<B>(&client, tensor)
86 }
87
88 fn grad(tensor: &AutodiffTensor<B>, grads: &Gradients) -> Option<B::FloatTensorPrimitive> {
89 grads.get::<B>(tensor)
90 }
91
92 fn grad_remove(
93 tensor: &AutodiffTensor<B>,
94 grads: &mut Gradients,
95 ) -> Option<B::FloatTensorPrimitive> {
96 grads.remove::<B>(tensor)
97 }
98 fn inner(tensor: AutodiffTensor<B>) -> B::FloatTensorPrimitive {
99 tensor.primitive
100 }
101
102 fn from_inner(tensor: B::FloatTensorPrimitive) -> AutodiffTensor<B> {
103 AutodiffTensor::new(tensor)
104 }
105
106 fn grad_replace(
107 tensor: &AutodiffTensor<B>,
108 grads: &mut Self::Gradients,
109 grad: B::FloatTensorPrimitive,
110 ) {
111 grads.remove::<B>(tensor);
112 grads.register::<B>(tensor.node.id, grad);
113 }
114
115 fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend> {
116 tensor
117 }
118
119 fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend> {
120 tensor
121 }
122
123 fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self> {
124 tensor
125 }
126
127 fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self> {
128 tensor
129 }
130
131 fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend> {
132 tensor
133 }
134
135 fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self> {
136 tensor
137 }
138}