use scivex_core::Float;
use crate::error::Result;
#[derive(Debug, Clone)]
pub struct LrFinderResult<T: Float> {
pub lrs: Vec<T>,
pub losses: Vec<T>,
pub raw_losses: Vec<T>,
pub suggested_lr: T,
}
pub struct LrFinder<T: Float> {
start_lr: T,
end_lr: T,
num_steps: usize,
smoothing: T,
diverge_threshold: T,
}
impl<T: Float> LrFinder<T> {
pub fn new() -> Self {
Self {
start_lr: T::from_f64(1e-7),
end_lr: T::from_f64(10.0),
num_steps: 100,
smoothing: T::from_f64(0.05),
diverge_threshold: T::from_f64(4.0),
}
}
#[must_use]
pub fn with_start_lr(mut self, lr: T) -> Self {
self.start_lr = lr;
self
}
#[must_use]
pub fn with_end_lr(mut self, lr: T) -> Self {
self.end_lr = lr;
self
}
#[must_use]
pub fn with_num_steps(mut self, n: usize) -> Self {
self.num_steps = n;
self
}
#[must_use]
pub fn with_smoothing(mut self, s: T) -> Self {
self.smoothing = s;
self
}
#[must_use]
pub fn with_diverge_threshold(mut self, t: T) -> Self {
self.diverge_threshold = t;
self
}
pub fn run<F>(&self, mut train_step: F) -> Result<LrFinderResult<T>>
where
F: FnMut(T) -> Result<T>,
{
let mut lrs = Vec::with_capacity(self.num_steps);
let mut losses = Vec::with_capacity(self.num_steps);
let mut raw_losses = Vec::with_capacity(self.num_steps);
let ratio = self.end_lr / self.start_lr;
let one = T::one();
let mut best_loss = T::infinity();
let mut smoothed_loss = T::zero();
let denom = if self.num_steps > 1 {
T::from_usize(self.num_steps - 1)
} else {
one
};
for step in 0..self.num_steps {
let t = T::from_usize(step) / denom;
let lr = self.start_lr * ratio.powf(t);
let raw_loss = train_step(lr)?;
smoothed_loss = if step == 0 {
raw_loss
} else {
self.smoothing * raw_loss + (one - self.smoothing) * smoothed_loss
};
lrs.push(lr);
losses.push(smoothed_loss);
raw_losses.push(raw_loss);
if smoothed_loss < best_loss {
best_loss = smoothed_loss;
}
if step > 0 && smoothed_loss > best_loss * self.diverge_threshold {
break;
}
}
let suggested_lr = find_steepest_descent(&lrs, &losses);
Ok(LrFinderResult {
lrs,
losses,
raw_losses,
suggested_lr,
})
}
}
impl<T: Float> Default for LrFinder<T> {
fn default() -> Self {
Self::new()
}
}
fn find_steepest_descent<T: Float>(lrs: &[T], losses: &[T]) -> T {
if lrs.len() < 3 {
return lrs[lrs.len() / 2];
}
let mut best_idx = 1;
let mut best_grad = T::infinity();
for i in 1..lrs.len() - 1 {
let log_lr_prev = lrs[i - 1].ln();
let log_lr_next = lrs[i + 1].ln();
let d_log_lr = log_lr_next - log_lr_prev;
if d_log_lr.abs() < T::from_f64(1e-12) {
continue;
}
let grad = (losses[i + 1] - losses[i - 1]) / d_log_lr;
if grad < best_grad {
best_grad = grad;
best_idx = i;
}
}
lrs[best_idx]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lr_finder_basic() {
let finder = LrFinder::<f64>::new()
.with_start_lr(1e-5)
.with_end_lr(1.0)
.with_num_steps(50)
.with_diverge_threshold(10.0);
let result = finder
.run(|lr| {
let loss = (lr - 0.01) * (lr - 0.01) + 0.1;
Ok(loss)
})
.unwrap();
assert!(!result.lrs.is_empty());
assert_eq!(result.lrs.len(), result.losses.len());
assert!(result.suggested_lr > 0.0);
assert!(result.suggested_lr < 1.0);
}
#[test]
fn test_lr_finder_early_stop_on_divergence() {
let finder = LrFinder::<f64>::new()
.with_start_lr(1e-5)
.with_end_lr(10.0)
.with_num_steps(100)
.with_diverge_threshold(4.0);
let result = finder
.run(|lr| {
let loss = if lr < 0.1 { 1.0 - lr * 5.0 } else { lr * 100.0 };
Ok(loss)
})
.unwrap();
assert!(result.lrs.len() < 100);
}
#[test]
fn test_lr_finder_steepest_descent() {
let lrs: Vec<f64> = (0..10)
.map(|i| 10.0_f64.powf(-4.0 + 0.5 * f64::from(i)))
.collect();
let losses: Vec<f64> = vec![1.0, 0.9, 0.7, 0.4, 0.2, 0.15, 0.2, 0.5, 1.0, 2.0];
let suggested = find_steepest_descent(&lrs, &losses);
assert!(suggested > 1e-4);
assert!(suggested < 1.0);
}
#[test]
fn test_lr_finder_default() {
let finder = LrFinder::<f32>::default();
assert_eq!(finder.num_steps, 100);
}
}