1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
use crate::{
    checkpoint::strategy::{CheckpointStrategy, NoCheckpointing},
    grads::Gradients,
    runtime::AutodiffClient,
    tensor::AutodiffTensor,
    AutodiffBridge,
};
use burn_common::sync_type::SyncType;
use burn_tensor::backend::{AutodiffBackend, Backend};
use core::marker::PhantomData;

/// Enable auto-differentiation on a backend.
///
/// This works as a backend decorator, extending the functionality of any backend with
/// backpropagation.
#[derive(Clone, Copy, Debug, Default)]
pub struct Autodiff<B, C = NoCheckpointing> {
    _b: PhantomData<B>,
    _checkpoint_strategy: PhantomData<C>,
}

impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
    type Device = B::Device;

    type FullPrecisionBridge = AutodiffBridge<B::FullPrecisionBridge>;

    type FloatTensorPrimitive<const D: usize> = AutodiffTensor<B, D>;
    type FloatElem = B::FloatElem;

    type IntTensorPrimitive<const D: usize> = B::IntTensorPrimitive<D>;
    type IntElem = B::IntElem;

    type BoolTensorPrimitive<const D: usize> = B::BoolTensorPrimitive<D>;

    type QuantizedTensorPrimitive<const D: usize> = B::QuantizedTensorPrimitive<D>;

    fn ad_enabled() -> bool {
        true
    }

    fn name() -> String {
        format!("autodiff<{}>", B::name())
    }

    fn seed(seed: u64) {
        B::seed(seed)
    }

    fn sync(device: &B::Device, sync_type: SyncType) {
        B::sync(device, sync_type)
    }
}

impl<B: Backend, C: CheckpointStrategy> AutodiffBackend for Autodiff<B, C> {
    type InnerBackend = B;
    type Gradients = Gradients;

    fn backward<const D: usize>(tensor: AutodiffTensor<B, D>) -> Gradients {
        let client = tensor.node.client.clone();

        AutodiffClient::backward(&client, tensor)
    }

    fn grad<const D: usize>(
        tensor: &AutodiffTensor<B, D>,
        grads: &Gradients,
    ) -> Option<B::FloatTensorPrimitive<D>> {
        grads.get(tensor)
    }

    fn grad_remove<const D: usize>(
        tensor: &AutodiffTensor<B, D>,
        grads: &mut Gradients,
    ) -> Option<B::FloatTensorPrimitive<D>> {
        grads.remove(tensor)
    }
    fn inner<const D: usize>(tensor: AutodiffTensor<B, D>) -> B::FloatTensorPrimitive<D> {
        tensor.primitive
    }

    fn from_inner<const D: usize>(tensor: B::FloatTensorPrimitive<D>) -> AutodiffTensor<B, D> {
        AutodiffTensor::new(tensor)
    }

    fn grad_replace<const D: usize>(
        tensor: &AutodiffTensor<B, D>,
        grads: &mut Self::Gradients,
        grad: B::FloatTensorPrimitive<D>,
    ) {
        grads.remove(tensor);
        grads.register::<B, D>(tensor.node.id, grad);
    }

    fn int_inner<const D: usize>(
        tensor: burn_tensor::ops::IntTensor<Self, D>,
    ) -> burn_tensor::ops::IntTensor<Self::InnerBackend, D> {
        tensor
    }

    fn bool_inner<const D: usize>(
        tensor: burn_tensor::ops::BoolTensor<Self, D>,
    ) -> burn_tensor::ops::BoolTensor<Self::InnerBackend, D> {
        tensor
    }

    fn int_from_inner<const D: usize>(
        tensor: burn_tensor::ops::IntTensor<Self::InnerBackend, D>,
    ) -> burn_tensor::ops::IntTensor<Self, D> {
        tensor
    }

    fn bool_from_inner<const D: usize>(
        tensor: burn_tensor::ops::BoolTensor<Self::InnerBackend, D>,
    ) -> burn_tensor::ops::BoolTensor<Self, D> {
        tensor
    }

    fn q_inner<const D: usize>(
        tensor: burn_tensor::ops::QuantizedTensor<Self, D>,
    ) -> burn_tensor::ops::QuantizedTensor<Self::InnerBackend, D> {
        tensor
    }

    fn q_from_inner<const D: usize>(
        tensor: burn_tensor::ops::QuantizedTensor<Self::InnerBackend, D>,
    ) -> burn_tensor::ops::QuantizedTensor<Self, D> {
        tensor
    }
}