burn_core/optim/
sgd.rs

1use crate::grad_clipping::GradientClippingConfig;
2use crate::module::AutodiffModule;
3use crate::{self as burn, LearningRate};
4
5use super::SimpleOptimizer;
6use super::decay::{WeightDecay, WeightDecayConfig};
7use super::momentum::{Momentum, MomentumConfig, MomentumState};
8use crate::config::Config;
9use crate::optim::adaptor::OptimizerAdaptor;
10use crate::record::Record;
11use crate::tensor::Tensor;
12use burn_tensor::backend::{AutodiffBackend, Backend};
13
14/// Configuration to create the [Sgd](Sgd) optimizer.
15#[derive(Config)]
16pub struct SgdConfig {
17    /// [Weight decay](WeightDecayConfig) config.
18    weight_decay: Option<WeightDecayConfig>,
19    /// [Momentum](MomentumConfig) config.
20    momentum: Option<MomentumConfig>,
21    /// [Gradient Clipping](GradientClippingConfig) config.
22    gradient_clipping: Option<GradientClippingConfig>,
23}
24
25/// Optimizer that implements stochastic gradient descent with momentum.
26///
27/// The optimizer can be configured with [SgdConfig](SgdConfig).
28#[derive(Clone)]
29pub struct Sgd<B: Backend> {
30    momentum: Option<Momentum<B>>,
31    weight_decay: Option<WeightDecay>,
32}
33
34/// State of [Sgd](Sgd).
35#[derive(Record, Clone, new)]
36pub struct SgdState<B: Backend, const D: usize> {
37    /// The current state of the momentum (if any).
38    pub momentum: Option<MomentumState<B, D>>,
39}
40
41impl SgdConfig {
42    /// Creates a new [SgdConfig](SgdConfig) with default values.
43    pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
44        &self,
45    ) -> OptimizerAdaptor<Sgd<B::InnerBackend>, M, B> {
46        let momentum = self.momentum.as_ref().map(Momentum::new);
47        let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new);
48
49        let mut optim = OptimizerAdaptor::from(Sgd {
50            momentum,
51            weight_decay,
52        });
53        if let Some(config) = &self.gradient_clipping {
54            optim = optim.with_grad_clipping(config.init());
55        }
56        optim
57    }
58}
59
60impl<B: Backend> SimpleOptimizer<B> for Sgd<B> {
61    type State<const D: usize> = SgdState<B, D>;
62
63    fn step<const D: usize>(
64        &self,
65        lr: LearningRate,
66        tensor: Tensor<B, D>,
67        mut grad: Tensor<B, D>,
68        state: Option<Self::State<D>>,
69    ) -> (Tensor<B, D>, Option<Self::State<D>>) {
70        let mut state_momemtum = None;
71
72        if let Some(state) = state {
73            state_momemtum = state.momentum;
74        }
75
76        if let Some(weight_decay) = &self.weight_decay {
77            grad = weight_decay.transform(grad, tensor.clone());
78        }
79
80        if let Some(momentum) = &self.momentum {
81            let (grad_out, state) = momentum.transform(grad, state_momemtum);
82            state_momemtum = Some(state);
83            grad = grad_out;
84        }
85
86        let state = SgdState::new(state_momemtum);
87        let delta = grad.mul_scalar(lr);
88
89        (tensor - delta, Some(state))
90    }
91
92    fn to_device<const D: usize>(mut state: Self::State<D>, device: &B::Device) -> Self::State<D> {
93        state.momentum = state.momentum.map(|state| state.to_device(device));
94        state
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101    use crate::{
102        TestAutodiffBackend, TestBackend,
103        grad_clipping::GradientClipping,
104        nn::{Linear, LinearConfig},
105        optim::{GradientsParams, Optimizer},
106        tensor::{Distribution, Shape},
107    };
108
109    const LEARNING_RATE: LearningRate = 0.02;
110
111    #[test]
112    fn with_updated_params_should_have_state() {
113        let device = Default::default();
114        let layer = layer::<TestAutodiffBackend>(&device);
115        let mut optim = sgd_with_all();
116        let loss = layer.forward(random_tensor::<TestAutodiffBackend>(&device));
117        let grads = loss.backward();
118        let grads = GradientsParams::from_grads(grads, &layer);
119        let _layer = optim.step(LEARNING_RATE, layer, grads);
120
121        let record = optim.to_record();
122
123        assert!(!record.is_empty());
124    }
125
126    #[test]
127    fn without_updated_params_should_not_have_state() {
128        let optim = sgd_with_all();
129        let record = optim.to_record();
130        assert!(record.is_empty());
131    }
132
133    #[test]
134    fn can_attach_gradient_clipping() {
135        let optim = sgd_with_all().with_grad_clipping(GradientClipping::Value(0.5));
136        assert!(optim.has_gradient_clipping());
137    }
138
139    #[test]
140    fn should_load_state() {
141        let device = Default::default();
142        let layer = layer::<TestAutodiffBackend>(&device);
143        let mut optim = sgd_with_all();
144        let loss = layer.forward(random_tensor(&device));
145        let grads = loss.backward();
146        let grads = GradientsParams::from_grads(grads, &layer);
147        let _layer = optim.step(LEARNING_RATE, layer, grads);
148
149        let record = optim.to_record();
150        let optim_new = sgd_with_all();
151        let record_new = optim_new.to_record();
152        let optim_new = optim_new.load_record(record.clone());
153        let state_restored = optim_new.to_record();
154
155        assert_ne!(record.len(), record_new.len());
156        assert_eq!(record.len(), state_restored.len());
157    }
158
159    fn random_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
160        Tensor::<B, 2>::random(Shape::new([2, 20]), Distribution::Default, device)
161    }
162
163    fn layer<B: Backend>(device: &B::Device) -> Linear<B> {
164        LinearConfig::new(20, 20).with_bias(true).init(device)
165    }
166
167    fn sgd_with_all()
168    -> OptimizerAdaptor<Sgd<TestBackend>, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
169        SgdConfig {
170            weight_decay: Some(WeightDecayConfig { penalty: 0.05 }),
171            momentum: Some(MomentumConfig {
172                momentum: 0.9,
173                dampening: 0.1,
174                nesterov: true,
175            }),
176            gradient_clipping: None,
177        }
178        .init()
179    }
180}