1use burn_core as burn;
2
3use super::GradientsParams;
4use crate::LearningRate;
5use burn::module::AutodiffModule;
6use burn::record::Record;
7use burn::tensor::backend::AutodiffBackend;
8
9pub trait Optimizer<M, B>: Send + Clone
11where
12 M: AutodiffModule<B>,
13 B: AutodiffBackend,
14{
15 type Record: Record<B>;
17
18 fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M;
21
22 fn to_record(&self) -> Self::Record;
24
25 fn load_record(self, record: Self::Record) -> Self;
27}