use tch::nn::{self};
use tch::Tensor;
mod grad;
pub use grad::ZeroGrad;
mod grad_scale;
pub use grad_scale::GradScaler;
pub trait Optimizer {
fn backward_step(&mut self, loss: &Tensor);
fn set_lr_group(&mut self, group: usize, learning_rate: f64);
fn set_weight_decay_group(&mut self, group: usize, weight_decay: f64);
fn step(&mut self);
fn trainable_variables(&self) -> Vec<Tensor>;
}
impl<C> Optimizer for nn::Optimizer<C> {
fn backward_step(&mut self, loss: &Tensor) {
nn::Optimizer::backward_step(self, loss)
}
fn set_lr_group(&mut self, group: usize, learning_rate: f64) {
nn::Optimizer::set_lr_group(self, group, learning_rate)
}
fn set_weight_decay_group(&mut self, group: usize, weight_decay: f64) {
nn::Optimizer::set_weight_decay_group(self, group, weight_decay)
}
fn step(&mut self) {
nn::Optimizer::step(self)
}
fn trainable_variables(&self) -> Vec<Tensor> {
nn::Optimizer::trainable_variables(self)
}
}