use crate::callbacks::core::Callback;
use crate::{TrainResult, TrainingState};
pub struct LearningRateFinder {
start_lr: f64,
end_lr: f64,
num_steps: usize,
current_step: usize,
pub history: Vec<(f64, f64)>,
exponential: bool,
smoothing: f64,
smoothed_loss: Option<f64>,
}
impl LearningRateFinder {
pub fn new(start_lr: f64, end_lr: f64, num_steps: usize) -> Self {
Self {
start_lr,
end_lr,
num_steps,
current_step: 0,
history: Vec::with_capacity(num_steps),
exponential: true, smoothing: 0.0, smoothed_loss: None,
}
}
pub fn with_exponential_scaling(mut self) -> Self {
self.exponential = true;
self
}
pub fn with_linear_scaling(mut self) -> Self {
self.exponential = false;
self
}
pub fn with_smoothing(mut self, smoothing: f64) -> Self {
self.smoothing = smoothing.clamp(0.0, 1.0);
self
}
fn compute_lr(&self) -> f64 {
if self.num_steps <= 1 {
return self.start_lr;
}
let step_ratio = self.current_step as f64 / (self.num_steps - 1) as f64;
if self.exponential {
self.start_lr * (self.end_lr / self.start_lr).powf(step_ratio)
} else {
self.start_lr + (self.end_lr - self.start_lr) * step_ratio
}
}
fn smooth_loss(&mut self, loss: f64) -> f64 {
if self.smoothing == 0.0 {
return loss;
}
match self.smoothed_loss {
None => {
self.smoothed_loss = Some(loss);
loss
}
Some(prev) => {
let smoothed = self.smoothing * prev + (1.0 - self.smoothing) * loss;
self.smoothed_loss = Some(smoothed);
smoothed
}
}
}
pub fn suggest_lr(&self) -> Option<f64> {
if self.history.len() < 3 {
return None;
}
let mut best_lr = None;
let mut best_gradient = f64::INFINITY;
for i in 1..self.history.len() {
let (lr1, loss1) = self.history[i - 1];
let (lr2, loss2) = self.history[i];
let gradient = (loss2 - loss1) / (lr2 - lr1);
if gradient < best_gradient {
best_gradient = gradient;
best_lr = Some(lr2);
}
}
best_lr
}
pub fn print_results(&self) {
println!("\n=== Learning Rate Finder Results ===");
println!(
"Tested {} learning rates from {:.2e} to {:.2e}",
self.history.len(),
self.start_lr,
self.end_lr
);
if let Some(suggested_lr) = self.suggest_lr() {
println!("Suggested optimal LR: {:.2e}", suggested_lr);
println!(
"Consider using LR between {:.2e} and {:.2e}",
suggested_lr / 10.0,
suggested_lr
);
}
println!("\nLR, Loss:");
for (lr, loss) in &self.history {
println!("{:.6e}, {:.6}", lr, loss);
}
println!("===================================\n");
}
}
impl Callback for LearningRateFinder {
fn on_batch_end(&mut self, _batch: usize, state: &TrainingState) -> TrainResult<()> {
if self.current_step >= self.num_steps {
return Ok(());
}
let loss = self.smooth_loss(state.batch_loss);
let lr = self.compute_lr();
self.history.push((lr, loss));
self.current_step += 1;
Ok(())
}
fn should_stop(&self) -> bool {
self.current_step >= self.num_steps
}
}