use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use crate::schedulers::LearningRateScheduler;
fn from_f64<A: Float>(v: f64) -> A {
A::from(v).unwrap_or_else(|| A::zero())
}
fn from_usize<A: Float>(v: usize) -> A {
A::from(v).unwrap_or_else(|| A::zero())
}
#[derive(Debug, Clone)]
pub struct ViTLayerDecayBuilder<A: Float + Debug> {
base_lr: Option<A>,
decay_rate: Option<A>,
num_layers: Option<usize>,
warmup_steps: Option<usize>,
total_steps: Option<usize>,
}
impl<A: Float + Debug + ScalarOperand + Send + Sync> ViTLayerDecayBuilder<A> {
fn new() -> Self {
Self {
base_lr: None,
decay_rate: None,
num_layers: None,
warmup_steps: None,
total_steps: None,
}
}
pub fn base_lr(mut self, lr: A) -> Self {
self.base_lr = Some(lr);
self
}
pub fn decay_rate(mut self, rate: A) -> Self {
self.decay_rate = Some(rate);
self
}
pub fn num_layers(mut self, n: usize) -> Self {
self.num_layers = Some(n);
self
}
pub fn warmup_steps(mut self, steps: usize) -> Self {
self.warmup_steps = Some(steps);
self
}
pub fn total_steps(mut self, steps: usize) -> Self {
self.total_steps = Some(steps);
self
}
pub fn build(self) -> ViTLayerDecay<A> {
let base_lr = self.base_lr.unwrap_or_else(|| from_f64(0.001));
let decay_rate = self.decay_rate.unwrap_or_else(|| from_f64(0.75));
let num_layers = self.num_layers.unwrap_or(12);
let warmup_steps = self.warmup_steps.unwrap_or(0);
let total_steps = self.total_steps.unwrap_or(1000);
ViTLayerDecay::new(base_lr, decay_rate, num_layers, warmup_steps, total_steps)
}
}
#[derive(Debug, Clone)]
pub struct ViTLayerDecay<A: Float + Debug> {
base_lr: A,
decay_rate: A,
num_layers: usize,
warmup_steps: usize,
total_steps: usize,
current_step: usize,
current_lr: A,
}
impl<A: Float + Debug + ScalarOperand + Send + Sync> ViTLayerDecay<A> {
pub fn new(
base_lr: A,
decay_rate: A,
num_layers: usize,
warmup_steps: usize,
total_steps: usize,
) -> Self {
Self {
base_lr,
decay_rate,
num_layers,
warmup_steps,
total_steps,
current_step: 0,
current_lr: A::zero(),
}
}
pub fn builder() -> ViTLayerDecayBuilder<A> {
ViTLayerDecayBuilder::new()
}
fn compute_schedule_factor(&self) -> A {
if self.current_step == 0 {
return A::zero();
}
if self.current_step <= self.warmup_steps {
if self.warmup_steps == 0 {
return A::one();
}
from_usize::<A>(self.current_step) / from_usize::<A>(self.warmup_steps)
} else {
let decay_steps = self.total_steps.saturating_sub(self.warmup_steps);
if decay_steps == 0 {
return A::one();
}
let steps_since_warmup = self.current_step.saturating_sub(self.warmup_steps);
let progress = if steps_since_warmup >= decay_steps {
A::one()
} else {
from_usize::<A>(steps_since_warmup) / from_usize::<A>(decay_steps)
};
let pi = from_f64::<A>(std::f64::consts::PI);
let half = from_f64::<A>(0.5);
half * (A::one() + (pi * progress).cos())
}
}
pub fn get_layer_learning_rate(&self, layer_idx: usize) -> A {
if layer_idx >= self.num_layers {
return A::zero();
}
let exponent = self.num_layers - layer_idx - 1;
let decay_factor = self.decay_rate.powi(exponent as i32);
self.current_lr * decay_factor
}
pub fn get_all_layer_rates(&self) -> Vec<A> {
(0..self.num_layers)
.map(|i| self.get_layer_learning_rate(i))
.collect()
}
pub fn num_layers(&self) -> usize {
self.num_layers
}
pub fn current_step(&self) -> usize {
self.current_step
}
pub fn decay_rate(&self) -> A {
self.decay_rate
}
pub fn base_lr(&self) -> A {
self.base_lr
}
}
impl<A: Float + Debug + ScalarOperand + Send + Sync> LearningRateScheduler<A> for ViTLayerDecay<A> {
fn get_learning_rate(&self) -> A {
self.current_lr
}
fn step(&mut self) -> A {
self.current_step += 1;
let factor = self.compute_schedule_factor();
self.current_lr = self.base_lr * factor;
self.current_lr
}
fn reset(&mut self) {
self.current_step = 0;
self.current_lr = A::zero();
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn test_vit_layer_decay_basic() {
let mut scheduler = ViTLayerDecay::<f64>::new(0.001, 0.75, 12, 100, 1000);
assert_abs_diff_eq!(scheduler.get_learning_rate(), 0.0);
let lr = scheduler.step();
assert!(lr > 0.0, "LR should be positive after first step");
assert!(lr < 0.001, "LR should be less than base_lr during warmup");
for _ in 1..100 {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_learning_rate(), 0.001, epsilon = 1e-9);
let layer_0_lr = scheduler.get_layer_learning_rate(0);
let layer_11_lr = scheduler.get_layer_learning_rate(11);
assert!(
layer_11_lr > layer_0_lr,
"Last layer should have higher LR than first"
);
assert_abs_diff_eq!(layer_11_lr, 0.001, epsilon = 1e-9);
}
#[test]
fn test_warmup_phase() {
let mut scheduler = ViTLayerDecay::<f64>::new(0.001, 0.75, 12, 100, 1000);
let mut prev_lr = 0.0;
for i in 0..100 {
let lr = scheduler.step();
assert!(
lr > prev_lr,
"LR should increase during warmup at step {}: {} vs {}",
i + 1,
lr,
prev_lr
);
let expected = 0.001 * (i + 1) as f64 / 100.0;
assert_abs_diff_eq!(lr, expected, epsilon = 1e-12);
prev_lr = lr;
}
assert_abs_diff_eq!(scheduler.get_learning_rate(), 0.001, epsilon = 1e-9);
}
#[test]
fn test_cosine_decay_phase() {
let warmup_steps = 100;
let total_steps = 1000;
let mut scheduler = ViTLayerDecay::<f64>::new(0.001, 0.75, 12, warmup_steps, total_steps);
for _ in 0..warmup_steps {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_learning_rate(), 0.001, epsilon = 1e-9);
let mut prev_lr = scheduler.get_learning_rate();
for _ in 0..100 {
let lr = scheduler.step();
assert!(
lr <= prev_lr + 1e-12,
"LR should decrease during cosine decay"
);
prev_lr = lr;
}
for _ in 0..(total_steps - warmup_steps - 100) {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_learning_rate(), 0.0, epsilon = 1e-9);
}
#[test]
fn test_layer_rates_ordering() {
let mut scheduler = ViTLayerDecay::<f64>::new(0.001, 0.75, 12, 100, 1000);
for _ in 0..100 {
scheduler.step();
}
let rates = scheduler.get_all_layer_rates();
assert_eq!(rates.len(), 12);
for i in 1..rates.len() {
assert!(
rates[i] > rates[i - 1],
"Layer {} (lr={}) should have higher LR than layer {} (lr={})",
i,
rates[i],
i - 1,
rates[i - 1]
);
}
assert_abs_diff_eq!(rates[11], 0.001, epsilon = 1e-9);
let expected_layer_0 = 0.001 * 0.75_f64.powi(11);
assert_abs_diff_eq!(rates[0], expected_layer_0, epsilon = 1e-9);
assert_abs_diff_eq!(scheduler.get_layer_learning_rate(12), 0.0);
}
#[test]
fn test_builder_pattern() {
let scheduler = ViTLayerDecay::<f64>::builder()
.base_lr(0.001)
.decay_rate(0.75)
.num_layers(12)
.warmup_steps(500)
.total_steps(10000)
.build();
assert_abs_diff_eq!(scheduler.base_lr(), 0.001);
assert_abs_diff_eq!(scheduler.decay_rate(), 0.75);
assert_eq!(scheduler.num_layers(), 12);
assert_eq!(scheduler.current_step(), 0);
let default_scheduler = ViTLayerDecay::<f64>::builder().build();
assert_abs_diff_eq!(default_scheduler.base_lr(), 0.001);
assert_abs_diff_eq!(default_scheduler.decay_rate(), 0.75);
assert_eq!(default_scheduler.num_layers(), 12);
}
}