burn_optim/optim/
adagrad.rs

1use burn_core as burn;
2
3use burn::{module::AutodiffModule, record::Record};
4
5use burn::config::Config;
6use burn::tensor::{Tensor, backend::AutodiffBackend};
7use burn::tensor::{backend::Backend, ops::Device};
8
9use super::{
10    SimpleOptimizer,
11    adaptor::OptimizerAdaptor,
12    decay::{WeightDecay, WeightDecayConfig},
13};
14use crate::{LearningRate, grad_clipping::GradientClippingConfig};
15
16/// AdaGrad configuration.
17#[derive(Config, Debug)]
18pub struct AdaGradConfig {
19    #[config(default = 0.)]
20    lr_decay: f64,
21    #[config(default = 1e-5)]
22    epsilon: f32,
23    /// [Weight decay](WeightDecayConfig) config.
24    weight_decay: Option<WeightDecayConfig>,
25    /// [Gradient Clipping](GradientClippingConfig) config.
26    grad_clipping: Option<GradientClippingConfig>,
27}
28
29/// AdaGrad optimizer
30#[derive(Clone)]
31pub struct AdaGrad {
32    lr_decay: LrDecay,
33    weight_decay: Option<WeightDecay>,
34}
35
36/// AdaGrad state.
37#[derive(Record, Clone, new)]
38pub struct AdaGradState<B: Backend, const D: usize> {
39    lr_decay: LrDecayState<B, D>,
40}
41
42impl<B: Backend> SimpleOptimizer<B> for AdaGrad {
43    type State<const D: usize> = AdaGradState<B, D>;
44
45    fn step<const D: usize>(
46        &self,
47        lr: LearningRate,
48        tensor: Tensor<B, D>,
49        mut grad: Tensor<B, D>,
50        state: Option<Self::State<D>>,
51    ) -> (Tensor<B, D>, Option<Self::State<D>>) {
52        let mut state_lr_decay = None;
53
54        if let Some(state) = state {
55            state_lr_decay = Some(state.lr_decay);
56        }
57
58        if let Some(weight_decay) = &self.weight_decay {
59            grad = weight_decay.transform(grad, tensor.clone());
60        }
61
62        let (grad, state_lr_decay) = self.lr_decay.transform(grad, lr, state_lr_decay);
63
64        let state = AdaGradState::new(state_lr_decay);
65
66        (tensor - grad, Some(state))
67    }
68
69    fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
70        state.lr_decay = state.lr_decay.to_device(device);
71        state
72    }
73}
74
75impl AdaGradConfig {
76    /// Initialize AdaGrad optimizer.
77    ///
78    /// # Returns
79    ///
80    /// Returns an optimizer that can be used to optimize a module.
81    pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
82        &self,
83    ) -> OptimizerAdaptor<AdaGrad, M, B> {
84        let optim = AdaGrad {
85            lr_decay: LrDecay {
86                lr_decay: self.lr_decay,
87                epsilon: self.epsilon,
88            },
89            weight_decay: self.weight_decay.as_ref().map(WeightDecay::new),
90        };
91
92        let mut optim = OptimizerAdaptor::from(optim);
93        if let Some(config) = &self.grad_clipping {
94            optim = optim.with_grad_clipping(config.init());
95        }
96        optim
97    }
98}
99
100/// Learning rate decay state (also includes sum state).
101#[derive(Record, new, Clone)]
102pub struct LrDecayState<B: Backend, const D: usize> {
103    time: usize,
104    sum: Tensor<B, D>,
105}
106
107#[derive(Clone)]
108struct LrDecay {
109    lr_decay: f64,
110    epsilon: f32,
111}
112
113impl LrDecay {
114    pub fn transform<B: Backend, const D: usize>(
115        &self,
116        grad: Tensor<B, D>,
117        lr: LearningRate,
118        lr_decay_state: Option<LrDecayState<B, D>>,
119    ) -> (Tensor<B, D>, LrDecayState<B, D>) {
120        let state = if let Some(mut state) = lr_decay_state {
121            state.sum = state.sum.add(grad.clone().powi_scalar(2));
122            state.time += 1;
123            state
124        } else {
125            LrDecayState::new(1, grad.clone().powi_scalar(2))
126        };
127
128        let new_lr = lr / (1. + (state.time as f64 - 1.) * self.lr_decay);
129
130        let grad = grad
131            .div(state.sum.clone().sqrt().add_scalar(self.epsilon))
132            .mul_scalar(new_lr);
133
134        (grad, state)
135    }
136}
137
138impl<B: Backend, const D: usize> LrDecayState<B, D> {
139    /// Move state to device.
140    ///
141    /// # Arguments
142    ///
143    /// * `device` - Device to move state to.
144    ///
145    /// # Returns
146    ///
147    /// Returns state moved to device.
148    pub fn to_device(mut self, device: &B::Device) -> Self {
149        self.sum = self.sum.to_device(device);
150        self
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use burn::tensor::Tolerance;
157    use burn::tensor::ops::FloatElem;
158
159    use super::*;
160    use crate::TestAutodiffBackend;
161    use crate::{GradientsParams, Optimizer};
162    use burn::module::{Module, Param};
163    use burn::tensor::{Distribution, Tensor, TensorData};
164    use burn_nn::{Linear, LinearConfig, LinearRecord};
165
166    const LEARNING_RATE: LearningRate = 0.01;
167
168    #[test]
169    fn test_adagrad_optimizer_save_load_state() {
170        let device = Default::default();
171        let linear = LinearConfig::new(6, 6).init(&device);
172        let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
173        let mut optimizer = create_adagrad();
174        let grads = linear.forward(x).backward();
175        let grads = GradientsParams::from_grads(grads, &linear);
176        let _linear = optimizer.step(LEARNING_RATE, linear, grads);
177
178        #[cfg(feature = "std")]
179        {
180            use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
181
182            BinFileRecorder::<FullPrecisionSettings>::default()
183                .record(
184                    optimizer.to_record(),
185                    std::env::temp_dir().as_path().join("test_optim_adagrad"),
186                )
187                .unwrap();
188        }
189        #[cfg(not(feature = "std"))]
190        {
191            use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
192
193            let result = BinBytesRecorder::<FullPrecisionSettings>::default()
194                .record(optimizer.to_record(), ())
195                .unwrap();
196            assert!(!result.is_empty());
197        }
198
199        let state_optim_before = optimizer.to_record();
200        let state_optim_before_copy = optimizer.to_record();
201        let optimizer = create_adagrad();
202        let optimizer = optimizer.load_record(state_optim_before_copy);
203        let state_optim_after = optimizer.to_record();
204
205        assert_eq!(state_optim_before.len(), state_optim_after.len());
206    }
207
208    #[test]
209    fn test_adagrad_optimizer_with_numbers() {
210        let device = Default::default();
211        let linear = given_linear_layer(
212            TensorData::from([
213                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
214                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
215                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
216                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
217                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
218                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
219            ]),
220            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
221        );
222        let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
223            [
224                [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
225                [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
226            ],
227            &device,
228        )
229        .require_grad();
230        let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
231            [
232                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
233                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
234            ],
235            &device,
236        )
237        .require_grad();
238
239        let mut optimizer = AdaGradConfig::new()
240            .with_epsilon(1e-8)
241            .with_lr_decay(0.5)
242            .init();
243
244        let grads = linear.forward(x_1).backward();
245        let grads = GradientsParams::from_grads(grads, &linear);
246        let linear = optimizer.step(LEARNING_RATE, linear, grads);
247
248        let grads = linear.forward(x_2).backward();
249        let grads = GradientsParams::from_grads(grads, &linear);
250        let linear = optimizer.step(LEARNING_RATE, linear, grads);
251
252        let state_updated = linear.into_record();
253        let weights_expected = TensorData::from([
254            [-0.334989, 0.123011, 0.389911, 0.305611, 0.071511, 0.052711],
255            [
256                0.066144, -0.030056, -0.378256, 0.243444, 0.183944, -0.303756,
257            ],
258            [
259                -0.033462, 0.020138, -0.310662, 0.233938, -0.292462, 0.298538,
260            ],
261            [
262                -0.312636, -0.236036, -0.386136, -0.312736, -0.090736, 0.147964,
263            ],
264            [
265                0.315896, -0.232304, 0.357596, -0.187004, 0.365496, -0.044504,
266            ],
267            [-0.030305, -0.026405, 0.111395, 0.177695, 0.014895, 0.368895],
268        ]);
269        let bias_expected = TensorData::from([
270            -0.405214, 0.073686, -0.111714, 0.102886, 0.121886, -0.001714,
271        ]);
272
273        let (weight_updated, bias_updated) = (
274            state_updated.weight.val().into_data(),
275            state_updated.bias.unwrap().val().into_data(),
276        );
277
278        type FT = FloatElem<TestAutodiffBackend>;
279        let tolerance = Tolerance::absolute(1e-6);
280        bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
281        weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
282    }
283
284    fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
285        let device = Default::default();
286        let record = LinearRecord {
287            weight: Param::from_data(weight, &device),
288            bias: Some(Param::from_data(bias, &device)),
289        };
290
291        LinearConfig::new(6, 6).init(&device).load_record(record)
292    }
293
294    fn create_adagrad()
295    -> OptimizerAdaptor<AdaGrad, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
296        let config = AdaGradConfig::new();
297        AdaGrad {
298            lr_decay: LrDecay {
299                lr_decay: config.lr_decay,
300                epsilon: config.epsilon,
301            },
302            weight_decay: config.weight_decay.as_ref().map(WeightDecay::new),
303        }
304        .into()
305    }
306}