use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::numeric::Float;
use std::collections::HashMap;
use std::fmt::Debug;
use crate::schedulers::LearningRateScheduler;
fn from_f64<A: Float>(v: f64) -> A {
A::from(v).unwrap_or_else(|| A::zero())
}
fn from_usize<A: Float>(v: usize) -> A {
A::from(v).unwrap_or_else(|| A::zero())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TransformerComponentType {
Attention,
FeedForward,
Embedding,
LayerNorm,
Output,
}
#[derive(Debug, Clone)]
pub struct AttentionAwareSchedulerBuilder<A: Float + Debug> {
base_lr: Option<A>,
warmup_steps: Option<usize>,
total_steps: Option<usize>,
component_scales: HashMap<TransformerComponentType, A>,
}
impl<A: Float + Debug + ScalarOperand + Send + Sync> AttentionAwareSchedulerBuilder<A> {
fn new() -> Self {
Self {
base_lr: None,
warmup_steps: None,
total_steps: None,
component_scales: HashMap::new(),
}
}
pub fn base_lr(mut self, lr: A) -> Self {
self.base_lr = Some(lr);
self
}
pub fn warmup_steps(mut self, steps: usize) -> Self {
self.warmup_steps = Some(steps);
self
}
pub fn total_steps(mut self, steps: usize) -> Self {
self.total_steps = Some(steps);
self
}
pub fn component_scale(mut self, component: TransformerComponentType, scale: A) -> Self {
self.component_scales.insert(component, scale);
self
}
pub fn build(self) -> AttentionAwareScheduler<A> {
let base_lr = self.base_lr.unwrap_or_else(|| from_f64(0.001));
let warmup_steps = self.warmup_steps.unwrap_or(0);
let total_steps = self.total_steps.unwrap_or(1000);
let mut scheduler = AttentionAwareScheduler::new(base_lr, warmup_steps, total_steps);
for (component, scale) in self.component_scales {
scheduler.set_component_scale(component, scale);
}
scheduler
}
}
#[derive(Debug, Clone)]
pub struct AttentionAwareScheduler<A: Float + Debug> {
base_lr: A,
warmup_steps: usize,
total_steps: usize,
current_step: usize,
current_lr: A,
component_scales: HashMap<TransformerComponentType, A>,
}
impl<A: Float + Debug + ScalarOperand + Send + Sync> AttentionAwareScheduler<A> {
pub fn new(base_lr: A, warmup_steps: usize, total_steps: usize) -> Self {
let mut component_scales = HashMap::new();
component_scales.insert(TransformerComponentType::Attention, from_f64(1.0));
component_scales.insert(TransformerComponentType::FeedForward, from_f64(1.0));
component_scales.insert(TransformerComponentType::Embedding, from_f64(0.1));
component_scales.insert(TransformerComponentType::LayerNorm, from_f64(0.01));
component_scales.insert(TransformerComponentType::Output, from_f64(0.5));
Self {
base_lr,
warmup_steps,
total_steps,
current_step: 0,
current_lr: A::zero(),
component_scales,
}
}
pub fn builder() -> AttentionAwareSchedulerBuilder<A> {
AttentionAwareSchedulerBuilder::new()
}
fn compute_schedule_factor(&self) -> A {
if self.current_step == 0 {
return A::zero();
}
if self.current_step <= self.warmup_steps {
if self.warmup_steps == 0 {
return A::one();
}
from_usize::<A>(self.current_step) / from_usize::<A>(self.warmup_steps)
} else {
let decay_steps = self.total_steps.saturating_sub(self.warmup_steps);
if decay_steps == 0 {
return A::one();
}
let steps_since_warmup = self.current_step.saturating_sub(self.warmup_steps);
let progress = if steps_since_warmup >= decay_steps {
A::one()
} else {
from_usize::<A>(steps_since_warmup) / from_usize::<A>(decay_steps)
};
let pi = from_f64::<A>(std::f64::consts::PI);
let half = from_f64::<A>(0.5);
half * (A::one() + (pi * progress).cos())
}
}
pub fn get_component_lr(&self, component_type: TransformerComponentType) -> A {
let scale = self
.component_scales
.get(&component_type)
.copied()
.unwrap_or_else(|| A::one());
self.current_lr * scale
}
pub fn set_component_scale(&mut self, component_type: TransformerComponentType, scale: A) {
self.component_scales.insert(component_type, scale);
}
pub fn component_scales(&self) -> &HashMap<TransformerComponentType, A> {
&self.component_scales
}
pub fn base_lr(&self) -> A {
self.base_lr
}
pub fn current_step(&self) -> usize {
self.current_step
}
}
impl<A: Float + Debug + ScalarOperand + Send + Sync> LearningRateScheduler<A>
for AttentionAwareScheduler<A>
{
fn get_learning_rate(&self) -> A {
self.current_lr
}
fn step(&mut self) -> A {
self.current_step += 1;
let factor = self.compute_schedule_factor();
self.current_lr = self.base_lr * factor;
self.current_lr
}
fn reset(&mut self) {
self.current_step = 0;
self.current_lr = A::zero();
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn test_attention_aware_basic() {
let mut scheduler = AttentionAwareScheduler::<f64>::new(0.001, 100, 1000);
assert_abs_diff_eq!(scheduler.get_learning_rate(), 0.0);
let lr = scheduler.step();
assert!(lr > 0.0, "LR should be positive after first step");
assert!(lr < 0.001, "LR should be less than base_lr during warmup");
for _ in 1..100 {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_learning_rate(), 0.001, epsilon = 1e-9);
let lr_after_decay = scheduler.step();
assert!(
lr_after_decay < 0.001,
"LR should decrease after warmup ends"
);
scheduler.reset();
assert_abs_diff_eq!(scheduler.get_learning_rate(), 0.0);
assert_eq!(scheduler.current_step(), 0);
}
#[test]
fn test_component_scales() {
let mut scheduler = AttentionAwareScheduler::<f64>::new(0.001, 100, 1000);
for _ in 0..100 {
scheduler.step();
}
let base_lr = scheduler.get_learning_rate();
assert_abs_diff_eq!(base_lr, 0.001, epsilon = 1e-9);
let attn_lr = scheduler.get_component_lr(TransformerComponentType::Attention);
assert_abs_diff_eq!(attn_lr, 0.001, epsilon = 1e-9);
let ff_lr = scheduler.get_component_lr(TransformerComponentType::FeedForward);
assert_abs_diff_eq!(ff_lr, 0.001, epsilon = 1e-9);
let embed_lr = scheduler.get_component_lr(TransformerComponentType::Embedding);
assert_abs_diff_eq!(embed_lr, 0.0001, epsilon = 1e-9);
let ln_lr = scheduler.get_component_lr(TransformerComponentType::LayerNorm);
assert_abs_diff_eq!(ln_lr, 0.00001, epsilon = 1e-9);
let output_lr = scheduler.get_component_lr(TransformerComponentType::Output);
assert_abs_diff_eq!(output_lr, 0.0005, epsilon = 1e-9);
assert!(attn_lr > output_lr);
assert!(output_lr > embed_lr);
assert!(embed_lr > ln_lr);
}
#[test]
fn test_warmup_cosine_schedule() {
let warmup_steps = 100;
let total_steps = 1000;
let mut scheduler = AttentionAwareScheduler::<f64>::new(0.001, warmup_steps, total_steps);
let mut prev_lr = 0.0;
for i in 0..warmup_steps {
let lr = scheduler.step();
assert!(
lr > prev_lr,
"LR should increase during warmup at step {}: {} vs {}",
i + 1,
lr,
prev_lr
);
prev_lr = lr;
}
assert_abs_diff_eq!(scheduler.get_learning_rate(), 0.001, epsilon = 1e-9);
prev_lr = scheduler.get_learning_rate();
for _ in 0..100 {
let lr = scheduler.step();
assert!(
lr <= prev_lr + 1e-12,
"LR should decrease during cosine decay"
);
prev_lr = lr;
}
for _ in 0..(total_steps - warmup_steps - 100) {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_learning_rate(), 0.0, epsilon = 1e-9);
}
#[test]
fn test_custom_component_scales() {
let mut scheduler = AttentionAwareScheduler::<f64>::new(0.001, 100, 1000);
scheduler.set_component_scale(TransformerComponentType::Attention, 2.0);
scheduler.set_component_scale(TransformerComponentType::Embedding, 0.05);
for _ in 0..100 {
scheduler.step();
}
let attn_lr = scheduler.get_component_lr(TransformerComponentType::Attention);
assert_abs_diff_eq!(attn_lr, 0.002, epsilon = 1e-9);
let embed_lr = scheduler.get_component_lr(TransformerComponentType::Embedding);
assert_abs_diff_eq!(embed_lr, 0.00005, epsilon = 1e-9);
let ff_lr = scheduler.get_component_lr(TransformerComponentType::FeedForward);
assert_abs_diff_eq!(ff_lr, 0.001, epsilon = 1e-9); }
#[test]
fn test_builder_pattern() {
let scheduler = AttentionAwareScheduler::<f64>::builder()
.base_lr(0.0005)
.warmup_steps(200)
.total_steps(5000)
.component_scale(TransformerComponentType::Attention, 1.5)
.component_scale(TransformerComponentType::Embedding, 0.05)
.build();
assert_abs_diff_eq!(scheduler.base_lr(), 0.0005);
assert_eq!(scheduler.current_step(), 0);
let scales = scheduler.component_scales();
assert_abs_diff_eq!(
*scales
.get(&TransformerComponentType::Attention)
.unwrap_or(&0.0),
1.5,
epsilon = 1e-9
);
assert_abs_diff_eq!(
*scales
.get(&TransformerComponentType::Embedding)
.unwrap_or(&0.0),
0.05,
epsilon = 1e-9
);
let default_scheduler = AttentionAwareScheduler::<f64>::builder().build();
assert_abs_diff_eq!(default_scheduler.base_lr(), 0.001);
}
}