burn_core/optim/
grads.rs

1use burn_tensor::{
2    Tensor,
3    backend::{AutodiffBackend, Backend},
4    container::TensorContainer,
5};
6
7use crate::module::{AutodiffModule, ParamId};
8
9use super::visitor::{GradientsParamsChangeDevice, GradientsParamsConverter};
10
11/// Data type that contains gradients for parameters.
12#[derive(Default, Debug)]
13pub struct GradientsParams {
14    container: TensorContainer<ParamId>,
15}
16
17impl GradientsParams {
18    /// Creates a new [GradientsParams](GradientsParams).
19    pub fn new() -> Self {
20        Self::default()
21    }
22
23    /// Extract each tensor gradients for the given [module](AutodiffModule).
24    ///
25    /// Note: This consumes the gradients. See ['from_module'] to extract gradients only for
26    ///  a specific module.
27    pub fn from_grads<B: AutodiffBackend, M: AutodiffModule<B>>(
28        grads: B::Gradients,
29        module: &M,
30    ) -> Self {
31        let mut grads = grads;
32        Self::from_module(&mut grads, module)
33    }
34
35    /// Extract each tensor gradients for the given [module](AutodiffModule).
36    pub fn from_module<B: AutodiffBackend, M: AutodiffModule<B>>(
37        grads: &mut B::Gradients,
38        module: &M,
39    ) -> Self {
40        let mut grads_params = GradientsParams::new();
41        let mut visitor = GradientsParamsConverter::<M, B>::new(grads, &mut grads_params, None);
42        module.visit(&mut visitor);
43        grads_params
44    }
45
46    /// Extract tensor gradients for the given [module](AutodiffModule) and given parameters.
47    pub fn from_params<B: AutodiffBackend, M: AutodiffModule<B>>(
48        grads: &mut B::Gradients,
49        module: &M,
50        params: &[ParamId],
51    ) -> Self {
52        let mut grads_params = GradientsParams::new();
53        let mut visitor =
54            GradientsParamsConverter::<M, B>::new(grads, &mut grads_params, Some(params.to_vec()));
55        module.visit(&mut visitor);
56        grads_params
57    }
58
59    /// Get the gradients for the given [parameter id](ParamId).
60    ///
61    /// # Notes
62    ///
63    /// You should use [remove](GradientsParams::remove) if you want to get the gradients
64    /// only one time.
65    pub fn get<B, const D: usize>(&self, id: ParamId) -> Option<Tensor<B, D>>
66    where
67        B: Backend,
68    {
69        self.container.get(&id).map(Tensor::from_primitive)
70    }
71
72    /// Remove the gradients for the given [parameter id](ParamId).
73    pub fn remove<B, const D: usize>(&mut self, id: ParamId) -> Option<Tensor<B, D>>
74    where
75        B: Backend,
76    {
77        self.container.remove(&id).map(Tensor::from_primitive)
78    }
79
80    /// Register a gradients tensor for the given [parameter id](ParamId).
81    ///
82    /// # Notes
83    ///
84    /// If a tensor is already registered for the given [parameter id](ParamId), it will be replaced.
85    pub fn register<B, const D: usize>(&mut self, id: ParamId, value: Tensor<B, D>)
86    where
87        B: Backend,
88    {
89        self.container.register(id, value.into_primitive())
90    }
91
92    /// The number of gradients tensors registered.
93    pub fn len(&self) -> usize {
94        self.container.len()
95    }
96
97    /// If any tensor is contained.
98    pub fn is_empty(&self) -> bool {
99        self.len() == 0
100    }
101
102    /// Change the device of each tensor gradients registered for the given [module](AutodiffModule).
103    pub fn to_device<B: AutodiffBackend, M: AutodiffModule<B>>(
104        mut self,
105        device: &B::Device,
106        module: &M,
107    ) -> Self {
108        let mut visitor = GradientsParamsChangeDevice::<M, B>::new(device, &mut self);
109        module.visit(&mut visitor);
110        self
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use crate::{
118        TestAutodiffBackend,
119        module::{Module, list_param_ids},
120        nn::{Linear, LinearConfig},
121    };
122    use burn_tensor::{Distribution, backend::Backend};
123
124    #[test]
125    fn test_convert_grads() {
126        let device = Default::default();
127        let layer_1 = layer::<TestAutodiffBackend>(&device);
128        let mut layer_2 = layer_1.clone();
129        layer_2 = layer_2.fork(&device);
130        let loss_1 = layer_1.forward(random_tensor(&device));
131        let loss_2 = layer_2.forward(random_tensor(&device));
132        let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer_1);
133        let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer_2);
134
135        let param_ids_1 = list_param_ids(&layer_1);
136        let param_ids_2 = list_param_ids(&layer_2);
137
138        assert_eq!(param_ids_1, param_ids_2);
139        assert_eq!(grads_1.len(), param_ids_1.len());
140        assert_eq!(grads_2.len(), param_ids_2.len());
141    }
142
143    fn layer<B: Backend>(device: &B::Device) -> Linear<B> {
144        LinearConfig::new(20, 20).with_bias(true).init(device)
145    }
146
147    fn random_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
148        Tensor::<B, 2>::random([2, 20], Distribution::Default, device)
149    }
150}