burn_optim/optim/
decay.rs1use burn_core as burn;
2
3use burn::config::Config;
4use burn::record::Record;
5use burn::tensor::Tensor;
6use burn::tensor::backend::Backend;
7
8#[derive(Config, Debug)]
10pub struct WeightDecayConfig {
11 pub penalty: f32,
13}
14
15#[derive(Record, Clone, new)]
17pub struct WeightDecayState<B: Backend, const D: usize> {
18 pub(crate) grad_last_step: Tensor<B, D>,
19}
20
21#[derive(Clone)]
23pub struct WeightDecay {
24 penalty: f32,
25}
26
27impl WeightDecay {
28 pub fn new(config: &WeightDecayConfig) -> Self {
30 Self {
31 penalty: config.penalty,
32 }
33 }
34
35 pub fn transform<B: Backend, const D: usize>(
46 &self,
47 grad: Tensor<B, D>,
48 tensor: Tensor<B, D>,
49 ) -> Tensor<B, D> {
50 tensor.mul_scalar(self.penalty).add(grad)
51 }
52}
53
54impl<B: Backend, const D: usize> WeightDecayState<B, D> {
55 pub fn to_device(mut self, device: &B::Device) -> Self {
65 self.grad_last_step = self.grad_last_step.to_device(device);
66 self
67 }
68}