#[allow(unused_imports)]
use crate::error::Result;
use crate::integration::optim::OptimizationMode;
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt;
#[allow(unused_imports)]
use std::marker::PhantomData;
#[derive(Debug, Clone)]
pub struct MetricLRScheduler<F: Float + fmt::Debug + fmt::Display + FromPrimitive> {
current_lr: F,
initial_lr: F,
factor: F,
patience: usize,
min_lr: F,
stagnation_count: usize,
best_metric: Option<F>,
threshold: F,
mode: OptimizationMode,
metric_name: String,
history: Vec<F>,
metric_history: Vec<F>,
}
impl<F: Float + fmt::Debug + fmt::Display + FromPrimitive> MetricLRScheduler<F> {
pub fn new<S: Into<String>>(
initial_lr: F,
factor: F,
patience: usize,
min_lr: F,
metric_name: S,
maximize: bool,
) -> Self {
Self {
current_lr: initial_lr,
initial_lr,
factor,
patience,
min_lr,
stagnation_count: 0,
best_metric: None,
threshold: F::from(1e-4).expect("Failed to convert constant to float"),
mode: if maximize {
OptimizationMode::Maximize
} else {
OptimizationMode::Minimize
},
metric_name: metric_name.into(),
history: vec![initial_lr],
metric_history: Vec::new(),
}
}
pub fn set_threshold(&mut self, threshold: F) -> &mut Self {
self.threshold = threshold;
self
}
pub fn step_with_metric(&mut self, metric: F) -> F {
self.metric_history.push(metric);
let is_improvement = match self.best_metric {
None => true, Some(best) => {
match self.mode {
OptimizationMode::Minimize => {
metric < best * (F::one() - self.threshold)
}
OptimizationMode::Maximize => {
metric > best * (F::one() + self.threshold)
}
}
}
};
if is_improvement {
self.best_metric = Some(metric);
self.stagnation_count = 0;
} else {
self.stagnation_count += 1;
if self.stagnation_count >= self.patience {
self.current_lr = (self.current_lr * self.factor).max(self.min_lr);
self.history.push(self.current_lr);
self.stagnation_count = 0;
}
}
self.current_lr
}
pub fn get_learning_rate(&self) -> F {
self.current_lr
}
pub fn reset(&mut self) {
self.current_lr = self.initial_lr;
self.stagnation_count = 0;
self.best_metric = None;
self.history = vec![self.initial_lr];
self.metric_history.clear();
}
pub fn history(&self) -> &[F] {
&self.history
}
pub fn metric_history(&self) -> &[F] {
&self.metric_history
}
pub fn best_metric(&self) -> Option<F> {
self.best_metric
}
pub fn to_scheduler_config(&self) -> crate::integration::optim::SchedulerConfig<F> {
use crate::integration::optim::SchedulerConfig;
SchedulerConfig {
initial_lr: self.initial_lr,
factor: self.factor,
patience: self.patience,
min_lr: self.min_lr,
mode: self.mode,
metric_name: self.metric_name.clone(),
}
}
pub fn get_state(&self) -> SchedulerState<F> {
SchedulerState {
current_lr: self.current_lr,
best_metric: self.best_metric,
stagnation_count: self.stagnation_count,
threshold: self.threshold,
mode: self.mode,
}
}
}
#[derive(Debug, Clone)]
pub struct SchedulerState<F: Float + fmt::Debug + fmt::Display + FromPrimitive> {
pub current_lr: F,
pub best_metric: Option<F>,
pub stagnation_count: usize,
pub threshold: F,
pub mode: OptimizationMode,
}
pub trait MetricScheduler<F: Float + fmt::Debug + fmt::Display + FromPrimitive> {
fn step_with_metric(&mut self, metric: F) -> F;
fn get_learning_rate(&self) -> F;
fn reset(&mut self);
fn set_mode(&mut self, mode: OptimizationMode);
}
pub struct SchedulerBridge<F: Float + fmt::Debug + fmt::Display + FromPrimitive> {
inner: Box<dyn MetricScheduler<F>>,
metric_name: String,
metric_history: Vec<F>,
lr_history: Vec<F>,
}