burn_optim/optim/
decay.rs

1use burn_core as burn;
2
3use burn::config::Config;
4use burn::record::Record;
5use burn::tensor::Tensor;
6use burn::tensor::backend::Backend;
7
8/// Configuration to create [weight decay](WeightDecay).
9#[derive(Config, Debug)]
10pub struct WeightDecayConfig {
11    /// L2 penalty.
12    pub penalty: f32,
13}
14
15/// State of [weight decay](WeightDecay).
16#[derive(Record, Clone, new)]
17pub struct WeightDecayState<B: Backend, const D: usize> {
18    pub(crate) grad_last_step: Tensor<B, D>,
19}
20
21/// Weight decay implementation that transforms gradients.
22#[derive(Clone)]
23pub struct WeightDecay {
24    penalty: f32,
25}
26
27impl WeightDecay {
28    /// Creates a new [weight decay](WeightDecay) from a [config](WeightDecayConfig).
29    pub fn new(config: &WeightDecayConfig) -> Self {
30        Self {
31            penalty: config.penalty,
32        }
33    }
34
35    /// Transforms a gradient.
36    ///
37    /// # Arguments
38    ///
39    /// * `grad` - Gradient to transform.
40    /// * `tensor` - Tensor param of the last iteration.
41    ///
42    /// # Returns
43    ///
44    /// * `grad` - Transformed gradient.
45    pub fn transform<B: Backend, const D: usize>(
46        &self,
47        grad: Tensor<B, D>,
48        tensor: Tensor<B, D>,
49    ) -> Tensor<B, D> {
50        tensor.mul_scalar(self.penalty).add(grad)
51    }
52}
53
54impl<B: Backend, const D: usize> WeightDecayState<B, D> {
55    /// Moves the state to a device.
56    ///
57    /// # Arguments
58    ///
59    /// * `device` - Device to move the state to.
60    ///
61    /// # Returns
62    ///
63    /// * `self` - Moved state.
64    pub fn to_device(mut self, device: &B::Device) -> Self {
65        self.grad_last_step = self.grad_last_step.to_device(device);
66        self
67    }
68}