1use burn_core as burn;
2
3use burn::{module::AutodiffModule, record::Record};
4
5use burn::config::Config;
6use burn::tensor::{Tensor, backend::AutodiffBackend};
7use burn::tensor::{backend::Backend, ops::Device};
8
9use super::{
10 SimpleOptimizer,
11 adaptor::OptimizerAdaptor,
12 decay::{WeightDecay, WeightDecayConfig},
13};
14use crate::{LearningRate, grad_clipping::GradientClippingConfig};
15
16#[derive(Config, Debug)]
18pub struct AdaGradConfig {
19 #[config(default = 0.)]
20 lr_decay: f64,
21 #[config(default = 1e-5)]
22 epsilon: f32,
23 weight_decay: Option<WeightDecayConfig>,
25 grad_clipping: Option<GradientClippingConfig>,
27}
28
29#[derive(Clone)]
31pub struct AdaGrad {
32 lr_decay: LrDecay,
33 weight_decay: Option<WeightDecay>,
34}
35
36#[derive(Record, Clone, new)]
38pub struct AdaGradState<B: Backend, const D: usize> {
39 lr_decay: LrDecayState<B, D>,
40}
41
42impl<B: Backend> SimpleOptimizer<B> for AdaGrad {
43 type State<const D: usize> = AdaGradState<B, D>;
44
45 fn step<const D: usize>(
46 &self,
47 lr: LearningRate,
48 tensor: Tensor<B, D>,
49 mut grad: Tensor<B, D>,
50 state: Option<Self::State<D>>,
51 ) -> (Tensor<B, D>, Option<Self::State<D>>) {
52 let mut state_lr_decay = None;
53
54 if let Some(state) = state {
55 state_lr_decay = Some(state.lr_decay);
56 }
57
58 if let Some(weight_decay) = &self.weight_decay {
59 grad = weight_decay.transform(grad, tensor.clone());
60 }
61
62 let (grad, state_lr_decay) = self.lr_decay.transform(grad, lr, state_lr_decay);
63
64 let state = AdaGradState::new(state_lr_decay);
65
66 (tensor - grad, Some(state))
67 }
68
69 fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
70 state.lr_decay = state.lr_decay.to_device(device);
71 state
72 }
73}
74
75impl AdaGradConfig {
76 pub fn build(&self) -> AdaGrad {
78 AdaGrad {
79 lr_decay: LrDecay {
80 lr_decay: self.lr_decay,
81 epsilon: self.epsilon,
82 },
83 weight_decay: self.weight_decay.as_ref().map(WeightDecay::new),
84 }
85 }
86
87 pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
93 &self,
94 ) -> OptimizerAdaptor<AdaGrad, M, B> {
95 let mut optim = OptimizerAdaptor::from(self.build());
96 if let Some(config) = &self.grad_clipping {
97 optim = optim.with_grad_clipping(config.init());
98 }
99 optim
100 }
101}
102
103#[derive(Record, new, Clone)]
105pub struct LrDecayState<B: Backend, const D: usize> {
106 time: usize,
107 sum: Tensor<B, D>,
108}
109
110#[derive(Clone)]
111struct LrDecay {
112 lr_decay: f64,
113 epsilon: f32,
114}
115
116impl LrDecay {
117 pub fn transform<B: Backend, const D: usize>(
118 &self,
119 grad: Tensor<B, D>,
120 lr: LearningRate,
121 lr_decay_state: Option<LrDecayState<B, D>>,
122 ) -> (Tensor<B, D>, LrDecayState<B, D>) {
123 let state = if let Some(mut state) = lr_decay_state {
124 state.sum = state.sum.add(grad.clone().square());
125 state.time += 1;
126 state
127 } else {
128 LrDecayState::new(1, grad.clone().square())
129 };
130
131 let new_lr = lr / (1. + (state.time as f64 - 1.) * self.lr_decay);
132
133 let grad = grad
134 .div(state.sum.clone().sqrt().add_scalar(self.epsilon))
135 .mul_scalar(new_lr);
136
137 (grad, state)
138 }
139}
140
141impl<B: Backend, const D: usize> LrDecayState<B, D> {
142 pub fn to_device(mut self, device: &B::Device) -> Self {
152 self.sum = self.sum.to_device(device);
153 self
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use burn::tensor::Tolerance;
160 use burn::tensor::ops::FloatElem;
161
162 use super::*;
163 use crate::TestAutodiffBackend;
164 use crate::{GradientsParams, Optimizer};
165 use burn::module::{Module, Param};
166 use burn::tensor::{Distribution, Tensor, TensorData};
167 use burn_nn::{Linear, LinearConfig, LinearRecord};
168
169 const LEARNING_RATE: LearningRate = 0.01;
170
171 #[test]
172 fn test_adagrad_optimizer_save_load_state() {
173 let device = Default::default();
174 let linear = LinearConfig::new(6, 6).init(&device);
175 let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
176 let mut optimizer = create_adagrad();
177 let grads = linear.forward(x).backward();
178 let grads = GradientsParams::from_grads(grads, &linear);
179 let _linear = optimizer.step(LEARNING_RATE, linear, grads);
180
181 #[cfg(feature = "std")]
182 {
183 use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
184
185 BinFileRecorder::<FullPrecisionSettings>::default()
186 .record(
187 optimizer.to_record(),
188 std::env::temp_dir().as_path().join("test_optim_adagrad"),
189 )
190 .unwrap();
191 }
192 #[cfg(not(feature = "std"))]
193 {
194 use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};
195
196 let result = BinBytesRecorder::<FullPrecisionSettings>::default()
197 .record(optimizer.to_record(), ())
198 .unwrap();
199 assert!(!result.is_empty());
200 }
201
202 let state_optim_before = optimizer.to_record();
203 let state_optim_before_copy = optimizer.to_record();
204 let optimizer = create_adagrad();
205 let optimizer = optimizer.load_record(state_optim_before_copy);
206 let state_optim_after = optimizer.to_record();
207
208 assert_eq!(state_optim_before.len(), state_optim_after.len());
209 }
210
211 #[test]
212 fn test_adagrad_optimizer_with_numbers() {
213 let device = Default::default();
214 let linear = given_linear_layer(
215 TensorData::from([
216 [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
217 [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
218 [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
219 [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
220 [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
221 [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
222 ]),
223 TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
224 );
225 let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(
226 [
227 [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
228 [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
229 ],
230 &device,
231 )
232 .require_grad();
233 let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(
234 [
235 [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
236 [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
237 ],
238 &device,
239 )
240 .require_grad();
241
242 let mut optimizer = AdaGradConfig::new()
243 .with_epsilon(1e-8)
244 .with_lr_decay(0.5)
245 .init();
246
247 let grads = linear.forward(x_1).backward();
248 let grads = GradientsParams::from_grads(grads, &linear);
249 let linear = optimizer.step(LEARNING_RATE, linear, grads);
250
251 let grads = linear.forward(x_2).backward();
252 let grads = GradientsParams::from_grads(grads, &linear);
253 let linear = optimizer.step(LEARNING_RATE, linear, grads);
254
255 let state_updated = linear.into_record();
256 let weights_expected = TensorData::from([
257 [-0.334989, 0.123011, 0.389911, 0.305611, 0.071511, 0.052711],
258 [
259 0.066144, -0.030056, -0.378256, 0.243444, 0.183944, -0.303756,
260 ],
261 [
262 -0.033462, 0.020138, -0.310662, 0.233938, -0.292462, 0.298538,
263 ],
264 [
265 -0.312636, -0.236036, -0.386136, -0.312736, -0.090736, 0.147964,
266 ],
267 [
268 0.315896, -0.232304, 0.357596, -0.187004, 0.365496, -0.044504,
269 ],
270 [-0.030305, -0.026405, 0.111395, 0.177695, 0.014895, 0.368895],
271 ]);
272 let bias_expected = TensorData::from([
273 -0.405214, 0.073686, -0.111714, 0.102886, 0.121886, -0.001714,
274 ]);
275
276 let (weight_updated, bias_updated) = (
277 state_updated.weight.val().into_data(),
278 state_updated.bias.unwrap().val().into_data(),
279 );
280
281 type FT = FloatElem<TestAutodiffBackend>;
282 let tolerance = Tolerance::absolute(1e-6);
283 bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);
284 weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);
285 }
286
287 fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
288 let device = Default::default();
289 let record = LinearRecord {
290 weight: Param::from_data(weight, &device),
291 bias: Some(Param::from_data(bias, &device)),
292 };
293
294 LinearConfig::new(6, 6).init(&device).load_record(record)
295 }
296
297 fn create_adagrad()
298 -> OptimizerAdaptor<AdaGrad, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
299 let config = AdaGradConfig::new();
300 AdaGrad {
301 lr_decay: LrDecay {
302 lr_decay: config.lr_decay,
303 epsilon: config.epsilon,
304 },
305 weight_decay: config.weight_decay.as_ref().map(WeightDecay::new),
306 }
307 .into()
308 }
309}