burn_optim/lr_scheduler/
composed.rs

1use burn_core as burn;
2
3use super::cosine::{CosineAnnealingLrScheduler, CosineAnnealingLrSchedulerConfig};
4use super::exponential::{ExponentialLrScheduler, ExponentialLrSchedulerConfig};
5use super::linear::{LinearLrScheduler, LinearLrSchedulerConfig};
6use super::noam::{NoamLrScheduler, NoamLrSchedulerConfig};
7use super::{LrScheduler, String};
8use crate::LearningRate;
9
10use burn::config::Config;
11use burn::record::Record;
12use burn::tensor::backend::Backend;
13
14/// Compose multiple [learning rate schedulers](LrScheduler) together.
15#[derive(Config, Debug)]
16pub struct ComposedLrSchedulerConfig {
17    #[config(default = "Vec::new()")]
18    schedulers: Vec<LrSchedulerConfig>,
19    #[config(default = "SchedulerReduction::Prod")]
20    reduction: SchedulerReduction,
21}
22
23/// Compose multiple [learning rate schedulers](LrScheduler) together.
24#[derive(Clone)]
25pub struct ComposedLrScheduler {
26    schedulers: Vec<LrSchedulerItem>,
27    reduction: SchedulerReduction,
28}
29
30/// Defines how the learning rates generated by the schedulers are combined.
31#[derive(Config, Debug, Copy)]
32pub enum SchedulerReduction {
33    /// All learning rates are averaged.
34    Avg,
35    /// All learning rates are summed.
36    Sum,
37    /// All learning rates are multiplied.
38    Prod,
39}
40
41impl ComposedLrSchedulerConfig {
42    /// Initialize the learning rate scheduler.
43    pub fn init(&self) -> Result<ComposedLrScheduler, String> {
44        let mut schedulers = Vec::with_capacity(self.schedulers.len());
45        for config in self.schedulers.iter() {
46            let config = match config {
47                LrSchedulerConfig::Linear(config) => LrSchedulerItem::Linear(config.init()?),
48                LrSchedulerConfig::Cosine(config) => LrSchedulerItem::Cosine(config.init()?),
49                LrSchedulerConfig::Exponential(config) => {
50                    LrSchedulerItem::Exponential(config.init()?)
51                }
52                LrSchedulerConfig::Noam(config) => LrSchedulerItem::Noam(config.init()?),
53            };
54            schedulers.push(config);
55        }
56
57        Ok(ComposedLrScheduler {
58            schedulers,
59            reduction: self.reduction,
60        })
61    }
62
63    /// Appends a [linear scheduler](LinearLrScheduler).
64    pub fn linear(mut self, config: LinearLrSchedulerConfig) -> Self {
65        self.schedulers.push(LrSchedulerConfig::Linear(config));
66        self
67    }
68
69    /// Appends a [cosine scheduler](ComposedLrSchedulerConfig).
70    pub fn cosine(mut self, config: CosineAnnealingLrSchedulerConfig) -> Self {
71        self.schedulers.push(LrSchedulerConfig::Cosine(config));
72        self
73    }
74
75    /// Appends an [exponential scheduler](ExponentialLrScheduler).
76    pub fn exponential(mut self, config: ExponentialLrSchedulerConfig) -> Self {
77        self.schedulers.push(LrSchedulerConfig::Exponential(config));
78        self
79    }
80
81    /// Appends a [noam scheduler](NoamLrScheduler).
82    pub fn noam(mut self, config: NoamLrSchedulerConfig) -> Self {
83        self.schedulers.push(LrSchedulerConfig::Noam(config));
84        self
85    }
86}
87
88#[derive(Config, Debug)]
89enum LrSchedulerConfig {
90    Linear(LinearLrSchedulerConfig),
91    Cosine(CosineAnnealingLrSchedulerConfig),
92    Exponential(ExponentialLrSchedulerConfig),
93    Noam(NoamLrSchedulerConfig),
94}
95
96#[derive(Clone)]
97enum LrSchedulerItem {
98    Linear(LinearLrScheduler),
99    Cosine(CosineAnnealingLrScheduler),
100    Exponential(ExponentialLrScheduler),
101    Noam(NoamLrScheduler),
102}
103
104#[derive(Record)]
105/// Record item for the [componsed learning rate scheduler](ComposedLrScheduler).
106pub enum LrSchedulerRecord<B: Backend> {
107    /// The linear variant.
108    Linear(<LinearLrScheduler as LrScheduler>::Record<B>),
109    /// The cosine variant.
110    Cosine(<CosineAnnealingLrScheduler as LrScheduler>::Record<B>),
111    /// The exponential variant.
112    Exponential(<ExponentialLrScheduler as LrScheduler>::Record<B>),
113    /// The noam variant.
114    Noam(<NoamLrScheduler as LrScheduler>::Record<B>),
115}
116
117#[derive(Record)]
118/// Records for the [componsed learning rate scheduler](ComposedLrScheduler).
119pub struct ComposedLrSchedulerRecord<B: Backend> {
120    schedulers: Vec<LrSchedulerRecord<B>>,
121}
122
123impl LrScheduler for ComposedLrScheduler {
124    type Record<B: Backend> = ComposedLrSchedulerRecord<B>;
125
126    fn step(&mut self) -> LearningRate {
127        let mut step = match self.reduction {
128            SchedulerReduction::Avg => 0.0,
129            SchedulerReduction::Sum => 0.0,
130            SchedulerReduction::Prod => 1.0,
131        };
132        let num_scheduler = self.schedulers.len() as f64;
133
134        for lr in self.schedulers.iter_mut().map(|s| match s {
135            LrSchedulerItem::Linear(item) => item.step(),
136            LrSchedulerItem::Cosine(item) => item.step(),
137            LrSchedulerItem::Exponential(item) => item.step(),
138            LrSchedulerItem::Noam(item) => item.step(),
139        }) {
140            step = match self.reduction {
141                SchedulerReduction::Avg => step + (lr / num_scheduler),
142                SchedulerReduction::Sum => step + lr,
143                SchedulerReduction::Prod => step * lr,
144            }
145        }
146
147        step
148    }
149
150    fn to_record<B: Backend>(&self) -> Self::Record<B> {
151        ComposedLrSchedulerRecord::<B> {
152            schedulers: self
153                .schedulers
154                .iter()
155                .map(|s| match s {
156                    LrSchedulerItem::Linear(item) => {
157                        LrSchedulerRecord::Linear(item.to_record::<B>())
158                    }
159                    LrSchedulerItem::Cosine(item) => {
160                        LrSchedulerRecord::Linear(item.to_record::<B>())
161                    }
162                    LrSchedulerItem::Exponential(item) => {
163                        LrSchedulerRecord::Exponential(item.to_record::<B>())
164                    }
165                    LrSchedulerItem::Noam(item) => LrSchedulerRecord::Noam(item.to_record::<B>()),
166                })
167                .collect(),
168        }
169    }
170
171    fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
172        self.schedulers = self
173            .schedulers
174            .into_iter()
175            .zip(record.schedulers)
176            .map(|scheduler| match scheduler {
177                (LrSchedulerItem::Linear(item), LrSchedulerRecord::Linear(record)) => {
178                    LrSchedulerItem::Linear(item.load_record::<B>(record))
179                }
180                (LrSchedulerItem::Cosine(item), LrSchedulerRecord::Cosine(record)) => {
181                    LrSchedulerItem::Cosine(item.load_record::<B>(record))
182                }
183                (LrSchedulerItem::Exponential(item), LrSchedulerRecord::Exponential(record)) => {
184                    LrSchedulerItem::Exponential(item.load_record::<B>(record))
185                }
186                (LrSchedulerItem::Noam(item), LrSchedulerRecord::Noam(record)) => {
187                    LrSchedulerItem::Noam(item.load_record::<B>(record))
188                }
189                _ => panic!("Invalid state"),
190            })
191            .collect();
192
193        self
194    }
195}