burn_optim/optim/
grads.rs1use burn_core as burn;
2
3#[cfg(feature = "collective")]
4use burn_collective::{CollectiveError, PeerId, ReduceOperation, all_reduce};
5
6use burn::{
7 Tensor,
8 tensor::{
9 backend::{AutodiffBackend, Backend},
10 container::TensorContainer,
11 },
12};
13
14use burn::module::{AutodiffModule, ParamId};
15
16use super::visitor::{GradientsParamsChangeDevice, GradientsParamsConverter};
17
18#[derive(Default, Debug)]
20pub struct GradientsParams {
21 container: TensorContainer<ParamId>,
22}
23
24impl GradientsParams {
25 pub fn new() -> Self {
27 Self::default()
28 }
29
30 pub fn from_grads<B: AutodiffBackend, M: AutodiffModule<B>>(
35 grads: B::Gradients,
36 module: &M,
37 ) -> Self {
38 let mut grads = grads;
39 Self::from_module(&mut grads, module)
40 }
41
42 pub fn from_module<B: AutodiffBackend, M: AutodiffModule<B>>(
44 grads: &mut B::Gradients,
45 module: &M,
46 ) -> Self {
47 let mut grads_params = GradientsParams::new();
48 let mut visitor = GradientsParamsConverter::<M, B>::new(grads, &mut grads_params, None);
49 module.visit(&mut visitor);
50 grads_params
51 }
52
53 pub fn from_params<B: AutodiffBackend, M: AutodiffModule<B>>(
55 grads: &mut B::Gradients,
56 module: &M,
57 params: &[ParamId],
58 ) -> Self {
59 let mut grads_params = GradientsParams::new();
60 let mut visitor =
61 GradientsParamsConverter::<M, B>::new(grads, &mut grads_params, Some(params.to_vec()));
62 module.visit(&mut visitor);
63 grads_params
64 }
65
66 pub fn get<B, const D: usize>(&self, id: ParamId) -> Option<Tensor<B, D>>
73 where
74 B: Backend,
75 {
76 self.container.get(&id).map(Tensor::from_primitive)
77 }
78
79 pub fn remove<B, const D: usize>(&mut self, id: ParamId) -> Option<Tensor<B, D>>
81 where
82 B: Backend,
83 {
84 self.container.remove(&id).map(Tensor::from_primitive)
85 }
86
87 pub fn register<B, const D: usize>(&mut self, id: ParamId, value: Tensor<B, D>)
93 where
94 B: Backend,
95 {
96 self.container.register(id, value.into_primitive())
97 }
98
99 pub fn len(&self) -> usize {
101 self.container.len()
102 }
103
104 pub fn is_empty(&self) -> bool {
106 self.len() == 0
107 }
108
109 pub fn to_device<B: AutodiffBackend, M: AutodiffModule<B>>(
111 mut self,
112 device: &B::Device,
113 module: &M,
114 ) -> Self {
115 let mut visitor = GradientsParamsChangeDevice::<M, B>::new(device, &mut self);
116 module.visit(&mut visitor);
117 self
118 }
119
120 #[cfg(feature = "collective")]
122 pub fn all_reduce<B: Backend>(
123 mut self,
124 peer_id: PeerId,
125 op: ReduceOperation,
126 ) -> Result<Self, CollectiveError> {
127 let mut ids = self
128 .container
129 .ids()
130 .into_iter()
131 .copied()
132 .collect::<Vec<ParamId>>();
133 ids.sort();
135
136 for id in ids {
137 let Some(grad) = self.container.remove::<B>(&id) else {
138 todo!()
139 };
140
141 let grad = match grad {
142 burn::tensor::TensorPrimitive::Float(grad) => {
143 let grad = all_reduce::<B>(peer_id, grad, op)?;
144 burn::tensor::TensorPrimitive::Float(grad)
145 }
146 burn::tensor::TensorPrimitive::QFloat(_grad) => {
147 unimplemented!("quantized all-reduce unimplemented")
148 }
149 };
150
151 self.container.register::<B>(id, grad);
152 }
153
154 Ok(self)
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use crate::TestAutodiffBackend;
162 use burn::module::{Module, list_param_ids};
163 use burn::tensor::{Distribution, backend::Backend};
164 use burn_nn::{Linear, LinearConfig};
165
166 #[test]
167 fn test_convert_grads() {
168 let device = Default::default();
169 let layer_1 = layer::<TestAutodiffBackend>(&device);
170 let mut layer_2 = layer_1.clone();
171 layer_2 = layer_2.fork(&device);
172 let loss_1 = layer_1.forward(random_tensor(&device));
173 let loss_2 = layer_2.forward(random_tensor(&device));
174 let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer_1);
175 let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer_2);
176
177 let param_ids_1 = list_param_ids(&layer_1);
178 let param_ids_2 = list_param_ids(&layer_2);
179
180 assert_eq!(param_ids_1, param_ids_2);
181 assert_eq!(grads_1.len(), param_ids_1.len());
182 assert_eq!(grads_2.len(), param_ids_2.len());
183 }
184
185 fn layer<B: Backend>(device: &B::Device) -> Linear<B> {
186 LinearConfig::new(20, 20).init(device)
187 }
188
189 fn random_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
190 Tensor::<B, 2>::random([2, 20], Distribution::Default, device)
191 }
192}