use super::Optimizer;
use std::cell::Cell;
pub trait LRScheduler {
fn step(&self);
fn get_last_lr(&self) -> f32;
fn get_current_lr(&self) -> f32;
fn get_current_epoch(&self) -> usize;
fn set_current_epoch(&self, epoch: usize);
fn print_lr(&self) {
println!(
"epoch {}: learning rate adjusted to [{}]",
self.get_current_epoch(),
self.get_current_lr()
);
}
}
fn prepare_step(last_lr: &Cell<f32>, current_lr: &Cell<f32>, current_epoch: &Cell<usize>) {
last_lr.set(current_lr.get());
let last_epoch = current_epoch.get();
current_epoch.set(last_epoch + 1);
}
pub struct LambdaLR<'a, T: Optimizer<'a>, F: Fn(usize) -> f32> {
optimizer: &'a T,
lr_fn: F,
current_epoch: Cell<usize>,
current_lr: Cell<f32>,
last_lr: Cell<f32>,
initial_lr: Cell<f32>,
}
impl<'a, T: Optimizer<'a>, F: Fn(usize) -> f32> LambdaLR<'a, T, F> {
pub fn new(optimizer: &'a T, lr_fn: F) -> Self {
let current_lr = optimizer.get_lr();
Self {
optimizer,
lr_fn,
current_epoch: Cell::new(0),
current_lr: Cell::new(current_lr),
last_lr: Cell::new(0.0),
initial_lr: Cell::new(current_lr),
}
}
pub fn step(&self) {
LRScheduler::step(self);
}
pub fn get_last_lr(&self) -> f32 {
LRScheduler::get_last_lr(self)
}
pub fn get_current_lr(&self) -> f32 {
LRScheduler::get_current_lr(self)
}
pub fn set_current_epoch(&self, epoch: usize) {
LRScheduler::set_current_epoch(self, epoch);
}
pub fn get_current_epoch(&self) -> usize {
LRScheduler::get_current_epoch(self)
}
pub fn print_lr(&self) {
LRScheduler::print_lr(self);
}
}
impl<'a, T: Optimizer<'a>, F: Fn(usize) -> f32> LRScheduler for LambdaLR<'a, T, F> {
fn step(&self) {
prepare_step(&self.last_lr, &self.current_lr, &self.current_epoch);
self.current_lr
.set(self.initial_lr.get() * (self.lr_fn)(self.current_epoch.get()));
self.optimizer.set_lr(self.current_lr.get());
}
fn get_last_lr(&self) -> f32 {
self.last_lr.get()
}
fn get_current_lr(&self) -> f32 {
self.current_lr.get()
}
fn set_current_epoch(&self, epoch: usize) {
self.current_epoch.replace(epoch);
}
fn get_current_epoch(&self) -> usize {
self.current_epoch.get()
}
}
pub struct MultiplicativeLR<'a, T: Optimizer<'a>, F: Fn(usize) -> f32> {
optimizer: &'a T,
lr_fn: F,
current_epoch: Cell<usize>,
current_lr: Cell<f32>,
last_lr: Cell<f32>,
}
impl<'a, T: Optimizer<'a>, F: Fn(usize) -> f32> MultiplicativeLR<'a, T, F> {
pub fn new(optimizer: &'a T, lr_fn: F) -> Self {
let current_lr = optimizer.get_lr();
Self {
optimizer,
lr_fn,
current_epoch: Cell::new(0),
current_lr: Cell::new(current_lr),
last_lr: Cell::new(0.0),
}
}
pub fn step(&self) {
LRScheduler::step(self);
}
pub fn get_last_lr(&self) -> f32 {
LRScheduler::get_last_lr(self)
}
pub fn get_current_lr(&self) -> f32 {
LRScheduler::get_current_lr(self)
}
pub fn set_current_epoch(&self, epoch: usize) {
LRScheduler::set_current_epoch(self, epoch);
}
pub fn get_current_epoch(&self) -> usize {
LRScheduler::get_current_epoch(self)
}
pub fn print_lr(&self) {
LRScheduler::print_lr(self);
}
}
impl<'a, T: Optimizer<'a>, F: Fn(usize) -> f32> LRScheduler for MultiplicativeLR<'a, T, F> {
fn step(&self) {
prepare_step(&self.last_lr, &self.current_lr, &self.current_epoch);
self.current_lr
.set(self.last_lr.get() * (self.lr_fn)(self.current_epoch.get()));
self.optimizer.set_lr(self.current_lr.get());
}
fn get_last_lr(&self) -> f32 {
self.last_lr.get()
}
fn get_current_lr(&self) -> f32 {
self.current_lr.get()
}
fn set_current_epoch(&self, epoch: usize) {
self.current_epoch.replace(epoch);
}
fn get_current_epoch(&self) -> usize {
self.current_epoch.get()
}
}
pub struct StepLR<'a, T: Optimizer<'a>> {
optimizer: &'a T,
gamma: f32,
step_size: usize,
current_epoch: Cell<usize>,
current_lr: Cell<f32>,
last_lr: Cell<f32>,
}
impl<'a, T: Optimizer<'a>> StepLR<'a, T> {
pub fn new(optimizer: &'a T, step_size: usize, gamma: f32) -> Self {
let current_lr = optimizer.get_lr();
Self {
optimizer,
gamma,
step_size,
current_epoch: Cell::new(0),
current_lr: Cell::new(current_lr),
last_lr: Cell::new(0.0),
}
}
pub fn step(&self) {
LRScheduler::step(self);
}
pub fn get_last_lr(&self) -> f32 {
LRScheduler::get_last_lr(self)
}
pub fn get_current_lr(&self) -> f32 {
LRScheduler::get_current_lr(self)
}
pub fn set_current_epoch(&self, epoch: usize) {
LRScheduler::set_current_epoch(self, epoch);
}
pub fn get_current_epoch(&self) -> usize {
LRScheduler::get_current_epoch(self)
}
pub fn print_lr(&self) {
LRScheduler::print_lr(self);
}
}
impl<'a, T: Optimizer<'a>> LRScheduler for StepLR<'a, T> {
fn step(&self) {
prepare_step(&self.last_lr, &self.current_lr, &self.current_epoch);
if self.current_epoch.get().rem_euclid(self.step_size) == 0 {
self.current_lr.set(self.last_lr.get() * self.gamma);
self.optimizer.set_lr(self.current_lr.get());
}
}
fn get_last_lr(&self) -> f32 {
self.last_lr.get()
}
fn get_current_lr(&self) -> f32 {
self.current_lr.get()
}
fn set_current_epoch(&self, epoch: usize) {
self.current_epoch.replace(epoch);
}
fn get_current_epoch(&self) -> usize {
self.current_epoch.get()
}
}
pub struct MultiStepLR<'a, T: Optimizer<'a>, const N: usize> {
optimizer: &'a T,
gamma: f32,
milestones: [usize; N],
current_epoch: Cell<usize>,
current_lr: Cell<f32>,
last_lr: Cell<f32>,
}
impl<'a, T: Optimizer<'a>, const N: usize> MultiStepLR<'a, T, N> {
pub fn new(optimizer: &'a T, milestones: [usize; N], gamma: f32) -> Self {
let current_lr = optimizer.get_lr();
Self {
optimizer,
gamma,
milestones,
current_epoch: Cell::new(0),
current_lr: Cell::new(current_lr),
last_lr: Cell::new(0.0),
}
}
pub fn step(&self) {
LRScheduler::step(self);
}
pub fn get_last_lr(&self) -> f32 {
LRScheduler::get_last_lr(self)
}
pub fn get_current_lr(&self) -> f32 {
LRScheduler::get_current_lr(self)
}
pub fn set_current_epoch(&self, epoch: usize) {
LRScheduler::set_current_epoch(self, epoch);
}
pub fn get_current_epoch(&self) -> usize {
LRScheduler::get_current_epoch(self)
}
pub fn print_lr(&self) {
LRScheduler::print_lr(self);
}
}
impl<'a, T: Optimizer<'a>, const N: usize> LRScheduler for MultiStepLR<'a, T, N> {
fn step(&self) {
prepare_step(&self.last_lr, &self.current_lr, &self.current_epoch);
if self
.milestones
.iter()
.any(|milestone| *milestone == self.current_epoch.get())
{
self.current_lr.set(self.last_lr.get() * self.gamma);
self.optimizer.set_lr(self.current_lr.get());
}
}
fn get_last_lr(&self) -> f32 {
self.last_lr.get()
}
fn get_current_lr(&self) -> f32 {
self.current_lr.get()
}
fn set_current_epoch(&self, epoch: usize) {
self.current_epoch.replace(epoch);
}
fn get_current_epoch(&self) -> usize {
self.current_epoch.get()
}
}
pub struct ExponentialLR<'a, T: Optimizer<'a>> {
optimizer: &'a T,
gamma: f32,
current_epoch: Cell<usize>,
current_lr: Cell<f32>,
last_lr: Cell<f32>,
}
impl<'a, T: Optimizer<'a>> ExponentialLR<'a, T> {
pub fn new(optimizer: &'a T, gamma: f32) -> Self {
let current_lr = optimizer.get_lr();
Self {
optimizer,
gamma,
current_epoch: Cell::new(0),
current_lr: Cell::new(current_lr),
last_lr: Cell::new(0.0),
}
}
pub fn step(&self) {
LRScheduler::step(self);
}
pub fn get_last_lr(&self) -> f32 {
LRScheduler::get_last_lr(self)
}
pub fn get_current_lr(&self) -> f32 {
LRScheduler::get_current_lr(self)
}
pub fn set_current_epoch(&self, epoch: usize) {
LRScheduler::set_current_epoch(self, epoch);
}
pub fn get_current_epoch(&self) -> usize {
LRScheduler::get_current_epoch(self)
}
pub fn print_lr(&self) {
LRScheduler::print_lr(self);
}
}
impl<'a, T: Optimizer<'a>> LRScheduler for ExponentialLR<'a, T> {
fn step(&self) {
prepare_step(&self.last_lr, &self.current_lr, &self.current_epoch);
self.current_lr.set(self.last_lr.get() * self.gamma);
self.optimizer.set_lr(self.current_lr.get());
}
fn get_last_lr(&self) -> f32 {
self.last_lr.get()
}
fn get_current_lr(&self) -> f32 {
self.current_lr.get()
}
fn set_current_epoch(&self, epoch: usize) {
self.current_epoch.replace(epoch);
}
fn get_current_epoch(&self) -> usize {
self.current_epoch.get()
}
}
#[cfg(test)]
mod test;