burn_optim/lr_scheduler/
composed.rs1use burn_core as burn;
2
3use super::cosine::{CosineAnnealingLrScheduler, CosineAnnealingLrSchedulerConfig};
4use super::exponential::{ExponentialLrScheduler, ExponentialLrSchedulerConfig};
5use super::linear::{LinearLrScheduler, LinearLrSchedulerConfig};
6use super::noam::{NoamLrScheduler, NoamLrSchedulerConfig};
7use super::{LrScheduler, String};
8use crate::LearningRate;
9
10use burn::config::Config;
11use burn::record::Record;
12use burn::tensor::backend::Backend;
13
14#[derive(Config, Debug)]
16pub struct ComposedLrSchedulerConfig {
17 #[config(default = "Vec::new()")]
18 schedulers: Vec<LrSchedulerConfig>,
19 #[config(default = "SchedulerReduction::Prod")]
20 reduction: SchedulerReduction,
21}
22
23#[derive(Clone)]
25pub struct ComposedLrScheduler {
26 schedulers: Vec<LrSchedulerItem>,
27 reduction: SchedulerReduction,
28}
29
30#[derive(Config, Debug, Copy)]
32pub enum SchedulerReduction {
33 Avg,
35 Sum,
37 Prod,
39}
40
41impl ComposedLrSchedulerConfig {
42 pub fn init(&self) -> Result<ComposedLrScheduler, String> {
44 let mut schedulers = Vec::with_capacity(self.schedulers.len());
45 for config in self.schedulers.iter() {
46 let config = match config {
47 LrSchedulerConfig::Linear(config) => LrSchedulerItem::Linear(config.init()?),
48 LrSchedulerConfig::Cosine(config) => LrSchedulerItem::Cosine(config.init()?),
49 LrSchedulerConfig::Exponential(config) => {
50 LrSchedulerItem::Exponential(config.init()?)
51 }
52 LrSchedulerConfig::Noam(config) => LrSchedulerItem::Noam(config.init()?),
53 };
54 schedulers.push(config);
55 }
56
57 Ok(ComposedLrScheduler {
58 schedulers,
59 reduction: self.reduction,
60 })
61 }
62
63 pub fn linear(mut self, config: LinearLrSchedulerConfig) -> Self {
65 self.schedulers.push(LrSchedulerConfig::Linear(config));
66 self
67 }
68
69 pub fn cosine(mut self, config: CosineAnnealingLrSchedulerConfig) -> Self {
71 self.schedulers.push(LrSchedulerConfig::Cosine(config));
72 self
73 }
74
75 pub fn exponential(mut self, config: ExponentialLrSchedulerConfig) -> Self {
77 self.schedulers.push(LrSchedulerConfig::Exponential(config));
78 self
79 }
80
81 pub fn noam(mut self, config: NoamLrSchedulerConfig) -> Self {
83 self.schedulers.push(LrSchedulerConfig::Noam(config));
84 self
85 }
86}
87
88#[derive(Config, Debug)]
89enum LrSchedulerConfig {
90 Linear(LinearLrSchedulerConfig),
91 Cosine(CosineAnnealingLrSchedulerConfig),
92 Exponential(ExponentialLrSchedulerConfig),
93 Noam(NoamLrSchedulerConfig),
94}
95
96#[derive(Clone)]
97enum LrSchedulerItem {
98 Linear(LinearLrScheduler),
99 Cosine(CosineAnnealingLrScheduler),
100 Exponential(ExponentialLrScheduler),
101 Noam(NoamLrScheduler),
102}
103
104#[derive(Record)]
105pub enum LrSchedulerRecord<B: Backend> {
107 Linear(<LinearLrScheduler as LrScheduler>::Record<B>),
109 Cosine(<CosineAnnealingLrScheduler as LrScheduler>::Record<B>),
111 Exponential(<ExponentialLrScheduler as LrScheduler>::Record<B>),
113 Noam(<NoamLrScheduler as LrScheduler>::Record<B>),
115}
116
117#[derive(Record)]
118pub struct ComposedLrSchedulerRecord<B: Backend> {
120 schedulers: Vec<LrSchedulerRecord<B>>,
121}
122
123impl LrScheduler for ComposedLrScheduler {
124 type Record<B: Backend> = ComposedLrSchedulerRecord<B>;
125
126 fn step(&mut self) -> LearningRate {
127 let mut step = match self.reduction {
128 SchedulerReduction::Avg => 0.0,
129 SchedulerReduction::Sum => 0.0,
130 SchedulerReduction::Prod => 1.0,
131 };
132 let num_scheduler = self.schedulers.len() as f64;
133
134 for lr in self.schedulers.iter_mut().map(|s| match s {
135 LrSchedulerItem::Linear(item) => item.step(),
136 LrSchedulerItem::Cosine(item) => item.step(),
137 LrSchedulerItem::Exponential(item) => item.step(),
138 LrSchedulerItem::Noam(item) => item.step(),
139 }) {
140 step = match self.reduction {
141 SchedulerReduction::Avg => step + (lr / num_scheduler),
142 SchedulerReduction::Sum => step + lr,
143 SchedulerReduction::Prod => step * lr,
144 }
145 }
146
147 step
148 }
149
150 fn to_record<B: Backend>(&self) -> Self::Record<B> {
151 ComposedLrSchedulerRecord::<B> {
152 schedulers: self
153 .schedulers
154 .iter()
155 .map(|s| match s {
156 LrSchedulerItem::Linear(item) => {
157 LrSchedulerRecord::Linear(item.to_record::<B>())
158 }
159 LrSchedulerItem::Cosine(item) => {
160 LrSchedulerRecord::Linear(item.to_record::<B>())
161 }
162 LrSchedulerItem::Exponential(item) => {
163 LrSchedulerRecord::Exponential(item.to_record::<B>())
164 }
165 LrSchedulerItem::Noam(item) => LrSchedulerRecord::Noam(item.to_record::<B>()),
166 })
167 .collect(),
168 }
169 }
170
171 fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
172 self.schedulers = self
173 .schedulers
174 .into_iter()
175 .zip(record.schedulers)
176 .map(|scheduler| match scheduler {
177 (LrSchedulerItem::Linear(item), LrSchedulerRecord::Linear(record)) => {
178 LrSchedulerItem::Linear(item.load_record::<B>(record))
179 }
180 (LrSchedulerItem::Cosine(item), LrSchedulerRecord::Cosine(record)) => {
181 LrSchedulerItem::Cosine(item.load_record::<B>(record))
182 }
183 (LrSchedulerItem::Exponential(item), LrSchedulerRecord::Exponential(record)) => {
184 LrSchedulerItem::Exponential(item.load_record::<B>(record))
185 }
186 (LrSchedulerItem::Noam(item), LrSchedulerRecord::Noam(record)) => {
187 LrSchedulerItem::Noam(item.load_record::<B>(record))
188 }
189 _ => panic!("Invalid state"),
190 })
191 .collect();
192
193 self
194 }
195}