use burn_core as burn;
use burn::config::Config;
use burn::record::Record;
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
#[derive(Config, Debug)]
pub struct WeightDecayConfig {
pub penalty: f32,
}
#[derive(Record, Clone, new)]
pub struct WeightDecayState<B: Backend, const D: usize> {
pub(crate) grad_last_step: Tensor<B, D>,
}
#[derive(Clone)]
pub struct WeightDecay {
penalty: f32,
}
impl WeightDecay {
pub fn new(config: &WeightDecayConfig) -> Self {
Self {
penalty: config.penalty,
}
}
pub fn transform<B: Backend, const D: usize>(
&self,
grad: Tensor<B, D>,
tensor: Tensor<B, D>,
) -> Tensor<B, D> {
tensor.mul_scalar(self.penalty).add(grad)
}
}
impl<B: Backend, const D: usize> WeightDecayState<B, D> {
pub fn to_device(mut self, device: &B::Device) -> Self {
self.grad_last_step = self.grad_last_step.to_device(device);
self
}
}