Skip to main content

burn_optim/optim/
sgd.rs

1use burn_core as burn;
2
3use super::SimpleOptimizer;
4use super::adaptor::OptimizerAdaptor;
5use super::decay::{WeightDecay, WeightDecayConfig};
6use super::momentum::{Momentum, MomentumConfig, MomentumState};
7use crate::LearningRate;
8use crate::grad_clipping::GradientClippingConfig;
9use burn::config::Config;
10use burn::module::AutodiffModule;
11use burn::record::Record;
12use burn::tensor::Tensor;
13use burn::tensor::backend::{AutodiffBackend, Backend};
14
15/// Configuration to create the [Sgd](Sgd) optimizer.
16#[derive(Config, Debug)]
17pub struct SgdConfig {
18    /// [Weight decay](WeightDecayConfig) config.
19    weight_decay: Option<WeightDecayConfig>,
20    /// [Momentum](MomentumConfig) config.
21    momentum: Option<MomentumConfig>,
22    /// [Gradient Clipping](GradientClippingConfig) config.
23    gradient_clipping: Option<GradientClippingConfig>,
24}
25
26/// Optimizer that implements stochastic gradient descent with momentum.
27///
28/// The optimizer can be configured with [SgdConfig](SgdConfig).
29#[derive(Clone)]
30pub struct Sgd<B: Backend> {
31    momentum: Option<Momentum<B>>,
32    weight_decay: Option<WeightDecay>,
33}
34
35/// State of [Sgd](Sgd).
36#[derive(Record, Clone, new)]
37pub struct SgdState<B: Backend, const D: usize> {
38    /// The current state of the momentum (if any).
39    pub momentum: Option<MomentumState<B, D>>,
40}
41
42impl SgdConfig {
43    /// Build a [`Sgd`] from the config.
44    pub fn build<B: Backend>(&self) -> Sgd<B> {
45        Sgd {
46            momentum: self.momentum.as_ref().map(Momentum::new),
47            weight_decay: self.weight_decay.as_ref().map(WeightDecay::new),
48        }
49    }
50
51    /// Creates a new [SgdConfig](SgdConfig) with default values.
52    pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
53        &self,
54    ) -> OptimizerAdaptor<Sgd<B::InnerBackend>, M, B> {
55        let mut optim = OptimizerAdaptor::from(self.build());
56        if let Some(config) = &self.gradient_clipping {
57            optim = optim.with_grad_clipping(config.init());
58        }
59        optim
60    }
61}
62
63impl<B: Backend> SimpleOptimizer<B> for Sgd<B> {
64    type State<const D: usize> = SgdState<B, D>;
65
66    fn step<const D: usize>(
67        &self,
68        lr: LearningRate,
69        tensor: Tensor<B, D>,
70        mut grad: Tensor<B, D>,
71        state: Option<Self::State<D>>,
72    ) -> (Tensor<B, D>, Option<Self::State<D>>) {
73        let mut state_momentum = None;
74
75        if let Some(state) = state {
76            state_momentum = state.momentum;
77        }
78
79        if let Some(weight_decay) = &self.weight_decay {
80            grad = weight_decay.transform(grad, tensor.clone());
81        }
82
83        if let Some(momentum) = &self.momentum {
84            let (grad_out, state) = momentum.transform(grad, state_momentum);
85            state_momentum = Some(state);
86            grad = grad_out;
87        }
88
89        let state = SgdState::new(state_momentum);
90        let delta = grad.mul_scalar(lr);
91
92        (tensor - delta, Some(state))
93    }
94
95    fn to_device<const D: usize>(mut state: Self::State<D>, device: &B::Device) -> Self::State<D> {
96        state.momentum = state.momentum.map(|state| state.to_device(device));
97        state
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use crate::{
105        TestAutodiffBackend, TestBackend,
106        grad_clipping::GradientClipping,
107        optim::{GradientsParams, Optimizer},
108    };
109    use burn::tensor::{Distribution, Shape};
110    use burn_nn::{Linear, LinearConfig};
111
112    const LEARNING_RATE: LearningRate = 0.02;
113
114    #[test]
115    fn with_updated_params_should_have_state() {
116        let device = Default::default();
117        let layer = layer::<TestAutodiffBackend>(&device);
118        let mut optim = sgd_with_all();
119        let loss = layer.forward(random_tensor::<TestAutodiffBackend>(&device));
120        let grads = loss.backward();
121        let grads = GradientsParams::from_grads(grads, &layer);
122        let _layer = optim.step(LEARNING_RATE, layer, grads);
123
124        let record = optim.to_record();
125
126        assert!(!record.is_empty());
127    }
128
129    #[test]
130    fn without_updated_params_should_not_have_state() {
131        let optim = sgd_with_all();
132        let record = optim.to_record();
133        assert!(record.is_empty());
134    }
135
136    #[test]
137    fn can_attach_gradient_clipping() {
138        let optim = sgd_with_all().with_grad_clipping(GradientClipping::Value(0.5));
139        assert!(optim.has_gradient_clipping());
140    }
141
142    #[test]
143    fn should_load_state() {
144        let device = Default::default();
145        let layer = layer::<TestAutodiffBackend>(&device);
146        let mut optim = sgd_with_all();
147        let loss = layer.forward(random_tensor(&device));
148        let grads = loss.backward();
149        let grads = GradientsParams::from_grads(grads, &layer);
150        let _layer = optim.step(LEARNING_RATE, layer, grads);
151
152        let record = optim.to_record();
153        let optim_new = sgd_with_all();
154        let record_new = optim_new.to_record();
155        let optim_new = optim_new.load_record(record.clone());
156        let state_restored = optim_new.to_record();
157
158        assert_ne!(record.len(), record_new.len());
159        assert_eq!(record.len(), state_restored.len());
160    }
161
162    fn random_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
163        Tensor::<B, 2>::random(Shape::new([2, 20]), Distribution::Default, device)
164    }
165
166    fn layer<B: Backend>(device: &B::Device) -> Linear<B> {
167        LinearConfig::new(20, 20).init(device)
168    }
169
170    fn sgd_with_all()
171    -> OptimizerAdaptor<Sgd<TestBackend>, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
172        SgdConfig {
173            weight_decay: Some(WeightDecayConfig { penalty: 0.05 }),
174            momentum: Some(MomentumConfig {
175                momentum: 0.9,
176                dampening: 0.1,
177                nesterov: true,
178            }),
179            gradient_clipping: None,
180        }
181        .init()
182    }
183}