use crate::error::{NeuralError, Result};
use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive, NumAssign, ToPrimitive};
use std::fmt::{self, Debug, Display};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LRScheduleType {
Exponential,
Linear,
}
impl Display for LRScheduleType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Exponential => write!(f, "Exponential"),
Self::Linear => write!(f, "Linear"),
}
}
}
#[derive(Debug, Clone)]
pub struct LRFinderConfig {
pub min_lr: f64,
pub max_lr: f64,
pub num_iterations: usize,
pub schedule: LRScheduleType,
pub smoothing_factor: f64,
pub divergence_threshold: f64,
pub accumulation_steps: usize,
}
impl Default for LRFinderConfig {
fn default() -> Self {
Self {
min_lr: 1e-7,
max_lr: 10.0,
num_iterations: 100,
schedule: LRScheduleType::Exponential,
smoothing_factor: 0.98,
divergence_threshold: 5.0,
accumulation_steps: 1,
}
}
}
impl LRFinderConfig {
pub fn builder() -> LRFinderConfigBuilder {
LRFinderConfigBuilder::default()
}
pub fn validate(&self) -> Result<()> {
if self.min_lr <= 0.0 {
return Err(NeuralError::InvalidArgument(
"min_lr must be positive".to_string(),
));
}
if self.max_lr <= self.min_lr {
return Err(NeuralError::InvalidArgument(
"max_lr must be greater than min_lr".to_string(),
));
}
if self.num_iterations == 0 {
return Err(NeuralError::InvalidArgument(
"num_iterations must be positive".to_string(),
));
}
if !(0.0..=1.0).contains(&self.smoothing_factor) {
return Err(NeuralError::InvalidArgument(
"smoothing_factor must be in [0.0, 1.0]".to_string(),
));
}
if self.divergence_threshold <= 1.0 {
return Err(NeuralError::InvalidArgument(
"divergence_threshold must be > 1.0".to_string(),
));
}
if self.accumulation_steps == 0 {
return Err(NeuralError::InvalidArgument(
"accumulation_steps must be positive".to_string(),
));
}
Ok(())
}
pub fn lr_at_step(&self, step: usize) -> f64 {
let total = self.num_iterations.max(1) as f64;
let t = (step as f64) / total;
match self.schedule {
LRScheduleType::Exponential => {
self.min_lr * (self.max_lr / self.min_lr).powf(t)
}
LRScheduleType::Linear => self.min_lr + (self.max_lr - self.min_lr) * t,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct LRFinderConfigBuilder {
config: LRFinderConfig,
}
impl LRFinderConfigBuilder {
pub fn min_lr(mut self, lr: f64) -> Self {
self.config.min_lr = lr;
self
}
pub fn max_lr(mut self, lr: f64) -> Self {
self.config.max_lr = lr;
self
}
pub fn num_iterations(mut self, n: usize) -> Self {
self.config.num_iterations = n;
self
}
pub fn schedule(mut self, s: LRScheduleType) -> Self {
self.config.schedule = s;
self
}
pub fn smoothing_factor(mut self, f: f64) -> Self {
self.config.smoothing_factor = f;
self
}
pub fn divergence_threshold(mut self, t: f64) -> Self {
self.config.divergence_threshold = t;
self
}
pub fn accumulation_steps(mut self, n: usize) -> Self {
self.config.accumulation_steps = n;
self
}
pub fn build(self) -> Result<LRFinderConfig> {
self.config.validate()?;
Ok(self.config)
}
}
#[derive(Debug, Clone)]
pub struct LRFinderPoint {
pub step: usize,
pub lr: f64,
pub raw_loss: f64,
pub smoothed_loss: f64,
pub loss_gradient: Option<f64>,
}
#[derive(Debug, Clone)]
pub struct LRFinderResult {
pub points: Vec<LRFinderPoint>,
pub diverged: bool,
pub divergence_step: Option<usize>,
pub best_loss: f64,
pub best_loss_lr: f64,
pub config: LRFinderConfig,
}
impl LRFinderResult {
pub fn from_losses(losses: &[f64], config: &LRFinderConfig) -> Self {
let beta = config.smoothing_factor; let mut smoothed = 0.0_f64;
let mut best_loss = f64::MAX;
let mut best_loss_lr = config.min_lr;
let mut points = Vec::with_capacity(losses.len());
let mut diverged = false;
let mut divergence_step = None;
for (i, &raw_loss) in losses.iter().enumerate() {
let lr = config.lr_at_step(i);
if i == 0 {
smoothed = raw_loss;
} else {
smoothed = beta * smoothed + (1.0 - beta) * raw_loss;
}
let correction = 1.0 - beta.powi((i + 1) as i32);
let corrected = if correction.abs() > 1e-15 {
smoothed / correction
} else {
raw_loss };
if corrected < best_loss {
best_loss = corrected;
best_loss_lr = lr;
}
if !diverged && best_loss > 0.0 && corrected > best_loss * config.divergence_threshold {
diverged = true;
divergence_step = Some(i);
}
points.push(LRFinderPoint {
step: i,
lr,
raw_loss,
smoothed_loss: corrected,
loss_gradient: None,
});
if diverged {
break;
}
}
if points.len() >= 2 {
for i in 1..points.len() {
let log_lr_curr = points[i].lr.ln();
let log_lr_prev = points[i - 1].lr.ln();
let d_log_lr = log_lr_curr - log_lr_prev;
if d_log_lr.abs() > f64::EPSILON {
let d_loss = points[i].smoothed_loss - points[i - 1].smoothed_loss;
points[i].loss_gradient = Some(d_loss / d_log_lr);
}
}
}
Self {
points,
diverged,
divergence_step,
best_loss,
best_loss_lr,
config: config.clone(),
}
}
pub fn suggested_lr(&self) -> Option<f64> {
let mut min_gradient = f64::MAX;
let mut best_lr = None;
for point in &self.points {
if let Some(grad) = point.loss_gradient {
if grad < min_gradient {
min_gradient = grad;
best_lr = Some(point.lr);
}
}
}
best_lr
}
pub fn suggested_lr_conservative(&self) -> f64 {
self.best_loss_lr / 10.0
}
pub fn learning_rates(&self) -> Vec<f64> {
self.points.iter().map(|p| p.lr).collect()
}
pub fn raw_losses(&self) -> Vec<f64> {
self.points.iter().map(|p| p.raw_loss).collect()
}
pub fn smoothed_losses(&self) -> Vec<f64> {
self.points.iter().map(|p| p.smoothed_loss).collect()
}
pub fn loss_gradients(&self) -> Vec<f64> {
self.points
.iter()
.map(|p| p.loss_gradient.unwrap_or(f64::NAN))
.collect()
}
pub fn summary(&self) -> String {
let mut out = String::new();
out.push_str("=== Learning Rate Range Test Summary ===\n");
out.push_str(&format!("Schedule: {}\n", self.config.schedule));
out.push_str(&format!(
"LR range: [{:.2e}, {:.2e}]\n",
self.config.min_lr, self.config.max_lr
));
out.push_str(&format!("Iterations: {}\n", self.points.len()));
out.push_str(&format!(
"Best loss: {:.6} at lr={:.2e}\n",
self.best_loss, self.best_loss_lr
));
if self.diverged {
out.push_str(&format!(
"Diverged at step {} (lr={:.2e})\n",
self.divergence_step.unwrap_or(0),
self.points
.last()
.map(|p| p.lr)
.unwrap_or(self.config.max_lr)
));
}
if let Some(lr) = self.suggested_lr() {
out.push_str(&format!("Suggested LR (steepest decrease): {lr:.2e}\n"));
}
out.push_str(&format!(
"Suggested LR (conservative): {:.2e}\n",
self.suggested_lr_conservative()
));
out
}
}
#[derive(Debug, Clone)]
pub struct LRFinder {
config: LRFinderConfig,
step: usize,
raw_losses: Vec<f64>,
finished: bool,
diverged: bool,
ema_loss: f64,
best_loss: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LRFinderStatus {
Continue,
Complete,
Diverged,
}
impl LRFinder {
pub fn new(config: LRFinderConfig) -> Result<Self> {
config.validate()?;
Ok(Self {
config,
step: 0,
raw_losses: Vec::new(),
finished: false,
diverged: false,
ema_loss: 0.0,
best_loss: f64::MAX,
})
}
pub fn next_lr(&self) -> Option<f64> {
if self.finished {
return None;
}
Some(self.config.lr_at_step(self.step))
}
pub fn record_loss(&mut self, loss: f64) -> LRFinderStatus {
if self.finished {
return if self.diverged {
LRFinderStatus::Diverged
} else {
LRFinderStatus::Complete
};
}
self.raw_losses.push(loss);
let beta = self.config.smoothing_factor;
if self.step == 0 {
self.ema_loss = loss;
} else {
self.ema_loss = beta * self.ema_loss + (1.0 - beta) * loss;
}
let correction = 1.0 - beta.powi((self.step + 1) as i32);
let corrected = if correction.abs() > 1e-15 {
self.ema_loss / correction
} else {
loss };
if corrected < self.best_loss {
self.best_loss = corrected;
}
if self.best_loss > 0.0 && corrected > self.best_loss * self.config.divergence_threshold {
self.finished = true;
self.diverged = true;
return LRFinderStatus::Diverged;
}
self.step += 1;
if self.step >= self.config.num_iterations {
self.finished = true;
return LRFinderStatus::Complete;
}
LRFinderStatus::Continue
}
pub fn is_finished(&self) -> bool {
self.finished
}
pub fn is_diverged(&self) -> bool {
self.diverged
}
pub fn current_step(&self) -> usize {
self.step
}
pub fn total_iterations(&self) -> usize {
self.config.num_iterations
}
pub fn result(&self) -> LRFinderResult {
LRFinderResult::from_losses(&self.raw_losses, &self.config)
}
pub fn reset(&mut self) {
self.step = 0;
self.raw_losses.clear();
self.finished = false;
self.diverged = false;
self.ema_loss = 0.0;
self.best_loss = f64::MAX;
}
}
pub fn find_optimal_lr(
losses: &[f64],
config: &LRFinderConfig,
) -> Result<(Option<f64>, LRFinderResult)> {
config.validate()?;
let result = LRFinderResult::from_losses(losses, config);
let lr = result.suggested_lr();
Ok((lr, result))
}
pub struct TypedLRFinder<F: Float + Debug + ScalarOperand + NumAssign> {
inner: LRFinder,
original_params: Option<Vec<Array<F, IxDyn>>>,
}
impl<F: Float + Debug + ScalarOperand + NumAssign> TypedLRFinder<F> {
pub fn new(config: LRFinderConfig) -> Result<Self> {
Ok(Self {
inner: LRFinder::new(config)?,
original_params: None,
})
}
pub fn save_params(&mut self, params: &[Array<F, IxDyn>]) {
self.original_params = Some(params.to_vec());
}
pub fn original_params(&self) -> Option<&[Array<F, IxDyn>]> {
self.original_params.as_deref()
}
pub fn next_lr(&self) -> Option<F> {
self.inner.next_lr().and_then(|lr| F::from(lr))
}
pub fn record_loss(&mut self, loss: F) -> LRFinderStatus {
let loss_f64 = loss.to_f64().unwrap_or(f64::NAN);
self.inner.record_loss(loss_f64)
}
pub fn is_finished(&self) -> bool {
self.inner.is_finished()
}
pub fn result(&self) -> LRFinderResult {
self.inner.result()
}
pub fn reset(&mut self) {
self.inner.reset();
self.original_params = None;
}
}
impl<F: Float + Debug + ScalarOperand + NumAssign> Debug for TypedLRFinder<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TypedLRFinder")
.field("inner", &self.inner)
.field("has_original_params", &self.original_params.is_some())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_defaults() {
let config = LRFinderConfig::default();
assert!((config.min_lr - 1e-7).abs() < 1e-15);
assert!((config.max_lr - 10.0).abs() < 1e-10);
assert_eq!(config.num_iterations, 100);
assert_eq!(config.schedule, LRScheduleType::Exponential);
}
#[test]
fn test_config_builder() {
let config = LRFinderConfig::builder()
.min_lr(1e-5)
.max_lr(1.0)
.num_iterations(50)
.schedule(LRScheduleType::Linear)
.smoothing_factor(0.1)
.divergence_threshold(4.0)
.build()
.expect("build should succeed");
assert!((config.min_lr - 1e-5).abs() < 1e-15);
assert!((config.max_lr - 1.0).abs() < 1e-10);
assert_eq!(config.num_iterations, 50);
assert_eq!(config.schedule, LRScheduleType::Linear);
}
#[test]
fn test_config_validation_errors() {
assert!(LRFinderConfig::builder().min_lr(-1.0).build().is_err());
assert!(LRFinderConfig::builder().min_lr(0.0).build().is_err());
assert!(LRFinderConfig::builder()
.min_lr(1.0)
.max_lr(0.5)
.build()
.is_err());
assert!(LRFinderConfig::builder().num_iterations(0).build().is_err());
assert!(LRFinderConfig::builder()
.smoothing_factor(1.5)
.build()
.is_err());
assert!(LRFinderConfig::builder()
.smoothing_factor(-0.1)
.build()
.is_err());
assert!(LRFinderConfig::builder()
.divergence_threshold(1.0)
.build()
.is_err());
assert!(LRFinderConfig::builder()
.divergence_threshold(0.5)
.build()
.is_err());
assert!(LRFinderConfig::builder()
.accumulation_steps(0)
.build()
.is_err());
}
#[test]
fn test_exponential_lr_schedule() {
let config = LRFinderConfig::builder()
.min_lr(1e-4)
.max_lr(1.0)
.num_iterations(100)
.schedule(LRScheduleType::Exponential)
.build()
.expect("build should succeed");
let lr_start = config.lr_at_step(0);
let lr_end = config.lr_at_step(100);
let lr_mid = config.lr_at_step(50);
assert!((lr_start - 1e-4).abs() < 1e-10);
assert!((lr_end - 1.0).abs() < 1e-6);
let expected_mid = (1e-4_f64 * 1.0).sqrt();
assert!((lr_mid - expected_mid).abs() < 1e-6);
}
#[test]
fn test_linear_lr_schedule() {
let config = LRFinderConfig::builder()
.min_lr(0.0001)
.max_lr(1.0)
.num_iterations(100)
.schedule(LRScheduleType::Linear)
.build()
.expect("build should succeed");
let lr_start = config.lr_at_step(0);
let lr_end = config.lr_at_step(100);
let lr_mid = config.lr_at_step(50);
assert!((lr_start - 0.0001).abs() < 1e-10);
assert!((lr_end - 1.0).abs() < 1e-6);
let expected_mid = 0.0001 + (1.0 - 0.0001) * 0.5;
assert!((lr_mid - expected_mid).abs() < 1e-6);
}
#[test]
fn test_lr_finder_decreasing_then_diverging_loss() {
let config = LRFinderConfig::builder()
.min_lr(1e-5)
.max_lr(1.0)
.num_iterations(100)
.smoothing_factor(0.0) .divergence_threshold(5.0)
.build()
.expect("build should succeed");
let losses: Vec<f64> = (0..100)
.map(|i| {
let t = i as f64 / 100.0;
if t < 0.4 {
1.0 - t * 2.0 } else {
0.2 + (t - 0.4).powi(2) * 50.0 }
})
.collect();
let result = LRFinderResult::from_losses(&losses, &config);
let suggested = result.suggested_lr();
assert!(suggested.is_some());
let lr = suggested.expect("should have suggested lr");
assert!(lr > 1e-5, "lr={lr} should be > 1e-5");
assert!(lr < 1.0, "lr={lr} should be < 1.0");
}
#[test]
fn test_lr_finder_divergence_detection() {
let config = LRFinderConfig::builder()
.min_lr(1e-5)
.max_lr(100.0)
.num_iterations(100)
.smoothing_factor(0.0) .divergence_threshold(3.0)
.build()
.expect("build should succeed");
let losses: Vec<f64> = (0..100)
.map(|i| {
if i < 30 {
1.0 - i as f64 * 0.01 } else {
0.7 + (i as f64 - 30.0).powi(2) * 0.1 }
})
.collect();
let result = LRFinderResult::from_losses(&losses, &config);
assert!(result.diverged);
assert!(result.divergence_step.is_some());
assert!(result.points.len() < 100);
}
#[test]
fn test_lr_finder_no_divergence() {
let config = LRFinderConfig::builder()
.min_lr(1e-5)
.max_lr(0.1)
.num_iterations(50)
.smoothing_factor(0.0)
.divergence_threshold(5.0)
.build()
.expect("build should succeed");
let losses: Vec<f64> = (0..50).map(|i| 1.0 - i as f64 * 0.015).collect();
let result = LRFinderResult::from_losses(&losses, &config);
assert!(!result.diverged);
assert_eq!(result.points.len(), 50);
}
#[test]
fn test_stateful_lr_finder() {
let config = LRFinderConfig::builder()
.min_lr(1e-5)
.max_lr(1.0)
.num_iterations(10)
.smoothing_factor(0.0)
.divergence_threshold(5.0)
.build()
.expect("build should succeed");
let mut finder = LRFinder::new(config).expect("should create finder");
for i in 0..10 {
assert!(!finder.is_finished());
let lr = finder.next_lr().expect("should have lr");
assert!(lr > 0.0);
let loss = 1.0 - (i as f64) * 0.05;
let status = finder.record_loss(loss);
if i < 9 {
assert_eq!(status, LRFinderStatus::Continue);
} else {
assert_eq!(status, LRFinderStatus::Complete);
}
}
assert!(finder.is_finished());
assert!(!finder.is_diverged());
let result = finder.result();
assert_eq!(result.points.len(), 10);
}
#[test]
fn test_stateful_lr_finder_divergence() {
let config = LRFinderConfig::builder()
.min_lr(1e-5)
.max_lr(1.0)
.num_iterations(100)
.smoothing_factor(0.0)
.divergence_threshold(3.0)
.build()
.expect("build should succeed");
let mut finder = LRFinder::new(config).expect("should create finder");
let mut step = 0;
loop {
if finder.is_finished() {
break;
}
let _lr = finder.next_lr().expect("should have lr");
let loss = if step < 5 {
1.0 - step as f64 * 0.05
} else {
0.75 + (step as f64 - 5.0).powi(2) * 0.5
};
let status = finder.record_loss(loss);
if status != LRFinderStatus::Continue {
break;
}
step += 1;
}
assert!(finder.is_finished());
assert!(finder.is_diverged());
}
#[test]
fn test_stateful_lr_finder_reset() {
let config = LRFinderConfig::builder()
.min_lr(1e-5)
.max_lr(1.0)
.num_iterations(5)
.smoothing_factor(0.0)
.divergence_threshold(5.0)
.build()
.expect("build should succeed");
let mut finder = LRFinder::new(config).expect("should create finder");
for i in 0..5 {
finder.record_loss(1.0 - i as f64 * 0.1);
}
assert!(finder.is_finished());
finder.reset();
assert!(!finder.is_finished());
assert_eq!(finder.current_step(), 0);
for i in 0..5 {
finder.record_loss(0.5 - i as f64 * 0.05);
}
assert!(finder.is_finished());
let result = finder.result();
assert_eq!(result.points.len(), 5);
}
#[test]
fn test_find_optimal_lr_convenience() {
let config = LRFinderConfig::builder()
.min_lr(1e-5)
.max_lr(1.0)
.num_iterations(50)
.smoothing_factor(0.05)
.divergence_threshold(5.0)
.build()
.expect("build should succeed");
let losses: Vec<f64> = (0..50)
.map(|i| {
let t = i as f64 / 50.0;
0.5 + (t - 0.4).powi(2) * 2.0
})
.collect();
let (suggested, result) = find_optimal_lr(&losses, &config).expect("should succeed");
assert!(suggested.is_some());
assert!(!result.diverged);
}
#[test]
fn test_result_accessors() {
let config = LRFinderConfig::default();
let losses: Vec<f64> = (0..100).map(|i| 1.0 - i as f64 * 0.005).collect();
let result = LRFinderResult::from_losses(&losses, &config);
let lrs = result.learning_rates();
assert_eq!(lrs.len(), result.points.len());
let raw = result.raw_losses();
assert_eq!(raw.len(), result.points.len());
let smoothed = result.smoothed_losses();
assert_eq!(smoothed.len(), result.points.len());
let grads = result.loss_gradients();
assert_eq!(grads.len(), result.points.len());
assert!(grads[0].is_nan());
}
#[test]
fn test_summary_generation() {
let config = LRFinderConfig::default();
let losses: Vec<f64> = (0..20).map(|i| 1.0 - i as f64 * 0.01).collect();
let result = LRFinderResult::from_losses(&losses, &config);
let summary = result.summary();
assert!(summary.contains("Learning Rate Range Test Summary"));
assert!(summary.contains("Exponential"));
assert!(summary.contains("Best loss"));
}
#[test]
fn test_conservative_suggestion() {
let config = LRFinderConfig::builder()
.min_lr(1e-5)
.max_lr(1.0)
.num_iterations(20)
.smoothing_factor(0.0)
.divergence_threshold(5.0)
.build()
.expect("build should succeed");
let losses: Vec<f64> = (0..20)
.map(|i| {
let t = i as f64 / 20.0;
(t - 0.5).powi(2) + 0.1
})
.collect();
let result = LRFinderResult::from_losses(&losses, &config);
let conservative = result.suggested_lr_conservative();
assert!((conservative - result.best_loss_lr / 10.0).abs() < 1e-15);
}
#[test]
fn test_typed_lr_finder() {
let config = LRFinderConfig::builder()
.min_lr(1e-5)
.max_lr(1.0)
.num_iterations(10)
.smoothing_factor(0.0)
.divergence_threshold(5.0)
.build()
.expect("build should succeed");
let mut finder = TypedLRFinder::<f64>::new(config).expect("should create finder");
let params = vec![Array::<f64, IxDyn>::zeros(IxDyn(&[3, 3]))];
finder.save_params(¶ms);
assert!(finder.original_params().is_some());
for i in 0..10 {
let lr: f64 = finder.next_lr().expect("should have lr");
assert!(lr > 0.0);
let loss = 1.0 - i as f64 * 0.05;
finder.record_loss(loss);
}
assert!(finder.is_finished());
let result = finder.result();
assert_eq!(result.points.len(), 10);
}
#[test]
fn test_typed_lr_finder_f32() {
let config = LRFinderConfig::builder()
.min_lr(1e-5)
.max_lr(1.0)
.num_iterations(5)
.smoothing_factor(0.0)
.divergence_threshold(5.0)
.build()
.expect("build should succeed");
let mut finder = TypedLRFinder::<f32>::new(config).expect("should create finder");
for i in 0..5 {
let lr: f32 = finder.next_lr().expect("should have lr");
assert!(lr > 0.0);
let loss: f32 = 1.0 - i as f32 * 0.1;
finder.record_loss(loss);
}
assert!(finder.is_finished());
}
#[test]
fn test_ema_smoothing() {
let config = LRFinderConfig::builder()
.min_lr(1e-5)
.max_lr(1.0)
.num_iterations(100)
.smoothing_factor(0.1) .divergence_threshold(100.0) .build()
.expect("build should succeed");
let losses: Vec<f64> = (0..100)
.map(|i| {
let trend = 1.0 - i as f64 * 0.005;
let noise = if i % 2 == 0 { 0.1 } else { -0.1 };
trend + noise
})
.collect();
let result = LRFinderResult::from_losses(&losses, &config);
let raw = result.raw_losses();
let smoothed = result.smoothed_losses();
fn variance(data: &[f64]) -> f64 {
if data.len() < 2 {
return 0.0;
}
let diffs: Vec<f64> = data.windows(2).map(|w| (w[1] - w[0]).powi(2)).collect();
diffs.iter().sum::<f64>() / diffs.len() as f64
}
let raw_var = variance(&raw);
let smoothed_var = variance(&smoothed);
assert!(
smoothed_var < raw_var,
"smoothed_var={smoothed_var} should be < raw_var={raw_var}"
);
}
#[test]
fn test_lr_schedule_display() {
assert_eq!(format!("{}", LRScheduleType::Exponential), "Exponential");
assert_eq!(format!("{}", LRScheduleType::Linear), "Linear");
}
#[test]
fn test_lr_finder_after_finished() {
let config = LRFinderConfig::builder()
.min_lr(1e-5)
.max_lr(1.0)
.num_iterations(3)
.smoothing_factor(0.0)
.divergence_threshold(5.0)
.build()
.expect("build should succeed");
let mut finder = LRFinder::new(config).expect("should create finder");
for _ in 0..3 {
finder.record_loss(0.5);
}
assert!(finder.is_finished());
assert!(finder.next_lr().is_none());
assert_eq!(finder.record_loss(0.5), LRFinderStatus::Complete);
}
#[test]
fn test_typed_lr_finder_reset() {
let config = LRFinderConfig::builder()
.min_lr(1e-5)
.max_lr(1.0)
.num_iterations(3)
.smoothing_factor(0.0)
.divergence_threshold(5.0)
.build()
.expect("build should succeed");
let mut finder = TypedLRFinder::<f64>::new(config).expect("should create finder");
let params = vec![Array::<f64, IxDyn>::zeros(IxDyn(&[2]))];
finder.save_params(¶ms);
for _ in 0..3 {
finder.record_loss(0.5);
}
assert!(finder.is_finished());
finder.reset();
assert!(!finder.is_finished());
assert!(finder.original_params().is_none());
}
}