use burn_core as burn;
use crate::LearningRate;
use burn::record::Record;
use burn::tensor::{Tensor, backend::Backend};
pub trait SimpleOptimizer<B>: Send + Sync + Clone
where
B: Backend,
{
type State<const D: usize>: Record<B> + Clone + 'static;
fn step<const D: usize>(
&self,
lr: LearningRate,
tensor: Tensor<B, D>,
grad: Tensor<B, D>,
state: Option<Self::State<D>>,
) -> (Tensor<B, D>, Option<Self::State<D>>);
fn to_device<const D: usize>(state: Self::State<D>, device: &B::Device) -> Self::State<D>;
}