use super::GradientsParams;
use crate::module::ADModule;
use crate::record::Record;
use crate::tensor::backend::ADBackend;
use crate::LearningRate;
/// General trait to optimize [module](ADModule).
pub trait Optimizer<M, B>: Send + Sync
where
M: ADModule<B>,
B: ADBackend,
{
/// Optimizer associative type to be used when saving and loading the state.
type Record: Record;
/// Perform the optimizer step using the given learning rate and gradients.
/// The updated module is returned.
fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M;
/// Get the current state of the optimizer as a [record](Record).
fn to_record(&self) -> Self::Record;
/// Load the state of the optimizer as a [record](Record).
fn load_record(self, record: Self::Record) -> Self;
}