burn_optim/optim/
grad_accum.rs

1use burn_core as burn;
2
3use core::marker::PhantomData;
4
5use burn::module::{AutodiffModule, ModuleVisitor, Param};
6use burn::tensor::{Tensor, backend::AutodiffBackend};
7
8use super::GradientsParams;
9
10/// Accumulate gradients into a single [Gradients](AutodiffBackend::Gradients) object.
11pub struct GradientsAccumulator<M> {
12    grads: GradientsParams,
13    phantom: PhantomData<M>,
14}
15
16impl<M> Default for GradientsAccumulator<M> {
17    fn default() -> Self {
18        Self::new()
19    }
20}
21
22impl<M> GradientsAccumulator<M> {
23    /// Create a new gradients accumulator.
24    pub fn new() -> Self {
25        Self {
26            grads: GradientsParams::new(),
27            phantom: PhantomData,
28        }
29    }
30}
31
32impl<M> GradientsAccumulator<M> {
33    /// Accumulate the given gradients for each parameter in the given module.
34    pub fn accumulate<B: AutodiffBackend>(&mut self, module: &M, grads: GradientsParams)
35    where
36        M: AutodiffModule<B>,
37    {
38        let mut visitor = ModuleGradsAccumulator::<M>::new(&mut self.grads, grads);
39        module.visit(&mut visitor);
40    }
41
42    /// Return the accumulated gradients and reset the accumulator state.
43    pub fn grads(&mut self) -> GradientsParams {
44        let mut grads = GradientsParams::new();
45        core::mem::swap(&mut self.grads, &mut grads);
46
47        grads
48    }
49}
50
51#[derive(new)]
52struct ModuleGradsAccumulator<'a, M> {
53    grads: &'a mut GradientsParams,
54    grads_new: GradientsParams,
55    phantom: PhantomData<M>,
56}
57
58impl<B: AutodiffBackend, M: AutodiffModule<B>> ModuleVisitor<B> for ModuleGradsAccumulator<'_, M> {
59    fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
60        let grad_updated = match self.grads_new.remove::<B::InnerBackend, D>(param.id) {
61            Some(new) => match self.grads.remove::<B::InnerBackend, D>(param.id) {
62                Some(grad) => grad.add(new),
63                None => new,
64            },
65            None => match self.grads.remove::<B::InnerBackend, D>(param.id) {
66                Some(grad) => grad,
67                None => return,
68            },
69        };
70
71        self.grads
72            .register::<B::InnerBackend, D>(param.id, grad_updated);
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79    use crate::TestAutodiffBackend;
80    use burn::tensor::{Distribution, backend::Backend};
81    use burn_nn::{Linear, LinearConfig};
82
83    #[test]
84    fn test_accumulate_gradients_one_step() {
85        let device = Default::default();
86        let mut accumulator = GradientsAccumulator::new();
87        let layer = layer::<TestAutodiffBackend>(&device);
88        let loss = layer.forward(random_tensor::<TestAutodiffBackend>(&device));
89        let grads = GradientsParams::from_grads(loss.backward(), &layer);
90
91        accumulator.accumulate(&layer, grads);
92
93        let grads = accumulator.grads();
94        assert!(!grads.is_empty())
95    }
96
97    #[test]
98    fn test_accumulate_gradients_two_steps() {
99        let device = Default::default();
100        let mut accumulator = GradientsAccumulator::new();
101        let layer = layer::<TestAutodiffBackend>(&device);
102        let loss_1 = layer.forward(random_tensor(&device));
103        let loss_2 = layer.forward(random_tensor(&device));
104        let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer);
105        let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer);
106
107        accumulator.accumulate(&layer, grads_1);
108        accumulator.accumulate(&layer, grads_2);
109
110        let grads = accumulator.grads();
111        assert_eq!(grads.len(), 2)
112    }
113
114    fn layer<B: Backend>(device: &B::Device) -> Linear<B> {
115        LinearConfig::new(20, 20).init(device)
116    }
117
118    fn random_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
119        Tensor::<B, 2>::random([2, 20], Distribution::Default, device)
120    }
121}