1use burn_tensor::backend::Backend;
2
3use crate as burn;
4use crate::record::Record;
5
6use crate::config::Config;
7use crate::tensor::Tensor;
8
9#[derive(Config)]
11pub struct WeightDecayConfig {
12 pub penalty: f32,
14}
15
16#[derive(Record, Clone, new)]
18pub struct WeightDecayState<B: Backend, const D: usize> {
19 pub(crate) grad_last_step: Tensor<B, D>,
20}
21
22#[derive(Clone)]
24pub struct WeightDecay {
25 penalty: f32,
26}
27
28impl WeightDecay {
29 pub fn new(config: &WeightDecayConfig) -> Self {
31 Self {
32 penalty: config.penalty,
33 }
34 }
35
36 pub fn transform<B: Backend, const D: usize>(
47 &self,
48 grad: Tensor<B, D>,
49 tensor: Tensor<B, D>,
50 ) -> Tensor<B, D> {
51 tensor.mul_scalar(self.penalty).add(grad)
52 }
53}
54
55impl<B: Backend, const D: usize> WeightDecayState<B, D> {
56 pub fn to_device(mut self, device: &B::Device) -> Self {
66 self.grad_last_step = self.grad_last_step.to_device(device);
67 self
68 }
69}