burn_core/optim/
grad_accum.rs1use core::marker::PhantomData;
2
3use crate::module::{AutodiffModule, ModuleVisitor, ParamId};
4
5use burn_tensor::{Tensor, backend::AutodiffBackend};
6
7use super::GradientsParams;
8
9pub struct GradientsAccumulator<M> {
11 grads: GradientsParams,
12 phantom: PhantomData<M>,
13}
14
15impl<M> Default for GradientsAccumulator<M> {
16 fn default() -> Self {
17 Self::new()
18 }
19}
20
21impl<M> GradientsAccumulator<M> {
22 pub fn new() -> Self {
24 Self {
25 grads: GradientsParams::new(),
26 phantom: PhantomData,
27 }
28 }
29}
30
31impl<M> GradientsAccumulator<M> {
32 pub fn accumulate<B: AutodiffBackend>(&mut self, module: &M, grads: GradientsParams)
34 where
35 M: AutodiffModule<B>,
36 {
37 let mut visitor = ModuleGradsAccumulator::<M>::new(&mut self.grads, grads);
38 module.visit(&mut visitor);
39 }
40
41 pub fn grads(&mut self) -> GradientsParams {
43 let mut grads = GradientsParams::new();
44 core::mem::swap(&mut self.grads, &mut grads);
45
46 grads
47 }
48}
49
50#[derive(new)]
51struct ModuleGradsAccumulator<'a, M> {
52 grads: &'a mut GradientsParams,
53 grads_new: GradientsParams,
54 phantom: PhantomData<M>,
55}
56
57impl<B: AutodiffBackend, M: AutodiffModule<B>> ModuleVisitor<B> for ModuleGradsAccumulator<'_, M> {
58 fn visit_float<const D: usize>(&mut self, id: ParamId, _tensor: &Tensor<B, D>) {
59 let grad_updated = match self.grads_new.remove::<B::InnerBackend, D>(id) {
60 Some(new) => match self.grads.remove::<B::InnerBackend, D>(id) {
61 Some(grad) => grad.add(new),
62 None => new,
63 },
64 None => match self.grads.remove::<B::InnerBackend, D>(id) {
65 Some(grad) => grad,
66 None => return,
67 },
68 };
69
70 self.grads.register::<B::InnerBackend, D>(id, grad_updated);
71 }
72}
73
74#[cfg(test)]
75mod tests {
76 use super::*;
77 use crate::{
78 TestAutodiffBackend,
79 nn::{Linear, LinearConfig},
80 };
81 use burn_tensor::{Distribution, backend::Backend};
82
83 #[test]
84 fn test_accumulate_gradients_one_step() {
85 let device = Default::default();
86 let mut accumulator = GradientsAccumulator::new();
87 let layer = layer::<TestAutodiffBackend>(&device);
88 let loss = layer.forward(random_tensor::<TestAutodiffBackend>(&device));
89 let grads = GradientsParams::from_grads(loss.backward(), &layer);
90
91 accumulator.accumulate(&layer, grads);
92
93 let grads = accumulator.grads();
94 assert!(!grads.is_empty())
95 }
96
97 #[test]
98 fn test_accumulate_gradients_two_steps() {
99 let device = Default::default();
100 let mut accumulator = GradientsAccumulator::new();
101 let layer = layer::<TestAutodiffBackend>(&device);
102 let loss_1 = layer.forward(random_tensor(&device));
103 let loss_2 = layer.forward(random_tensor(&device));
104 let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer);
105 let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer);
106
107 accumulator.accumulate(&layer, grads_1);
108 accumulator.accumulate(&layer, grads_2);
109
110 let grads = accumulator.grads();
111 assert_eq!(grads.len(), 2)
112 }
113
114 fn layer<B: Backend>(device: &B::Device) -> Linear<B> {
115 LinearConfig::new(20, 20).with_bias(true).init(device)
116 }
117
118 fn random_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
119 Tensor::<B, 2>::random([2, 20], Distribution::Default, device)
120 }
121}