1use burn_core as burn;
2
3use super::SimpleOptimizer;
4use super::adaptor::OptimizerAdaptor;
5use super::decay::{WeightDecay, WeightDecayConfig};
6use super::momentum::{Momentum, MomentumConfig, MomentumState};
7use crate::LearningRate;
8use crate::grad_clipping::GradientClippingConfig;
9use burn::config::Config;
10use burn::module::AutodiffModule;
11use burn::record::Record;
12use burn::tensor::Tensor;
13use burn::tensor::backend::{AutodiffBackend, Backend};
14
15#[derive(Config, Debug)]
17pub struct SgdConfig {
18 weight_decay: Option<WeightDecayConfig>,
20 momentum: Option<MomentumConfig>,
22 gradient_clipping: Option<GradientClippingConfig>,
24}
25
26#[derive(Clone)]
30pub struct Sgd<B: Backend> {
31 momentum: Option<Momentum<B>>,
32 weight_decay: Option<WeightDecay>,
33}
34
35#[derive(Record, Clone, new)]
37pub struct SgdState<B: Backend, const D: usize> {
38 pub momentum: Option<MomentumState<B, D>>,
40}
41
42impl SgdConfig {
43 pub fn build<B: Backend>(&self) -> Sgd<B> {
45 Sgd {
46 momentum: self.momentum.as_ref().map(Momentum::new),
47 weight_decay: self.weight_decay.as_ref().map(WeightDecay::new),
48 }
49 }
50
51 pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
53 &self,
54 ) -> OptimizerAdaptor<Sgd<B::InnerBackend>, M, B> {
55 let mut optim = OptimizerAdaptor::from(self.build());
56 if let Some(config) = &self.gradient_clipping {
57 optim = optim.with_grad_clipping(config.init());
58 }
59 optim
60 }
61}
62
63impl<B: Backend> SimpleOptimizer<B> for Sgd<B> {
64 type State<const D: usize> = SgdState<B, D>;
65
66 fn step<const D: usize>(
67 &self,
68 lr: LearningRate,
69 tensor: Tensor<B, D>,
70 mut grad: Tensor<B, D>,
71 state: Option<Self::State<D>>,
72 ) -> (Tensor<B, D>, Option<Self::State<D>>) {
73 let mut state_momentum = None;
74
75 if let Some(state) = state {
76 state_momentum = state.momentum;
77 }
78
79 if let Some(weight_decay) = &self.weight_decay {
80 grad = weight_decay.transform(grad, tensor.clone());
81 }
82
83 if let Some(momentum) = &self.momentum {
84 let (grad_out, state) = momentum.transform(grad, state_momentum);
85 state_momentum = Some(state);
86 grad = grad_out;
87 }
88
89 let state = SgdState::new(state_momentum);
90 let delta = grad.mul_scalar(lr);
91
92 (tensor - delta, Some(state))
93 }
94
95 fn to_device<const D: usize>(mut state: Self::State<D>, device: &B::Device) -> Self::State<D> {
96 state.momentum = state.momentum.map(|state| state.to_device(device));
97 state
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104 use crate::{
105 TestAutodiffBackend, TestBackend,
106 grad_clipping::GradientClipping,
107 optim::{GradientsParams, Optimizer},
108 };
109 use burn::tensor::{Distribution, Shape};
110 use burn_nn::{Linear, LinearConfig};
111
112 const LEARNING_RATE: LearningRate = 0.02;
113
114 #[test]
115 fn with_updated_params_should_have_state() {
116 let device = Default::default();
117 let layer = layer::<TestAutodiffBackend>(&device);
118 let mut optim = sgd_with_all();
119 let loss = layer.forward(random_tensor::<TestAutodiffBackend>(&device));
120 let grads = loss.backward();
121 let grads = GradientsParams::from_grads(grads, &layer);
122 let _layer = optim.step(LEARNING_RATE, layer, grads);
123
124 let record = optim.to_record();
125
126 assert!(!record.is_empty());
127 }
128
129 #[test]
130 fn without_updated_params_should_not_have_state() {
131 let optim = sgd_with_all();
132 let record = optim.to_record();
133 assert!(record.is_empty());
134 }
135
136 #[test]
137 fn can_attach_gradient_clipping() {
138 let optim = sgd_with_all().with_grad_clipping(GradientClipping::Value(0.5));
139 assert!(optim.has_gradient_clipping());
140 }
141
142 #[test]
143 fn should_load_state() {
144 let device = Default::default();
145 let layer = layer::<TestAutodiffBackend>(&device);
146 let mut optim = sgd_with_all();
147 let loss = layer.forward(random_tensor(&device));
148 let grads = loss.backward();
149 let grads = GradientsParams::from_grads(grads, &layer);
150 let _layer = optim.step(LEARNING_RATE, layer, grads);
151
152 let record = optim.to_record();
153 let optim_new = sgd_with_all();
154 let record_new = optim_new.to_record();
155 let optim_new = optim_new.load_record(record.clone());
156 let state_restored = optim_new.to_record();
157
158 assert_ne!(record.len(), record_new.len());
159 assert_eq!(record.len(), state_restored.len());
160 }
161
162 fn random_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
163 Tensor::<B, 2>::random(Shape::new([2, 20]), Distribution::Default, device)
164 }
165
166 fn layer<B: Backend>(device: &B::Device) -> Linear<B> {
167 LinearConfig::new(20, 20).init(device)
168 }
169
170 fn sgd_with_all()
171 -> OptimizerAdaptor<Sgd<TestBackend>, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
172 SgdConfig {
173 weight_decay: Some(WeightDecayConfig { penalty: 0.05 }),
174 momentum: Some(MomentumConfig {
175 momentum: 0.9,
176 dampening: 0.1,
177 nesterov: true,
178 }),
179 gradient_clipping: None,
180 }
181 .init()
182 }
183}