burn_optim/optim/
sgd.rs

1use burn_core as burn;
2
3use super::SimpleOptimizer;
4use super::adaptor::OptimizerAdaptor;
5use super::decay::{WeightDecay, WeightDecayConfig};
6use super::momentum::{Momentum, MomentumConfig, MomentumState};
7use crate::LearningRate;
8use crate::grad_clipping::GradientClippingConfig;
9use burn::config::Config;
10use burn::module::AutodiffModule;
11use burn::record::Record;
12use burn::tensor::Tensor;
13use burn::tensor::backend::{AutodiffBackend, Backend};
14
15/// Configuration to create the [Sgd](Sgd) optimizer.
16#[derive(Config, Debug)]
17pub struct SgdConfig {
18    /// [Weight decay](WeightDecayConfig) config.
19    weight_decay: Option<WeightDecayConfig>,
20    /// [Momentum](MomentumConfig) config.
21    momentum: Option<MomentumConfig>,
22    /// [Gradient Clipping](GradientClippingConfig) config.
23    gradient_clipping: Option<GradientClippingConfig>,
24}
25
26/// Optimizer that implements stochastic gradient descent with momentum.
27///
28/// The optimizer can be configured with [SgdConfig](SgdConfig).
29#[derive(Clone)]
30pub struct Sgd<B: Backend> {
31    momentum: Option<Momentum<B>>,
32    weight_decay: Option<WeightDecay>,
33}
34
35/// State of [Sgd](Sgd).
36#[derive(Record, Clone, new)]
37pub struct SgdState<B: Backend, const D: usize> {
38    /// The current state of the momentum (if any).
39    pub momentum: Option<MomentumState<B, D>>,
40}
41
42impl SgdConfig {
43    /// Creates a new [SgdConfig](SgdConfig) with default values.
44    pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
45        &self,
46    ) -> OptimizerAdaptor<Sgd<B::InnerBackend>, M, B> {
47        let momentum = self.momentum.as_ref().map(Momentum::new);
48        let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new);
49
50        let mut optim = OptimizerAdaptor::from(Sgd {
51            momentum,
52            weight_decay,
53        });
54        if let Some(config) = &self.gradient_clipping {
55            optim = optim.with_grad_clipping(config.init());
56        }
57        optim
58    }
59}
60
61impl<B: Backend> SimpleOptimizer<B> for Sgd<B> {
62    type State<const D: usize> = SgdState<B, D>;
63
64    fn step<const D: usize>(
65        &self,
66        lr: LearningRate,
67        tensor: Tensor<B, D>,
68        mut grad: Tensor<B, D>,
69        state: Option<Self::State<D>>,
70    ) -> (Tensor<B, D>, Option<Self::State<D>>) {
71        let mut state_momentum = None;
72
73        if let Some(state) = state {
74            state_momentum = state.momentum;
75        }
76
77        if let Some(weight_decay) = &self.weight_decay {
78            grad = weight_decay.transform(grad, tensor.clone());
79        }
80
81        if let Some(momentum) = &self.momentum {
82            let (grad_out, state) = momentum.transform(grad, state_momentum);
83            state_momentum = Some(state);
84            grad = grad_out;
85        }
86
87        let state = SgdState::new(state_momentum);
88        let delta = grad.mul_scalar(lr);
89
90        (tensor - delta, Some(state))
91    }
92
93    fn to_device<const D: usize>(mut state: Self::State<D>, device: &B::Device) -> Self::State<D> {
94        state.momentum = state.momentum.map(|state| state.to_device(device));
95        state
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102    use crate::{
103        TestAutodiffBackend, TestBackend,
104        grad_clipping::GradientClipping,
105        optim::{GradientsParams, Optimizer},
106    };
107    use burn::tensor::{Distribution, Shape};
108    use burn_nn::{Linear, LinearConfig};
109
110    const LEARNING_RATE: LearningRate = 0.02;
111
112    #[test]
113    fn with_updated_params_should_have_state() {
114        let device = Default::default();
115        let layer = layer::<TestAutodiffBackend>(&device);
116        let mut optim = sgd_with_all();
117        let loss = layer.forward(random_tensor::<TestAutodiffBackend>(&device));
118        let grads = loss.backward();
119        let grads = GradientsParams::from_grads(grads, &layer);
120        let _layer = optim.step(LEARNING_RATE, layer, grads);
121
122        let record = optim.to_record();
123
124        assert!(!record.is_empty());
125    }
126
127    #[test]
128    fn without_updated_params_should_not_have_state() {
129        let optim = sgd_with_all();
130        let record = optim.to_record();
131        assert!(record.is_empty());
132    }
133
134    #[test]
135    fn can_attach_gradient_clipping() {
136        let optim = sgd_with_all().with_grad_clipping(GradientClipping::Value(0.5));
137        assert!(optim.has_gradient_clipping());
138    }
139
140    #[test]
141    fn should_load_state() {
142        let device = Default::default();
143        let layer = layer::<TestAutodiffBackend>(&device);
144        let mut optim = sgd_with_all();
145        let loss = layer.forward(random_tensor(&device));
146        let grads = loss.backward();
147        let grads = GradientsParams::from_grads(grads, &layer);
148        let _layer = optim.step(LEARNING_RATE, layer, grads);
149
150        let record = optim.to_record();
151        let optim_new = sgd_with_all();
152        let record_new = optim_new.to_record();
153        let optim_new = optim_new.load_record(record.clone());
154        let state_restored = optim_new.to_record();
155
156        assert_ne!(record.len(), record_new.len());
157        assert_eq!(record.len(), state_restored.len());
158    }
159
160    fn random_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 2> {
161        Tensor::<B, 2>::random(Shape::new([2, 20]), Distribution::Default, device)
162    }
163
164    fn layer<B: Backend>(device: &B::Device) -> Linear<B> {
165        LinearConfig::new(20, 20).init(device)
166    }
167
168    fn sgd_with_all()
169    -> OptimizerAdaptor<Sgd<TestBackend>, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
170        SgdConfig {
171            weight_decay: Some(WeightDecayConfig { penalty: 0.05 }),
172            momentum: Some(MomentumConfig {
173                momentum: 0.9,
174                dampening: 0.1,
175                nesterov: true,
176            }),
177            gradient_clipping: None,
178        }
179        .init()
180    }
181}