use burn_core as burn;
use burn::config::Config;
use burn::tensor::backend::Backend;
use super::{LrScheduler, String};
use crate::LearningRate;
#[derive(Config, Debug)]
pub struct StepLrSchedulerConfig {
initial_lr: LearningRate,
step_size: usize,
#[config(default = 0.1)]
gamma: f64,
}
impl StepLrSchedulerConfig {
pub fn init(&self) -> Result<StepLrScheduler, String> {
if self.step_size == 0 {
return Err("Step size must be greater than 0".into());
}
if self.initial_lr <= 0.0 {
log::warn!(
"Initial learning rate value of {} is not a positive number. Ignore this warning \
if it is intended.",
self.initial_lr
);
}
if self.gamma <= 0.0 || self.gamma >= 1.0 {
log::warn!(
"Gamma value of {} is out of range (0.0, 1.0). Ignore this warning if it is \
intended.",
self.gamma
);
}
Ok(StepLrScheduler {
init_lr: self.initial_lr,
step_size: self.step_size,
gamma: self.gamma,
iter_idx: -1,
})
}
}
#[derive(Clone, Debug)]
pub struct StepLrScheduler {
init_lr: LearningRate,
step_size: usize,
gamma: f64,
iter_idx: i32,
}
impl LrScheduler for StepLrScheduler {
type Record<B: Backend> = i32;
fn step(&mut self) -> LearningRate {
self.iter_idx = self
.iter_idx
.checked_add(1)
.expect("`.step()` should be called no more than `i32::MAX + 1` times");
self.init_lr
* self
.gamma
.powi((self.iter_idx as usize / self.step_size) as i32)
}
fn to_record<B: Backend>(&self) -> Self::Record<B> {
self.iter_idx
}
fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
self.iter_idx = record;
self
}
}
#[cfg(test)]
mod tests {
use super::super::test_utils;
use super::*;
use crate::TestBackend;
#[test]
fn test_config_step_size_zero() {
let r = StepLrSchedulerConfig::new(1.0, 0).init();
assert!(r.is_err(), "Should return an error");
}
#[test]
fn test_config_step_size_nonzero() {
let r = StepLrSchedulerConfig::new(1.0, 1).init();
assert!(r.is_ok(), "Should return a success value");
}
#[test]
fn test_config_default_gamma() {
const INIT_LR: LearningRate = 0.4;
const STEP_SIZE: usize = 2;
let mut default = StepLrSchedulerConfig::new(INIT_LR, STEP_SIZE)
.init()
.unwrap();
let mut explicit = StepLrSchedulerConfig::new(INIT_LR, STEP_SIZE)
.with_gamma(0.1)
.init()
.unwrap();
test_utils::compare_steps(&mut default, &mut explicit, 3 * STEP_SIZE);
}
#[test]
fn test_lr_decreasing() {
let scheduler = StepLrSchedulerConfig::new(0.5, 3)
.with_gamma(0.1)
.init()
.unwrap();
let expected_lrs = [0.5, 0.5, 0.5, 0.05, 0.05, 0.05, 0.005, 0.005, 0.005];
test_utils::check_lr_sequence(scheduler, expected_lrs);
}
#[test]
fn test_lr_increasing() {
let scheduler = StepLrSchedulerConfig::new(0.1, 2)
.with_gamma(2.0)
.init()
.unwrap();
let expected_lrs = [0.1, 0.1, 0.2, 0.2, 0.4, 0.4];
test_utils::check_lr_sequence(scheduler, expected_lrs);
}
#[test]
fn test_lr_unchanging() {
let scheduler = StepLrSchedulerConfig::new(3.1, 1)
.with_gamma(1.0)
.init()
.unwrap();
let expected_lrs = [3.1, 3.1, 3.1];
test_utils::check_lr_sequence(scheduler, expected_lrs);
}
#[test]
fn test_save_and_load() {
const STEP_SIZE: usize = 10;
let scheduler = StepLrSchedulerConfig::new(0.007, STEP_SIZE)
.with_gamma(0.03)
.init()
.unwrap();
test_utils::check_save_load(scheduler, 3 * STEP_SIZE / 2);
}
#[test]
fn test_number_of_calls_within_limit() {
let mut scheduler = StepLrSchedulerConfig::new(0.1, 2).init().unwrap();
scheduler = scheduler.load_record::<TestBackend>(i32::MAX - 1);
scheduler.step();
}
#[test]
#[should_panic = "i32::MAX"]
fn test_number_of_calls_over_limit() {
let mut scheduler = StepLrSchedulerConfig::new(0.1, 2).init().unwrap();
scheduler = scheduler.load_record::<TestBackend>(i32::MAX - 1);
scheduler.step();
scheduler.step();
}
}