burn_optim/optim/
grads.rs

1use burn_core as burn;
2
3#[cfg(feature = "collective")]
4use burn_collective::{CollectiveError, PeerId, ReduceOperation, all_reduce};
5
6use burn::{
7    Tensor,
8    tensor::{
9        backend::{AutodiffBackend, Backend},
10        container::TensorContainer,
11    },
12};
13
14use burn::module::{AutodiffModule, ParamId};
15
16use super::visitor::{GradientsParamsChangeDevice, GradientsParamsConverter};
17
18/// Data type that contains gradients for parameters.
19#[derive(Default, Debug)]
20pub struct GradientsParams {
21    container: TensorContainer<ParamId>,
22}
23
24impl GradientsParams {
25    /// Creates a new [GradientsParams](GradientsParams).
26    pub fn new() -> Self {
27        Self::default()
28    }
29
30    /// Extract each tensor gradients for the given [module](AutodiffModule).
31    ///
32    /// Note: This consumes the gradients. See ['from_module'] to extract gradients only for
33    ///  a specific module.
34    pub fn from_grads<B: AutodiffBackend, M: AutodiffModule<B>>(
35        grads: B::Gradients,
36        module: &M,
37    ) -> Self {
38        let mut grads = grads;
39        Self::from_module(&mut grads, module)
40    }
41
42    /// Extract each tensor gradients for the given [module](AutodiffModule).
43    pub fn from_module<B: AutodiffBackend, M: AutodiffModule<B>>(
44        grads: &mut B::Gradients,
45        module: &M,
46    ) -> Self {
47        let mut grads_params = GradientsParams::new();
48        let mut visitor = GradientsParamsConverter::<M, B>::new(grads, &mut grads_params, None);
49        module.visit(&mut visitor);
50        grads_params
51    }
52
53    /// Extract tensor gradients for the given [module](AutodiffModule) and given parameters.
54    pub fn from_params<B: AutodiffBackend, M: AutodiffModule<B>>(
55        grads: &mut B::Gradients,
56        module: &M,
57        params: &[ParamId],
58    ) -> Self {
59        let mut grads_params = GradientsParams::new();
60        let mut visitor =
61            GradientsParamsConverter::<M, B>::new(grads, &mut grads_params, Some(params.to_vec()));
62        module.visit(&mut visitor);
63        grads_params
64    }
65
66    /// Get the gradients for the given [parameter id](ParamId).
67    ///
68    /// # Notes
69    ///
70    /// You should use [remove](GradientsParams::remove) if you want to get the gradients
71    /// only one time.
72    pub fn get<B, const D: usize>(&self, id: ParamId) -> Option<Tensor<B, D>>
73    where
74        B: Backend,
75    {
76        self.container.get(&id).map(Tensor::from_primitive)
77    }
78
79    /// Remove the gradients for the given [parameter id](ParamId).
80    pub fn remove<B, const D: usize>(&mut self, id: ParamId) -> Option<Tensor<B, D>>
81    where
82        B: Backend,
83    {
84        self.container.remove(&id).map(Tensor::from_primitive)
85    }
86
87    /// Register a gradients tensor for the given [parameter id](ParamId).
88    ///
89    /// # Notes
90    ///
91    /// If a tensor is already registered for the given [parameter id](ParamId), it will be replaced.
92    pub fn register<B, const D: usize>(&mut self, id: ParamId, value: Tensor<B, D>)
93    where
94        B: Backend,
95    {
96        self.container.register(id, value.into_primitive())
97    }
98
99    /// The number of gradients tensors registered.
100    pub fn len(&self) -> usize {
101        self.container.len()
102    }
103
104    /// If any tensor is contained.
105    pub fn is_empty(&self) -> bool {
106        self.len() == 0
107    }
108
109    /// Change the device of each tensor gradients registered for the given [module](AutodiffModule).
110    pub fn to_device<B: AutodiffBackend, M: AutodiffModule<B>>(
111        mut self,
112        device: &B::Device,
113        module: &M,
114    ) -> Self {
115        let mut visitor = GradientsParamsChangeDevice::<M, B>::new(device, &mut self);
116        module.visit(&mut visitor);
117        self
118    }
119
120    /// Syncs the gradient params with the other peers in the collective.
121    #[cfg(feature = "collective")]
122    pub fn all_reduce<B: Backend>(
123        mut self,
124        peer_id: PeerId,
125        op: ReduceOperation,
126    ) -> Result<Self, CollectiveError> {
127        let mut ids = self
128            .container
129            .ids()
130            .into_iter()
131            .copied()
132            .collect::<Vec<ParamId>>();
133        // This is crucial, since the all-reduce operations need to happen in the same order for the same parameters on all nodes!
134        ids.sort();
135
136        for id in ids {
137            let Some(grad) = self.container.remove::<B>(&id) else {
138                todo!()
139            };
140
141            let grad = match grad {
142                burn::tensor::TensorPrimitive::Float(grad) => {
143                    let grad = all_reduce::<B>(peer_id, grad, op)?;
144                    burn::tensor::TensorPrimitive::Float(grad)
145                }
146                burn::tensor::TensorPrimitive::QFloat(_grad) => {
147                    unimplemented!("quantized all-reduce unimplemented")
148                }
149            };
150
151            self.container.register::<B>(id, grad);
152        }
153
154        Ok(self)
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use crate::TestAutodiffBackend;
162    use burn::module::{Module, list_param_ids};
163    use burn::tensor::{Distribution, backend::Backend};
164    use burn_nn::{Linear, LinearConfig};
165
166    #[test]
167    fn test_convert_grads() {
168        let device = Default::default();
169        let layer_1 = layer::<TestAutodiffBackend>(&device);
170        let mut layer_2 = layer_1.clone();
171        layer_2 = layer_2.fork(&device);
172        let loss_1 = layer_1.forward(random_tensor(&device));
173        let loss_2 = layer_2.forward(random_tensor(&device));
174        let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer_1);
175        let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer_2);
176
177        let param_ids_1 = list_param_ids(&layer_1);
178        let param_ids_2 = list_param_ids(&layer_2);
179
180        assert_eq!(param_ids_1, param_ids_2);
181        assert_eq!(grads_1.len(), param_ids_1.len());
182        assert_eq!(grads_2.len(), param_ids_2.len());
183    }
184
185    fn layer<B: Backend>(device: &B::Device) -> Linear<B> {
186        LinearConfig::new(20, 20).init(device)
187    }
188
189    fn random_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
190        Tensor::<B, 2>::random([2, 20], Distribution::Default, device)
191    }
192}