use serde::{Deserialize, Serialize};
use crate::error::NetworkError;
use super::{LearningRateScheduler, LearningRateSchedulerClone};
#[derive(Serialize, Deserialize, Clone)]
struct StepLRScheduler {
decay_rate: f32, step_size: usize, }
impl StepLRScheduler {
fn new(decay_rate: f32, step_size: usize) -> Self {
Self { decay_rate, step_size }
}
}
#[typetag::serde]
impl LearningRateScheduler for StepLRScheduler {
fn schedule(&self, epoch: usize, current_learning_rate: f32) -> f32 {
if epoch % self.step_size == 0 {
current_learning_rate * self.decay_rate
} else {
current_learning_rate
}
}
}
impl LearningRateSchedulerClone for StepLRScheduler {
fn clone_box(&self) -> Box<dyn LearningRateScheduler> {
Box::new(self.clone())
}
}
pub struct Step {
decay_rate: f32,
step_size: usize,
}
impl Step {
fn new() -> Self {
Self {
decay_rate: 0.9,
step_size: 10,
}
}
pub fn decay_rate(mut self, decay_rate: f32) -> Self {
self.decay_rate = decay_rate;
self
}
pub fn step_size(mut self, step_size: usize) -> Self {
self.step_size = step_size;
self
}
fn validate(&self) -> Result<(), NetworkError> {
if self.decay_rate <= 0.0 || self.decay_rate >= 1.0 {
return Err(NetworkError::ConfigError(format!(
"Decay rate for Step must be in the range (0, 1), but was {}",
self.decay_rate
)));
}
if self.step_size == 0 {
return Err(NetworkError::ConfigError(format!(
"Step size for Step must be greater than 0, but was {}",
self.step_size
)));
}
Ok(())
}
pub fn build(self) -> Result<Box<dyn LearningRateScheduler>, NetworkError> {
self.validate()?;
Ok(Box::new(StepLRScheduler::new(self.decay_rate, self.step_size)))
}
}
impl Default for Step {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_step_lr_scheduler() {
let scheduler = StepLRScheduler::new(0.5, 10);
assert_eq!(scheduler.schedule(10, 0.1), 0.05); assert_eq!(scheduler.schedule(15, 0.1), 0.1); assert_eq!(scheduler.schedule(20, 0.1), 0.05); }
#[test]
fn test_step_builder() {
let scheduler = Step::new()
.decay_rate(0.8)
.step_size(5)
.build()
.expect("Failed to build StepLRScheduler");
assert_eq!(scheduler.schedule(5, 0.1), 0.080000006); assert_eq!(scheduler.schedule(6, 0.1), 0.1); assert_eq!(scheduler.schedule(10, 0.1), 0.080000006); }
#[test]
fn test_step_builder_invalid_decay_rate() {
let step = Step::new().decay_rate(1.0).step_size(5).build();
assert!(step.is_err());
if let Err(err) = step {
assert_eq!(
err.to_string(),
"Configuration error: Decay rate for Step must be in the range (0, 1), but was 1"
);
}
}
}