burn_core/optim/
grad_accum.rs

1use core::marker::PhantomData;
2
3use crate::module::{AutodiffModule, ModuleVisitor, ParamId};
4
5use burn_tensor::{Tensor, backend::AutodiffBackend};
6
7use super::GradientsParams;
8
9/// Accumulate gradients into a single [Gradients](AutodiffBackend::Gradients) object.
10pub struct GradientsAccumulator<M> {
11    grads: GradientsParams,
12    phantom: PhantomData<M>,
13}
14
15impl<M> Default for GradientsAccumulator<M> {
16    fn default() -> Self {
17        Self::new()
18    }
19}
20
21impl<M> GradientsAccumulator<M> {
22    /// Create a new gradients accumulator.
23    pub fn new() -> Self {
24        Self {
25            grads: GradientsParams::new(),
26            phantom: PhantomData,
27        }
28    }
29}
30
31impl<M> GradientsAccumulator<M> {
32    /// Accumulate the given gradients for each parameter in the given module.
33    pub fn accumulate<B: AutodiffBackend>(&mut self, module: &M, grads: GradientsParams)
34    where
35        M: AutodiffModule<B>,
36    {
37        let mut visitor = ModuleGradsAccumulator::<M>::new(&mut self.grads, grads);
38        module.visit(&mut visitor);
39    }
40
41    /// Return the accumulated gradients and reset the accumulator state.
42    pub fn grads(&mut self) -> GradientsParams {
43        let mut grads = GradientsParams::new();
44        core::mem::swap(&mut self.grads, &mut grads);
45
46        grads
47    }
48}
49
50#[derive(new)]
51struct ModuleGradsAccumulator<'a, M> {
52    grads: &'a mut GradientsParams,
53    grads_new: GradientsParams,
54    phantom: PhantomData<M>,
55}
56
57impl<B: AutodiffBackend, M: AutodiffModule<B>> ModuleVisitor<B> for ModuleGradsAccumulator<'_, M> {
58    fn visit_float<const D: usize>(&mut self, id: ParamId, _tensor: &Tensor<B, D>) {
59        let grad_updated = match self.grads_new.remove::<B::InnerBackend, D>(id) {
60            Some(new) => match self.grads.remove::<B::InnerBackend, D>(id) {
61                Some(grad) => grad.add(new),
62                None => new,
63            },
64            None => match self.grads.remove::<B::InnerBackend, D>(id) {
65                Some(grad) => grad,
66                None => return,
67            },
68        };
69
70        self.grads.register::<B::InnerBackend, D>(id, grad_updated);
71    }
72}
73
74#[cfg(test)]
75mod tests {
76    use super::*;
77    use crate::{
78        TestAutodiffBackend,
79        nn::{Linear, LinearConfig},
80    };
81    use burn_tensor::{Distribution, backend::Backend};
82
83    #[test]
84    fn test_accumulate_gradients_one_step() {
85        let device = Default::default();
86        let mut accumulator = GradientsAccumulator::new();
87        let layer = layer::<TestAutodiffBackend>(&device);
88        let loss = layer.forward(random_tensor::<TestAutodiffBackend>(&device));
89        let grads = GradientsParams::from_grads(loss.backward(), &layer);
90
91        accumulator.accumulate(&layer, grads);
92
93        let grads = accumulator.grads();
94        assert!(!grads.is_empty())
95    }
96
97    #[test]
98    fn test_accumulate_gradients_two_steps() {
99        let device = Default::default();
100        let mut accumulator = GradientsAccumulator::new();
101        let layer = layer::<TestAutodiffBackend>(&device);
102        let loss_1 = layer.forward(random_tensor(&device));
103        let loss_2 = layer.forward(random_tensor(&device));
104        let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer);
105        let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer);
106
107        accumulator.accumulate(&layer, grads_1);
108        accumulator.accumulate(&layer, grads_2);
109
110        let grads = accumulator.grads();
111        assert_eq!(grads.len(), 2)
112    }
113
114    fn layer<B: Backend>(device: &B::Device) -> Linear<B> {
115        LinearConfig::new(20, 20).with_bias(true).init(device)
116    }
117
118    fn random_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
119        Tensor::<B, 2>::random([2, 20], Distribution::Default, device)
120    }
121}