burn-optim 0.20.1

Optimizer building blocks for the Burn deep learning framework
Documentation
use burn_core as burn;

use super::cosine::{CosineAnnealingLrScheduler, CosineAnnealingLrSchedulerConfig};
use super::exponential::{ExponentialLrScheduler, ExponentialLrSchedulerConfig};
use super::linear::{LinearLrScheduler, LinearLrSchedulerConfig};
use super::noam::{NoamLrScheduler, NoamLrSchedulerConfig};
use super::{LrScheduler, String};
use crate::LearningRate;

use burn::config::Config;
use burn::record::Record;
use burn::tensor::backend::Backend;

/// Compose multiple [learning rate schedulers](LrScheduler) together.
#[derive(Config, Debug)]
pub struct ComposedLrSchedulerConfig {
    #[config(default = "Vec::new()")]
    schedulers: Vec<LrSchedulerConfig>,
    #[config(default = "SchedulerReduction::Prod")]
    reduction: SchedulerReduction,
}

/// Compose multiple [learning rate schedulers](LrScheduler) together.
#[derive(Clone)]
pub struct ComposedLrScheduler {
    schedulers: Vec<LrSchedulerItem>,
    reduction: SchedulerReduction,
}

/// Defines how the learning rates generated by the schedulers are combined.
#[derive(Config, Debug, Copy)]
pub enum SchedulerReduction {
    /// All learning rates are averaged.
    Avg,
    /// All learning rates are summed.
    Sum,
    /// All learning rates are multiplied.
    Prod,
}

impl ComposedLrSchedulerConfig {
    /// Initialize the learning rate scheduler.
    pub fn init(&self) -> Result<ComposedLrScheduler, String> {
        let mut schedulers = Vec::with_capacity(self.schedulers.len());
        for config in self.schedulers.iter() {
            let config = match config {
                LrSchedulerConfig::Linear(config) => LrSchedulerItem::Linear(config.init()?),
                LrSchedulerConfig::Cosine(config) => LrSchedulerItem::Cosine(config.init()?),
                LrSchedulerConfig::Exponential(config) => {
                    LrSchedulerItem::Exponential(config.init()?)
                }
                LrSchedulerConfig::Noam(config) => LrSchedulerItem::Noam(config.init()?),
            };
            schedulers.push(config);
        }

        Ok(ComposedLrScheduler {
            schedulers,
            reduction: self.reduction,
        })
    }

    /// Appends a [linear scheduler](LinearLrScheduler).
    pub fn linear(mut self, config: LinearLrSchedulerConfig) -> Self {
        self.schedulers.push(LrSchedulerConfig::Linear(config));
        self
    }

    /// Appends a [cosine scheduler](ComposedLrSchedulerConfig).
    pub fn cosine(mut self, config: CosineAnnealingLrSchedulerConfig) -> Self {
        self.schedulers.push(LrSchedulerConfig::Cosine(config));
        self
    }

    /// Appends an [exponential scheduler](ExponentialLrScheduler).
    pub fn exponential(mut self, config: ExponentialLrSchedulerConfig) -> Self {
        self.schedulers.push(LrSchedulerConfig::Exponential(config));
        self
    }

    /// Appends a [noam scheduler](NoamLrScheduler).
    pub fn noam(mut self, config: NoamLrSchedulerConfig) -> Self {
        self.schedulers.push(LrSchedulerConfig::Noam(config));
        self
    }
}

#[derive(Config, Debug)]
enum LrSchedulerConfig {
    Linear(LinearLrSchedulerConfig),
    Cosine(CosineAnnealingLrSchedulerConfig),
    Exponential(ExponentialLrSchedulerConfig),
    Noam(NoamLrSchedulerConfig),
}

#[derive(Clone)]
enum LrSchedulerItem {
    Linear(LinearLrScheduler),
    Cosine(CosineAnnealingLrScheduler),
    Exponential(ExponentialLrScheduler),
    Noam(NoamLrScheduler),
}

#[derive(Record)]
/// Record item for the [componsed learning rate scheduler](ComposedLrScheduler).
pub enum LrSchedulerRecord<B: Backend> {
    /// The linear variant.
    Linear(<LinearLrScheduler as LrScheduler>::Record<B>),
    /// The cosine variant.
    Cosine(<CosineAnnealingLrScheduler as LrScheduler>::Record<B>),
    /// The exponential variant.
    Exponential(<ExponentialLrScheduler as LrScheduler>::Record<B>),
    /// The noam variant.
    Noam(<NoamLrScheduler as LrScheduler>::Record<B>),
}

#[derive(Record)]
/// Records for the [componsed learning rate scheduler](ComposedLrScheduler).
pub struct ComposedLrSchedulerRecord<B: Backend> {
    schedulers: Vec<LrSchedulerRecord<B>>,
}

impl LrScheduler for ComposedLrScheduler {
    type Record<B: Backend> = ComposedLrSchedulerRecord<B>;

    fn step(&mut self) -> LearningRate {
        let mut step = match self.reduction {
            SchedulerReduction::Avg => 0.0,
            SchedulerReduction::Sum => 0.0,
            SchedulerReduction::Prod => 1.0,
        };
        let num_scheduler = self.schedulers.len() as f64;

        for lr in self.schedulers.iter_mut().map(|s| match s {
            LrSchedulerItem::Linear(item) => item.step(),
            LrSchedulerItem::Cosine(item) => item.step(),
            LrSchedulerItem::Exponential(item) => item.step(),
            LrSchedulerItem::Noam(item) => item.step(),
        }) {
            step = match self.reduction {
                SchedulerReduction::Avg => step + (lr / num_scheduler),
                SchedulerReduction::Sum => step + lr,
                SchedulerReduction::Prod => step * lr,
            }
        }

        step
    }

    fn to_record<B: Backend>(&self) -> Self::Record<B> {
        ComposedLrSchedulerRecord::<B> {
            schedulers: self
                .schedulers
                .iter()
                .map(|s| match s {
                    LrSchedulerItem::Linear(item) => {
                        LrSchedulerRecord::Linear(item.to_record::<B>())
                    }
                    LrSchedulerItem::Cosine(item) => {
                        LrSchedulerRecord::Linear(item.to_record::<B>())
                    }
                    LrSchedulerItem::Exponential(item) => {
                        LrSchedulerRecord::Exponential(item.to_record::<B>())
                    }
                    LrSchedulerItem::Noam(item) => LrSchedulerRecord::Noam(item.to_record::<B>()),
                })
                .collect(),
        }
    }

    fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
        self.schedulers = self
            .schedulers
            .into_iter()
            .zip(record.schedulers)
            .map(|scheduler| match scheduler {
                (LrSchedulerItem::Linear(item), LrSchedulerRecord::Linear(record)) => {
                    LrSchedulerItem::Linear(item.load_record::<B>(record))
                }
                (LrSchedulerItem::Cosine(item), LrSchedulerRecord::Cosine(record)) => {
                    LrSchedulerItem::Cosine(item.load_record::<B>(record))
                }
                (LrSchedulerItem::Exponential(item), LrSchedulerRecord::Exponential(record)) => {
                    LrSchedulerItem::Exponential(item.load_record::<B>(record))
                }
                (LrSchedulerItem::Noam(item), LrSchedulerRecord::Noam(record)) => {
                    LrSchedulerItem::Noam(item.load_record::<B>(record))
                }
                _ => panic!("Invalid state"),
            })
            .collect();

        self
    }
}