burn_core/optim/
decay.rs

1use burn_tensor::backend::Backend;
2
3use crate as burn;
4use crate::record::Record;
5
6use crate::config::Config;
7use crate::tensor::Tensor;
8
9/// Configuration to create [weight decay](WeightDecay).
10#[derive(Config)]
11pub struct WeightDecayConfig {
12    /// L2 penalty.
13    pub penalty: f32,
14}
15
16/// State of [weight decay](WeightDecay).
17#[derive(Record, Clone, new)]
18pub struct WeightDecayState<B: Backend, const D: usize> {
19    pub(crate) grad_last_step: Tensor<B, D>,
20}
21
22/// Weight decay implementation that transforms gradients.
23#[derive(Clone)]
24pub struct WeightDecay {
25    penalty: f32,
26}
27
28impl WeightDecay {
29    /// Creates a new [weight decay](WeightDecay) from a [config](WeightDecayConfig).
30    pub fn new(config: &WeightDecayConfig) -> Self {
31        Self {
32            penalty: config.penalty,
33        }
34    }
35
36    /// Transforms a gradient.
37    ///
38    /// # Arguments
39    ///
40    /// * `grad` - Gradient to transform.
41    /// * `tensor` - Tensor param of the last iteration.
42    ///
43    /// # Returns
44    ///
45    /// * `grad` - Transformed gradient.
46    pub fn transform<B: Backend, const D: usize>(
47        &self,
48        grad: Tensor<B, D>,
49        tensor: Tensor<B, D>,
50    ) -> Tensor<B, D> {
51        tensor.mul_scalar(self.penalty).add(grad)
52    }
53}
54
55impl<B: Backend, const D: usize> WeightDecayState<B, D> {
56    /// Moves the state to a device.
57    ///
58    /// # Arguments
59    ///
60    /// * `device` - Device to move the state to.
61    ///
62    /// # Returns
63    ///
64    /// * `self` - Moved state.
65    pub fn to_device(mut self, device: &B::Device) -> Self {
66        self.grad_last_step = self.grad_last_step.to_device(device);
67        self
68    }
69}