1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
use crate::module::{Module, ModuleVisitor, ParamId};
use burn_tensor::{backend::ADBackend, Tensor};
use super::GradientsParams;
pub struct GradientsAccumulator {
grads: GradientsParams,
}
impl Default for GradientsAccumulator {
fn default() -> Self {
Self::new()
}
}
impl GradientsAccumulator {
pub fn new() -> Self {
Self {
grads: GradientsParams::new(),
}
}
}
impl GradientsAccumulator {
pub fn accumulate<B: ADBackend, M>(&mut self, module: &M, grads: GradientsParams)
where
M: Module<Backend = B>,
{
let mut visitor = ModuleGradsAccumulator::new(&mut self.grads, grads);
module.visit(&mut visitor);
}
pub fn grads(&mut self) -> GradientsParams {
let mut grads = GradientsParams::new();
core::mem::swap(&mut self.grads, &mut grads);
grads
}
}
#[derive(new)]
struct ModuleGradsAccumulator<'a> {
grads: &'a mut GradientsParams,
grads_new: GradientsParams,
}
impl<'a, B: ADBackend> ModuleVisitor<B> for ModuleGradsAccumulator<'a> {
fn visit<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D>) {
let grad_updated = match self.grads_new.remove::<B::InnerBackend, D>(id) {
Some(new) => match self.grads.remove::<B::InnerBackend, D>(id) {
Some(grad) => grad.add(new),
None => new,
},
None => match self.grads.remove::<B::InnerBackend, D>(id) {
Some(grad) => grad,
None => return,
},
};
self.grads
.register::<B::InnerBackend, D>(id.clone(), grad_updated);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
nn::{Linear, LinearConfig},
TestADBackend,
};
use burn_tensor::Distribution;
#[test]
fn test_accumulate_gradients_one_step() {
let mut accumulator = GradientsAccumulator::new();
let layer = layer();
let loss = layer.forward(random_tensor());
let grads = GradientsParams::from_grads(loss.backward(), &layer);
accumulator.accumulate(&layer, grads);
let grads = accumulator.grads();
assert!(!grads.is_empty())
}
#[test]
fn test_accumulate_gradients_two_steps() {
let mut accumulator = GradientsAccumulator::new();
let layer = layer();
let loss_1 = layer.forward(random_tensor());
let loss_2 = layer.forward(random_tensor());
let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer);
let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer);
accumulator.accumulate(&layer, grads_1);
accumulator.accumulate(&layer, grads_2);
let grads = accumulator.grads();
assert_eq!(grads.len(), 2)
}
fn layer() -> Linear<TestADBackend> {
Linear::<TestADBackend>::new(&LinearConfig::new(20, 20).with_bias(true))
}
fn random_tensor() -> Tensor<TestADBackend, 2> {
Tensor::<TestADBackend, 2>::random([2, 20], Distribution::Standard)
}
}