1use crate::grad_clipping::GradientClippingConfig;
2use crate::module::AutodiffModule;
3use crate::{self as burn, LearningRate};
4
5use super::SimpleOptimizer;
6use super::decay::{WeightDecay, WeightDecayConfig};
7use super::momentum::{Momentum, MomentumConfig, MomentumState};
8use crate::config::Config;
9use crate::optim::adaptor::OptimizerAdaptor;
10use crate::record::Record;
11use crate::tensor::Tensor;
12use burn_tensor::backend::{AutodiffBackend, Backend};
13
14#[derive(Config)]
16pub struct SgdConfig {
17 weight_decay: Option<WeightDecayConfig>,
19 momentum: Option<MomentumConfig>,
21 gradient_clipping: Option<GradientClippingConfig>,
23}
24
25#[derive(Clone)]
29pub struct Sgd<B: Backend> {
30 momentum: Option<Momentum<B>>,
31 weight_decay: Option<WeightDecay>,
32}
33
34#[derive(Record, Clone, new)]
36pub struct SgdState<B: Backend, const D: usize> {
37 pub momentum: Option<MomentumState<B, D>>,
39}
40
41impl SgdConfig {
42 pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
44 &self,
45 ) -> OptimizerAdaptor<Sgd<B::InnerBackend>, M, B> {
46 let momentum = self.momentum.as_ref().map(Momentum::new);
47 let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new);
48
49 let mut optim = OptimizerAdaptor::from(Sgd {
50 momentum,
51 weight_decay,
52 });
53 if let Some(config) = &self.gradient_clipping {
54 optim = optim.with_grad_clipping(config.init());
55 }
56 optim
57 }
58}
59
60impl<B: Backend> SimpleOptimizer<B> for Sgd<B> {
61 type State<const D: usize> = SgdState<B, D>;
62
63 fn step<const D: usize>(
64 &self,
65 lr: LearningRate,
66 tensor: Tensor<B, D>,
67 mut grad: Tensor<B, D>,
68 state: Option<Self::State<D>>,
69 ) -> (Tensor<B, D>, Option<Self::State<D>>) {
70 let mut state_momemtum = None;
71
72 if let Some(state) = state {
73 state_momemtum = state.momentum;
74 }
75
76 if let Some(weight_decay) = &self.weight_decay {
77 grad = weight_decay.transform(grad, tensor.clone());
78 }
79
80 if let Some(momentum) = &self.momentum {
81 let (grad_out, state) = momentum.transform(grad, state_momemtum);
82 state_momemtum = Some(state);
83 grad = grad_out;
84 }
85
86 let state = SgdState::new(state_momemtum);
87 let delta = grad.mul_scalar(lr);
88
89 (tensor - delta, Some(state))
90 }
91
92 fn to_device<const D: usize>(mut state: Self::State<D>, device: &B::Device) -> Self::State<D> {
93 state.momentum = state.momentum.map(|state| state.to_device(device));
94 state
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101 use crate::{
102 TestAutodiffBackend, TestBackend,
103 grad_clipping::GradientClipping,
104 nn::{Linear, LinearConfig},
105 optim::{GradientsParams, Optimizer},
106 tensor::{Distribution, Shape},
107 };
108
109 const LEARNING_RATE: LearningRate = 0.02;
110
111 #[test]
112 fn with_updated_params_should_have_state() {
113 let device = Default::default();
114 let layer = layer::<TestAutodiffBackend>(&device);
115 let mut optim = sgd_with_all();
116 let loss = layer.forward(random_tensor::<TestAutodiffBackend>(&device));
117 let grads = loss.backward();
118 let grads = GradientsParams::from_grads(grads, &layer);
119 let _layer = optim.step(LEARNING_RATE, layer, grads);
120
121 let record = optim.to_record();
122
123 assert!(!record.is_empty());
124 }
125
126 #[test]
127 fn without_updated_params_should_not_have_state() {
128 let optim = sgd_with_all();
129 let record = optim.to_record();
130 assert!(record.is_empty());
131 }
132
133 #[test]
134 fn can_attach_gradient_clipping() {
135 let optim = sgd_with_all().with_grad_clipping(GradientClipping::Value(0.5));
136 assert!(optim.has_gradient_clipping());
137 }
138
139 #[test]
140 fn should_load_state() {
141 let device = Default::default();
142 let layer = layer::<TestAutodiffBackend>(&device);
143 let mut optim = sgd_with_all();
144 let loss = layer.forward(random_tensor(&device));
145 let grads = loss.backward();
146 let grads = GradientsParams::from_grads(grads, &layer);
147 let _layer = optim.step(LEARNING_RATE, layer, grads);
148
149 let record = optim.to_record();
150 let optim_new = sgd_with_all();
151 let record_new = optim_new.to_record();
152 let optim_new = optim_new.load_record(record.clone());
153 let state_restored = optim_new.to_record();
154
155 assert_ne!(record.len(), record_new.len());
156 assert_eq!(record.len(), state_restored.len());
157 }
158
159 fn random_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
160 Tensor::<B, 2>::random(Shape::new([2, 20]), Distribution::Default, device)
161 }
162
163 fn layer<B: Backend>(device: &B::Device) -> Linear<B> {
164 LinearConfig::new(20, 20).with_bias(true).init(device)
165 }
166
167 fn sgd_with_all()
168 -> OptimizerAdaptor<Sgd<TestBackend>, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
169 SgdConfig {
170 weight_decay: Some(WeightDecayConfig { penalty: 0.05 }),
171 momentum: Some(MomentumConfig {
172 momentum: 0.9,
173 dampening: 0.1,
174 nesterov: true,
175 }),
176 gradient_clipping: None,
177 }
178 .init()
179 }
180}