use burn_core as burn;
use super::cosine::{CosineAnnealingLrScheduler, CosineAnnealingLrSchedulerConfig};
use super::exponential::{ExponentialLrScheduler, ExponentialLrSchedulerConfig};
use super::linear::{LinearLrScheduler, LinearLrSchedulerConfig};
use super::noam::{NoamLrScheduler, NoamLrSchedulerConfig};
use super::{LrScheduler, String};
use crate::LearningRate;
use burn::config::Config;
use burn::record::Record;
use burn::tensor::backend::Backend;
#[derive(Config, Debug)]
pub struct ComposedLrSchedulerConfig {
#[config(default = "Vec::new()")]
schedulers: Vec<LrSchedulerConfig>,
#[config(default = "SchedulerReduction::Prod")]
reduction: SchedulerReduction,
}
#[derive(Clone)]
pub struct ComposedLrScheduler {
schedulers: Vec<LrSchedulerItem>,
reduction: SchedulerReduction,
}
#[derive(Config, Debug, Copy)]
pub enum SchedulerReduction {
Avg,
Sum,
Prod,
}
impl ComposedLrSchedulerConfig {
pub fn init(&self) -> Result<ComposedLrScheduler, String> {
let mut schedulers = Vec::with_capacity(self.schedulers.len());
for config in self.schedulers.iter() {
let config = match config {
LrSchedulerConfig::Linear(config) => LrSchedulerItem::Linear(config.init()?),
LrSchedulerConfig::Cosine(config) => LrSchedulerItem::Cosine(config.init()?),
LrSchedulerConfig::Exponential(config) => {
LrSchedulerItem::Exponential(config.init()?)
}
LrSchedulerConfig::Noam(config) => LrSchedulerItem::Noam(config.init()?),
};
schedulers.push(config);
}
Ok(ComposedLrScheduler {
schedulers,
reduction: self.reduction,
})
}
pub fn linear(mut self, config: LinearLrSchedulerConfig) -> Self {
self.schedulers.push(LrSchedulerConfig::Linear(config));
self
}
pub fn cosine(mut self, config: CosineAnnealingLrSchedulerConfig) -> Self {
self.schedulers.push(LrSchedulerConfig::Cosine(config));
self
}
pub fn exponential(mut self, config: ExponentialLrSchedulerConfig) -> Self {
self.schedulers.push(LrSchedulerConfig::Exponential(config));
self
}
pub fn noam(mut self, config: NoamLrSchedulerConfig) -> Self {
self.schedulers.push(LrSchedulerConfig::Noam(config));
self
}
}
#[derive(Config, Debug)]
enum LrSchedulerConfig {
Linear(LinearLrSchedulerConfig),
Cosine(CosineAnnealingLrSchedulerConfig),
Exponential(ExponentialLrSchedulerConfig),
Noam(NoamLrSchedulerConfig),
}
#[derive(Clone)]
enum LrSchedulerItem {
Linear(LinearLrScheduler),
Cosine(CosineAnnealingLrScheduler),
Exponential(ExponentialLrScheduler),
Noam(NoamLrScheduler),
}
#[derive(Record)]
pub enum LrSchedulerRecord<B: Backend> {
Linear(<LinearLrScheduler as LrScheduler>::Record<B>),
Cosine(<CosineAnnealingLrScheduler as LrScheduler>::Record<B>),
Exponential(<ExponentialLrScheduler as LrScheduler>::Record<B>),
Noam(<NoamLrScheduler as LrScheduler>::Record<B>),
}
#[derive(Record)]
pub struct ComposedLrSchedulerRecord<B: Backend> {
schedulers: Vec<LrSchedulerRecord<B>>,
}
impl LrScheduler for ComposedLrScheduler {
type Record<B: Backend> = ComposedLrSchedulerRecord<B>;
fn step(&mut self) -> LearningRate {
let mut step = match self.reduction {
SchedulerReduction::Avg => 0.0,
SchedulerReduction::Sum => 0.0,
SchedulerReduction::Prod => 1.0,
};
let num_scheduler = self.schedulers.len() as f64;
for lr in self.schedulers.iter_mut().map(|s| match s {
LrSchedulerItem::Linear(item) => item.step(),
LrSchedulerItem::Cosine(item) => item.step(),
LrSchedulerItem::Exponential(item) => item.step(),
LrSchedulerItem::Noam(item) => item.step(),
}) {
step = match self.reduction {
SchedulerReduction::Avg => step + (lr / num_scheduler),
SchedulerReduction::Sum => step + lr,
SchedulerReduction::Prod => step * lr,
}
}
step
}
fn to_record<B: Backend>(&self) -> Self::Record<B> {
ComposedLrSchedulerRecord::<B> {
schedulers: self
.schedulers
.iter()
.map(|s| match s {
LrSchedulerItem::Linear(item) => {
LrSchedulerRecord::Linear(item.to_record::<B>())
}
LrSchedulerItem::Cosine(item) => {
LrSchedulerRecord::Linear(item.to_record::<B>())
}
LrSchedulerItem::Exponential(item) => {
LrSchedulerRecord::Exponential(item.to_record::<B>())
}
LrSchedulerItem::Noam(item) => LrSchedulerRecord::Noam(item.to_record::<B>()),
})
.collect(),
}
}
fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
self.schedulers = self
.schedulers
.into_iter()
.zip(record.schedulers)
.map(|scheduler| match scheduler {
(LrSchedulerItem::Linear(item), LrSchedulerRecord::Linear(record)) => {
LrSchedulerItem::Linear(item.load_record::<B>(record))
}
(LrSchedulerItem::Cosine(item), LrSchedulerRecord::Cosine(record)) => {
LrSchedulerItem::Cosine(item.load_record::<B>(record))
}
(LrSchedulerItem::Exponential(item), LrSchedulerRecord::Exponential(record)) => {
LrSchedulerItem::Exponential(item.load_record::<B>(record))
}
(LrSchedulerItem::Noam(item), LrSchedulerRecord::Noam(record)) => {
LrSchedulerItem::Noam(item.load_record::<B>(record))
}
_ => panic!("Invalid state"),
})
.collect();
self
}
}