Skip to main content

burn_optim/optim/
grads.rs

1use burn_core as burn;
2
3use burn::{
4    Tensor,
5    tensor::{
6        backend::{AutodiffBackend, Backend},
7        container::TensorContainer,
8    },
9};
10#[cfg(feature = "collective")]
11use burn_collective::{CollectiveError, PeerId, ReduceOperation, all_reduce};
12
13use burn::module::{AutodiffModule, ParamId};
14
15use super::visitor::{GradientsParamsChangeDevice, GradientsParamsConverter};
16
17/// Data type that contains gradients for parameters.
18#[derive(Default, Debug)]
19pub struct GradientsParams {
20    container: TensorContainer<ParamId>,
21}
22
23impl GradientsParams {
24    /// Creates a new [GradientsParams](GradientsParams).
25    pub fn new() -> Self {
26        Self::default()
27    }
28
29    /// Extract each tensor gradients for the given [module](AutodiffModule).
30    ///
31    /// Note: This consumes the gradients. See ['from_module'] to extract gradients only for
32    ///  a specific module.
33    pub fn from_grads<B: AutodiffBackend, M: AutodiffModule<B>>(
34        grads: B::Gradients,
35        module: &M,
36    ) -> Self {
37        let mut grads = grads;
38        Self::from_module(&mut grads, module)
39    }
40
41    /// Extract each tensor gradients for the given [module](AutodiffModule).
42    pub fn from_module<B: AutodiffBackend, M: AutodiffModule<B>>(
43        grads: &mut B::Gradients,
44        module: &M,
45    ) -> Self {
46        let mut grads_params = GradientsParams::new();
47        let mut visitor = GradientsParamsConverter::<M, B>::new(grads, &mut grads_params, None);
48        module.visit(&mut visitor);
49        grads_params
50    }
51
52    /// Extract tensor gradients for the given [module](AutodiffModule) and given parameters.
53    pub fn from_params<B: AutodiffBackend, M: AutodiffModule<B>>(
54        grads: &mut B::Gradients,
55        module: &M,
56        params: &[ParamId],
57    ) -> Self {
58        let mut grads_params = GradientsParams::new();
59        let mut visitor =
60            GradientsParamsConverter::<M, B>::new(grads, &mut grads_params, Some(params.to_vec()));
61        module.visit(&mut visitor);
62        grads_params
63    }
64
65    /// Get the gradients for the given [parameter id](ParamId).
66    ///
67    /// # Notes
68    ///
69    /// You should use [remove](GradientsParams::remove) if you want to get the gradients
70    /// only one time.
71    pub fn get<B, const D: usize>(&self, id: ParamId) -> Option<Tensor<B, D>>
72    where
73        B: Backend,
74    {
75        self.container.get(&id).map(Tensor::from_primitive)
76    }
77
78    /// Remove the gradients for the given [parameter id](ParamId).
79    pub fn remove<B, const D: usize>(&mut self, id: ParamId) -> Option<Tensor<B, D>>
80    where
81        B: Backend,
82    {
83        self.container.remove(&id).map(Tensor::from_primitive)
84    }
85
86    /// Register a gradients tensor for the given [parameter id](ParamId).
87    ///
88    /// # Notes
89    ///
90    /// If a tensor is already registered for the given [parameter id](ParamId), it will be replaced.
91    pub fn register<B, const D: usize>(&mut self, id: ParamId, value: Tensor<B, D>)
92    where
93        B: Backend,
94    {
95        self.container.register(id, value.into_primitive())
96    }
97
98    /// The number of gradients tensors registered.
99    pub fn len(&self) -> usize {
100        self.container.len()
101    }
102
103    /// If any tensor is contained.
104    pub fn is_empty(&self) -> bool {
105        self.len() == 0
106    }
107
108    /// Change the device of each tensor gradients registered for the given [module](AutodiffModule).
109    pub fn to_device<B: AutodiffBackend, M: AutodiffModule<B>>(
110        mut self,
111        device: &B::Device,
112        module: &M,
113    ) -> Self {
114        let mut visitor = GradientsParamsChangeDevice::<M, B>::new(device, &mut self);
115        module.visit(&mut visitor);
116        self
117    }
118
119    /// Syncs the gradient params with the other peers in the collective.
120    #[cfg(feature = "collective")]
121    pub fn all_reduce<B: Backend>(
122        mut self,
123        peer_id: PeerId,
124        op: ReduceOperation,
125    ) -> Result<Self, CollectiveError> {
126        let mut ids = self
127            .container
128            .ids()
129            .into_iter()
130            .copied()
131            .collect::<Vec<ParamId>>();
132        // This is crucial, since the all-reduce operations need to happen in the same order for the same parameters on all nodes!
133        ids.sort();
134
135        for id in ids {
136            let Some(grad) = self.container.remove::<B>(&id) else {
137                todo!()
138            };
139
140            let grad = match grad {
141                burn::tensor::TensorPrimitive::Float(grad) => {
142                    let grad = all_reduce::<B>(peer_id, grad, op)?;
143                    burn::tensor::TensorPrimitive::Float(grad)
144                }
145                burn::tensor::TensorPrimitive::QFloat(_grad) => {
146                    unimplemented!("quantized all-reduce unimplemented")
147                }
148            };
149
150            self.container.register::<B>(id, grad);
151        }
152
153        Ok(self)
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use crate::TestAutodiffBackend;
161    use burn::module::{Module, list_param_ids};
162    use burn::tensor::{Distribution, backend::Backend};
163    use burn_nn::{Linear, LinearConfig};
164
165    #[test]
166    fn test_convert_grads() {
167        let device = Default::default();
168        let layer_1 = layer::<TestAutodiffBackend>(&device);
169        let mut layer_2 = layer_1.clone();
170        layer_2 = layer_2.fork(&device);
171        let loss_1 = layer_1.forward(random_tensor(&device));
172        let loss_2 = layer_2.forward(random_tensor(&device));
173        let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer_1);
174        let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer_2);
175
176        let param_ids_1 = list_param_ids(&layer_1);
177        let param_ids_2 = list_param_ids(&layer_2);
178
179        assert_eq!(param_ids_1, param_ids_2);
180        assert_eq!(grads_1.len(), param_ids_1.len());
181        assert_eq!(grads_2.len(), param_ids_2.len());
182    }
183
184    fn layer<B: Backend>(device: &B::Device) -> Linear<B> {
185        LinearConfig::new(20, 20).init(device)
186    }
187
188    fn random_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
189        Tensor::<B, 2>::random([2, 20], Distribution::Default, device)
190    }
191}