1use burn_tensor::{
2 Tensor,
3 backend::{AutodiffBackend, Backend},
4 container::TensorContainer,
5};
6
7use crate::module::{AutodiffModule, ParamId};
8
9use super::visitor::{GradientsParamsChangeDevice, GradientsParamsConverter};
10
11#[derive(Default, Debug)]
13pub struct GradientsParams {
14 container: TensorContainer<ParamId>,
15}
16
17impl GradientsParams {
18 pub fn new() -> Self {
20 Self::default()
21 }
22
23 pub fn from_grads<B: AutodiffBackend, M: AutodiffModule<B>>(
28 grads: B::Gradients,
29 module: &M,
30 ) -> Self {
31 let mut grads = grads;
32 Self::from_module(&mut grads, module)
33 }
34
35 pub fn from_module<B: AutodiffBackend, M: AutodiffModule<B>>(
37 grads: &mut B::Gradients,
38 module: &M,
39 ) -> Self {
40 let mut grads_params = GradientsParams::new();
41 let mut visitor = GradientsParamsConverter::<M, B>::new(grads, &mut grads_params, None);
42 module.visit(&mut visitor);
43 grads_params
44 }
45
46 pub fn from_params<B: AutodiffBackend, M: AutodiffModule<B>>(
48 grads: &mut B::Gradients,
49 module: &M,
50 params: &[ParamId],
51 ) -> Self {
52 let mut grads_params = GradientsParams::new();
53 let mut visitor =
54 GradientsParamsConverter::<M, B>::new(grads, &mut grads_params, Some(params.to_vec()));
55 module.visit(&mut visitor);
56 grads_params
57 }
58
59 pub fn get<B, const D: usize>(&self, id: ParamId) -> Option<Tensor<B, D>>
66 where
67 B: Backend,
68 {
69 self.container.get(&id).map(Tensor::from_primitive)
70 }
71
72 pub fn remove<B, const D: usize>(&mut self, id: ParamId) -> Option<Tensor<B, D>>
74 where
75 B: Backend,
76 {
77 self.container.remove(&id).map(Tensor::from_primitive)
78 }
79
80 pub fn register<B, const D: usize>(&mut self, id: ParamId, value: Tensor<B, D>)
86 where
87 B: Backend,
88 {
89 self.container.register(id, value.into_primitive())
90 }
91
92 pub fn len(&self) -> usize {
94 self.container.len()
95 }
96
97 pub fn is_empty(&self) -> bool {
99 self.len() == 0
100 }
101
102 pub fn to_device<B: AutodiffBackend, M: AutodiffModule<B>>(
104 mut self,
105 device: &B::Device,
106 module: &M,
107 ) -> Self {
108 let mut visitor = GradientsParamsChangeDevice::<M, B>::new(device, &mut self);
109 module.visit(&mut visitor);
110 self
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use crate::{
118 TestAutodiffBackend,
119 module::{Module, list_param_ids},
120 nn::{Linear, LinearConfig},
121 };
122 use burn_tensor::{Distribution, backend::Backend};
123
124 #[test]
125 fn test_convert_grads() {
126 let device = Default::default();
127 let layer_1 = layer::<TestAutodiffBackend>(&device);
128 let mut layer_2 = layer_1.clone();
129 layer_2 = layer_2.fork(&device);
130 let loss_1 = layer_1.forward(random_tensor(&device));
131 let loss_2 = layer_2.forward(random_tensor(&device));
132 let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer_1);
133 let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer_2);
134
135 let param_ids_1 = list_param_ids(&layer_1);
136 let param_ids_2 = list_param_ids(&layer_2);
137
138 assert_eq!(param_ids_1, param_ids_2);
139 assert_eq!(grads_1.len(), param_ids_1.len());
140 assert_eq!(grads_2.len(), param_ids_2.len());
141 }
142
143 fn layer<B: Backend>(device: &B::Device) -> Linear<B> {
144 LinearConfig::new(20, 20).with_bias(true).init(device)
145 }
146
147 fn random_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
148 Tensor::<B, 2>::random([2, 20], Distribution::Default, device)
149 }
150}