1use crate::{
2 self as burn, LearningRate, grad_clipping::GradientClippingConfig, module::AutodiffModule,
3 record::Record,
4};
5
6use super::{
7 SimpleOptimizer,
8 decay::{WeightDecay, WeightDecayConfig},
9};
10use crate::config::Config;
11use crate::optim::adaptor::OptimizerAdaptor;
12use crate::tensor::{Tensor, backend::AutodiffBackend};
13use burn_tensor::{backend::Backend, ops::Device};
14
15#[derive(Config)]
17pub struct AdaGradConfig {
18 #[config(default = 0.)]
19 lr_decay: f64,
20 #[config(default = 1e-5)]
21 epsilon: f32,
22 weight_decay: Option<WeightDecayConfig>,
24 grad_clipping: Option<GradientClippingConfig>,
26}
27
28#[derive(Clone)]
30pub struct AdaGrad {
31 lr_decay: LrDecay,
32 weight_decay: Option<WeightDecay>,
33}
34
35#[derive(Record, Clone, new)]
37pub struct AdaGradState<B: Backend, const D: usize> {
38 lr_decay: LrDecayState<B, D>,
39}
40
41impl<B: Backend> SimpleOptimizer<B> for AdaGrad {
42 type State<const D: usize> = AdaGradState<B, D>;
43
44 fn step<const D: usize>(
45 &self,
46 lr: LearningRate,
47 tensor: Tensor<B, D>,
48 mut grad: Tensor<B, D>,
49 state: Option<Self::State<D>>,
50 ) -> (Tensor<B, D>, Option<Self::State<D>>) {
51 let mut state_lr_decay = None;
52
53 if let Some(state) = state {
54 state_lr_decay = Some(state.lr_decay);
55 }
56
57 if let Some(weight_decay) = &self.weight_decay {
58 grad = weight_decay.transform(grad, tensor.clone());
59 }
60
61 let (grad, state_lr_decay) = self.lr_decay.transform(grad, lr, state_lr_decay);
62
63 let state = AdaGradState::new(state_lr_decay);
64
65 (tensor - grad, Some(state))
66 }
67
68 fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
69 state.lr_decay = state.lr_decay.to_device(device);
70 state
71 }
72}
73
74impl AdaGradConfig {
75 pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
81 &self,
82 ) -> OptimizerAdaptor<AdaGrad, M, B> {
83 let optim = AdaGrad {
84 lr_decay: LrDecay {
85 lr_decay: self.lr_decay,
86 epsilon: self.epsilon,
87 },
88 weight_decay: self.weight_decay.as_ref().map(WeightDecay::new),
89 };
90
91 let mut optim = OptimizerAdaptor::from(optim);
92 if let Some(config) = &self.grad_clipping {
93 optim = optim.with_grad_clipping(config.init());
94 }
95 optim
96 }
97}
98
99#[derive(Record, new, Clone)]
101pub struct LrDecayState<B: Backend, const D: usize> {
102 time: usize,
103 sum: Tensor<B, D>,
104}
105
106#[derive(Clone)]
107struct LrDecay {
108 lr_decay: f64,
109 epsilon: f32,
110}
111
112impl LrDecay {
113 pub fn transform<B: Backend, const D: usize>(
114 &self,
115 grad: Tensor<B, D>,
116 lr: LearningRate,
117 lr_decay_state: Option<LrDecayState<B, D>>,
118 ) -> (Tensor<B, D>, LrDecayState<B, D>) {
119 let state = if let Some(mut state) = lr_decay_state {
120 state.sum = state.sum.add(grad.clone().powf_scalar(2.));
121 state.time += 1;
122 state
123 } else {
124 LrDecayState::new(1, grad.clone().powf_scalar(2.))
125 };
126
127 let new_lr = lr / (1. + (state.time as f64 - 1.) * self.lr_decay);
128
129 let grad = grad
130 .div(state.sum.clone().sqrt().add_scalar(self.epsilon))
131 .mul_scalar(new_lr);
132
133 (grad, state)
134 }
135}
136
137impl<B: Backend, const D: usize> LrDecayState<B, D> {
138 pub fn to_device(mut self, device: &B::Device) -> Self {
148 self.sum = self.sum.to_device(device);
149 self
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use burn_tensor::Tolerance;
156 use burn_tensor::ops::FloatElem;
157
158 use super::*;
159 use crate::module::{Module, Param};
160 use crate::optim::{GradientsParams, Optimizer};
161 use crate::tensor::{Distribution, Tensor, TensorData};
162 use crate::{TestAutodiffBackend, nn, nn::Linear};
163
164 const LEARNING_RATE: LearningRate = 0.01;
165
166 #[test]
167 fn test_adagrad_optimizer_save_load_state() {
168 let device = Default::default();
169 let linear = nn::LinearConfig::new(6, 6).init(&device);
170 let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
171 let mut optimizer = create_adagrad();
172 let grads = linear.forward(x).backward();
173 let grads = GradientsParams::from_grads(grads, &linear);
174 let _linear = optimizer.step(LEARNING_RATE, linear, grads);
175
176 #[cfg(feature = "std")]
177 {
178 use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
179
180 BinFileRecorder::<FullPrecisionSettings>::default()
181 .record(
182 optimizer.to_record(),
183 std::env::temp_dir().as_path().join("test_optim_adagrad"),
184 )
185 .unwrap();
186 }
187 #[cfg(not(feature = "std"))]
188 {
189 use crate::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
190
191 let result = BinBytesRecorder::<FullPrecisionSettings>::default()
192 .record(optimizer.to_record(), ())
193 .unwrap();
194 assert!(!result.is_empty());
195 }
196
197 let state_optim_before = optimizer.to_record();
198 let state_optim_before_copy = optimizer.to_record();
199 let optimizer = create_adagrad();
200 let optimizer = optimizer.load_record(state_optim_before_copy);
201 let state_optim_after = optimizer.to_record();
202
203 assert_eq!(state_optim_before.len(), state_optim_after.len());
204 }
205
206 #[test]
207 fn test_adagrad_optimizer_with_numbers() {
208 let device = Default::default();
209 let linear = given_linear_layer(
210 TensorData::from([
211 [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
212 [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
213 [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
214 [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
215 [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
216 [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
217 ]),
218 TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
219 );
220 let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
221 [
222 [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
223 [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
224 ],
225 &device,
226 )
227 .require_grad();
228 let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
229 [
230 [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
231 [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
232 ],
233 &device,
234 )
235 .require_grad();
236
237 let mut optimizer = AdaGradConfig::new()
238 .with_epsilon(1e-8)
239 .with_lr_decay(0.5)
240 .init();
241
242 let grads = linear.forward(x_1).backward();
243 let grads = GradientsParams::from_grads(grads, &linear);
244 let linear = optimizer.step(LEARNING_RATE, linear, grads);
245
246 let grads = linear.forward(x_2).backward();
247 let grads = GradientsParams::from_grads(grads, &linear);
248 let linear = optimizer.step(LEARNING_RATE, linear, grads);
249
250 let state_updated = linear.into_record();
251 let weights_expected = TensorData::from([
252 [-0.334989, 0.123011, 0.389911, 0.305611, 0.071511, 0.052711],
253 [
254 0.066144, -0.030056, -0.378256, 0.243444, 0.183944, -0.303756,
255 ],
256 [
257 -0.033462, 0.020138, -0.310662, 0.233938, -0.292462, 0.298538,
258 ],
259 [
260 -0.312636, -0.236036, -0.386136, -0.312736, -0.090736, 0.147964,
261 ],
262 [
263 0.315896, -0.232304, 0.357596, -0.187004, 0.365496, -0.044504,
264 ],
265 [-0.030305, -0.026405, 0.111395, 0.177695, 0.014895, 0.368895],
266 ]);
267 let bias_expected = TensorData::from([
268 -0.405214, 0.073686, -0.111714, 0.102886, 0.121886, -0.001714,
269 ]);
270
271 let (weight_updated, bias_updated) = (
272 state_updated.weight.val().into_data(),
273 state_updated.bias.unwrap().val().into_data(),
274 );
275
276 type FT = FloatElem<TestAutodiffBackend>;
277 let tolerance = Tolerance::absolute(1e-6);
278 bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
279 weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
280 }
281
282 fn given_linear_layer(weight: TensorData, bias: TensorData) -> nn::Linear<TestAutodiffBackend> {
283 let device = Default::default();
284 let record = nn::LinearRecord {
285 weight: Param::from_data(weight, &device),
286 bias: Some(Param::from_data(bias, &device)),
287 };
288
289 nn::LinearConfig::new(6, 6)
290 .init(&device)
291 .load_record(record)
292 }
293
294 fn create_adagrad()
295 -> OptimizerAdaptor<AdaGrad, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
296 let config = AdaGradConfig::new();
297 AdaGrad {
298 lr_decay: LrDecay {
299 lr_decay: config.lr_decay,
300 epsilon: config.epsilon,
301 },
302 weight_decay: config.weight_decay.as_ref().map(WeightDecay::new),
303 }
304 .into()
305 }
306}