burn_optim/optim/
grad_accum.rs1use burn_core as burn;
2
3use core::marker::PhantomData;
4
5use burn::module::{AutodiffModule, ModuleVisitor, Param};
6use burn::tensor::{Tensor, backend::AutodiffBackend};
7
8use super::GradientsParams;
9
10pub struct GradientsAccumulator<M> {
12 grads: GradientsParams,
13 phantom: PhantomData<M>,
14}
15
16impl<M> Default for GradientsAccumulator<M> {
17 fn default() -> Self {
18 Self::new()
19 }
20}
21
22impl<M> GradientsAccumulator<M> {
23 pub fn new() -> Self {
25 Self {
26 grads: GradientsParams::new(),
27 phantom: PhantomData,
28 }
29 }
30}
31
32impl<M> GradientsAccumulator<M> {
33 pub fn accumulate<B: AutodiffBackend>(&mut self, module: &M, grads: GradientsParams)
35 where
36 M: AutodiffModule<B>,
37 {
38 let mut visitor = ModuleGradsAccumulator::<M>::new(&mut self.grads, grads);
39 module.visit(&mut visitor);
40 }
41
42 pub fn grads(&mut self) -> GradientsParams {
44 let mut grads = GradientsParams::new();
45 core::mem::swap(&mut self.grads, &mut grads);
46
47 grads
48 }
49}
50
51#[derive(new)]
52struct ModuleGradsAccumulator<'a, M> {
53 grads: &'a mut GradientsParams,
54 grads_new: GradientsParams,
55 phantom: PhantomData<M>,
56}
57
58impl<B: AutodiffBackend, M: AutodiffModule<B>> ModuleVisitor<B> for ModuleGradsAccumulator<'_, M> {
59 fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
60 let grad_updated = match self.grads_new.remove::<B::InnerBackend, D>(param.id) {
61 Some(new) => match self.grads.remove::<B::InnerBackend, D>(param.id) {
62 Some(grad) => grad.add(new),
63 None => new,
64 },
65 None => match self.grads.remove::<B::InnerBackend, D>(param.id) {
66 Some(grad) => grad,
67 None => return,
68 },
69 };
70
71 self.grads
72 .register::<B::InnerBackend, D>(param.id, grad_updated);
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79 use crate::TestAutodiffBackend;
80 use burn::tensor::{Distribution, backend::Backend};
81 use burn_nn::{Linear, LinearConfig};
82
83 #[test]
84 fn test_accumulate_gradients_one_step() {
85 let device = Default::default();
86 let mut accumulator = GradientsAccumulator::new();
87 let layer = layer::<TestAutodiffBackend>(&device);
88 let loss = layer.forward(random_tensor::<TestAutodiffBackend>(&device));
89 let grads = GradientsParams::from_grads(loss.backward(), &layer);
90
91 accumulator.accumulate(&layer, grads);
92
93 let grads = accumulator.grads();
94 assert!(!grads.is_empty())
95 }
96
97 #[test]
98 fn test_accumulate_gradients_two_steps() {
99 let device = Default::default();
100 let mut accumulator = GradientsAccumulator::new();
101 let layer = layer::<TestAutodiffBackend>(&device);
102 let loss_1 = layer.forward(random_tensor(&device));
103 let loss_2 = layer.forward(random_tensor(&device));
104 let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer);
105 let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer);
106
107 accumulator.accumulate(&layer, grads_1);
108 accumulator.accumulate(&layer, grads_2);
109
110 let grads = accumulator.grads();
111 assert_eq!(grads.len(), 2)
112 }
113
114 fn layer<B: Backend>(device: &B::Device) -> Linear<B> {
115 LinearConfig::new(20, 20).init(device)
116 }
117
118 fn random_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
119 Tensor::<B, 2>::random([2, 20], Distribution::Default, device)
120 }
121}