use crate::prelude::*;
use burn::optim::{LearningRate, SimpleOptimizer};
pub trait LessSimpleOptimizer<B: Backend>: SimpleOptimizer<B> {
fn many_steps<const D: usize>(
&self,
lr_function: impl FnMut(usize) -> LearningRate,
num_steps: usize,
grad_function: impl FnMut(Tensor<B, D>) -> Tensor<B, D>,
tensor: Tensor<B, D>,
state: Option<Self::State<D>>,
) -> (Tensor<B, D>, Option<Self::State<D>>);
}
impl<B: Backend, T: SimpleOptimizer<B>> LessSimpleOptimizer<B> for T {
#[inline]
fn many_steps<const D: usize>(
&self,
mut lr_function: impl FnMut(usize) -> LearningRate,
num_steps: usize,
mut grad_function: impl FnMut(Tensor<B, D>) -> Tensor<B, D>,
mut tensor: Tensor<B, D>,
mut state: Option<Self::State<D>>,
) -> (Tensor<B, D>, Option<Self::State<D>>) {
for step in 0..num_steps {
let cur_grad = grad_function(tensor.clone());
let cur_lr = lr_function(step);
let (new_x, new_state) = self.step(cur_lr, tensor.clone(), cur_grad, state);
tensor = new_x.detach().require_grad();
state = new_state;
}
(tensor, state)
}
}