burn_optim/optim/
grads.rs1use burn_core as burn;
2
3use burn::{
4 Tensor,
5 tensor::{
6 backend::{AutodiffBackend, Backend},
7 container::TensorContainer,
8 },
9};
10#[cfg(feature = "collective")]
11use burn_collective::{CollectiveError, PeerId, ReduceOperation, all_reduce};
12
13use burn::module::{AutodiffModule, ParamId};
14
15use super::visitor::{GradientsParamsChangeDevice, GradientsParamsConverter};
16
17#[derive(Default, Debug)]
19pub struct GradientsParams {
20 container: TensorContainer<ParamId>,
21}
22
23impl GradientsParams {
24 pub fn new() -> Self {
26 Self::default()
27 }
28
29 pub fn from_grads<B: AutodiffBackend, M: AutodiffModule<B>>(
34 grads: B::Gradients,
35 module: &M,
36 ) -> Self {
37 let mut grads = grads;
38 Self::from_module(&mut grads, module)
39 }
40
41 pub fn from_module<B: AutodiffBackend, M: AutodiffModule<B>>(
43 grads: &mut B::Gradients,
44 module: &M,
45 ) -> Self {
46 let mut grads_params = GradientsParams::new();
47 let mut visitor = GradientsParamsConverter::<M, B>::new(grads, &mut grads_params, None);
48 module.visit(&mut visitor);
49 grads_params
50 }
51
52 pub fn from_params<B: AutodiffBackend, M: AutodiffModule<B>>(
54 grads: &mut B::Gradients,
55 module: &M,
56 params: &[ParamId],
57 ) -> Self {
58 let mut grads_params = GradientsParams::new();
59 let mut visitor =
60 GradientsParamsConverter::<M, B>::new(grads, &mut grads_params, Some(params.to_vec()));
61 module.visit(&mut visitor);
62 grads_params
63 }
64
65 pub fn get<B, const D: usize>(&self, id: ParamId) -> Option<Tensor<B, D>>
72 where
73 B: Backend,
74 {
75 self.container.get(&id).map(Tensor::from_primitive)
76 }
77
78 pub fn remove<B, const D: usize>(&mut self, id: ParamId) -> Option<Tensor<B, D>>
80 where
81 B: Backend,
82 {
83 self.container.remove(&id).map(Tensor::from_primitive)
84 }
85
86 pub fn register<B, const D: usize>(&mut self, id: ParamId, value: Tensor<B, D>)
92 where
93 B: Backend,
94 {
95 self.container.register(id, value.into_primitive())
96 }
97
98 pub fn len(&self) -> usize {
100 self.container.len()
101 }
102
103 pub fn is_empty(&self) -> bool {
105 self.len() == 0
106 }
107
108 pub fn to_device<B: AutodiffBackend, M: AutodiffModule<B>>(
110 mut self,
111 device: &B::Device,
112 module: &M,
113 ) -> Self {
114 let mut visitor = GradientsParamsChangeDevice::<M, B>::new(device, &mut self);
115 module.visit(&mut visitor);
116 self
117 }
118
119 #[cfg(feature = "collective")]
121 pub fn all_reduce<B: Backend>(
122 mut self,
123 peer_id: PeerId,
124 op: ReduceOperation,
125 ) -> Result<Self, CollectiveError> {
126 let mut ids = self
127 .container
128 .ids()
129 .into_iter()
130 .copied()
131 .collect::<Vec<ParamId>>();
132 ids.sort();
134
135 for id in ids {
136 let Some(grad) = self.container.remove::<B>(&id) else {
137 todo!()
138 };
139
140 let grad = match grad {
141 burn::tensor::TensorPrimitive::Float(grad) => {
142 let grad = all_reduce::<B>(peer_id, grad, op)?;
143 burn::tensor::TensorPrimitive::Float(grad)
144 }
145 burn::tensor::TensorPrimitive::QFloat(_grad) => {
146 unimplemented!("quantized all-reduce unimplemented")
147 }
148 };
149
150 self.container.register::<B>(id, grad);
151 }
152
153 Ok(self)
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use crate::TestAutodiffBackend;
161 use burn::module::{Module, list_param_ids};
162 use burn::tensor::{Distribution, backend::Backend};
163 use burn_nn::{Linear, LinearConfig};
164
165 #[test]
166 fn test_convert_grads() {
167 let device = Default::default();
168 let layer_1 = layer::<TestAutodiffBackend>(&device);
169 let mut layer_2 = layer_1.clone();
170 layer_2 = layer_2.fork(&device);
171 let loss_1 = layer_1.forward(random_tensor(&device));
172 let loss_2 = layer_2.forward(random_tensor(&device));
173 let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer_1);
174 let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer_2);
175
176 let param_ids_1 = list_param_ids(&layer_1);
177 let param_ids_2 = list_param_ids(&layer_2);
178
179 assert_eq!(param_ids_1, param_ids_2);
180 assert_eq!(grads_1.len(), param_ids_1.len());
181 assert_eq!(grads_2.len(), param_ids_2.len());
182 }
183
184 fn layer<B: Backend>(device: &B::Device) -> Linear<B> {
185 LinearConfig::new(20, 20).init(device)
186 }
187
188 fn random_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
189 Tensor::<B, 2>::random([2, 20], Distribution::Default, device)
190 }
191}