use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::numeric::{Float, FromPrimitive};
#[cfg(not(feature = "metrics_integration"))]
use std::fmt::Debug;
#[cfg(feature = "metrics_integration")]
use std::fmt::{Debug, Display};
#[cfg(feature = "metrics_integration")]
use crate::schedulers::LearningRateScheduler;
#[cfg(feature = "metrics_integration")]
#[derive(Debug, Clone)]
pub struct MetricScheduler<F: Float + Debug + Display + ScalarOperand + FromPrimitive> {
scheduler: scirs2_metrics::integration::optim::MetricLRScheduler<F>,
threshold: F,
}
#[cfg(feature = "metrics_integration")]
impl<F: Float + Debug + Display + ScalarOperand + FromPrimitive + Send + Sync> MetricScheduler<F> {
pub fn new(
initial_lr: F,
factor: F,
patience: usize,
min_lr: F,
metric_name: &str,
maximize: bool,
) -> Self {
Self {
scheduler: scirs2_metrics::integration::optim::MetricLRScheduler::new(
initial_lr,
factor,
patience,
min_lr,
metric_name,
maximize,
),
threshold: F::from(1e-4).expect("unwrap failed"),
}
}
pub fn with_threshold(mut self, threshold: F) -> Self {
self.threshold = threshold;
self.scheduler.set_threshold(threshold);
self
}
pub fn step_with_metric(&mut self, metricvalue: F) -> F {
self.scheduler.step_with_metric(metric_value)
}
pub fn get_lr(&self) -> F {
self.scheduler.get_learning_rate()
}
pub fn history(&self) -> &[F] {
self.scheduler.history()
}
pub fn metric_history(&self) -> &[F] {
self.scheduler.metric_history()
}
pub fn best_metric(&self) -> Option<F> {
self.scheduler.best_metric()
}
}
#[cfg(feature = "metrics_integration")]
impl<F: Float + Debug + Display + ScalarOperand + FromPrimitive + Send + Sync>
LearningRateScheduler<F> for MetricScheduler<F>
{
fn get_learning_rate(&self) -> F {
self.scheduler.get_learning_rate()
}
fn step(&mut self) -> F {
self.get_learning_rate()
}
fn reset(&mut self) {
self.scheduler.reset();
}
}
#[cfg(feature = "metrics_integration")]
#[derive(Debug)]
pub struct MetricBasedReduceOnPlateau<F: Float + Debug + Display + ScalarOperand + FromPrimitive> {
scheduler: crate::schedulers::ReduceOnPlateau<F>,
metric_name: String,
metric_history: Vec<F>,
lr_history: Vec<F>,
}
#[cfg(feature = "metrics_integration")]
impl<F: Float + Debug + Display + ScalarOperand + FromPrimitive + Send + Sync>
MetricBasedReduceOnPlateau<F>
{
pub fn new(
initial_lr: F,
factor: F,
patience: usize,
min_lr: F,
metric_name: &str,
maximize: bool,
) -> Self {
let mut scheduler =
crate::schedulers::ReduceOnPlateau::new(initial_lr, factor, patience, min_lr);
if maximize {
scheduler.mode_max();
} else {
scheduler.mode_min();
}
Self {
scheduler,
metric_name: metric_name.to_string(),
metric_history: Vec::new(),
_lr_history: Vec::new(),
}
}
pub fn step_with_metric(&mut self, metricvalue: F) -> F {
self.metric_history.push(metric_value);
let lr = self.scheduler.step_with_metric(metric_value);
self.lr_history.push(lr);
lr
}
pub fn metric_name(&self) -> &str {
&self.metric_name
}
pub fn metric_history(&self) -> &[F] {
&self.metric_history
}
pub fn lr_history(&self) -> &[F] {
&self.lr_history
}
}
#[cfg(feature = "metrics_integration")]
impl<F: Float + Debug + Display + ScalarOperand + FromPrimitive + Send + Sync>
LearningRateScheduler<F> for MetricBasedReduceOnPlateau<F>
{
fn get_learning_rate(&self) -> F {
self.scheduler.get_learning_rate()
}
fn step(&mut self) -> F {
self.get_learning_rate()
}
fn reset(&mut self) {
self.scheduler.reset();
self.metric_history.clear();
self.lr_history.clear();
}
}
#[cfg(not(feature = "metrics_integration"))]
#[derive(Debug)]
pub struct MetricScheduler<F: Float + Debug> {
_phantom: std::marker::PhantomData<F>,
}
#[cfg(not(feature = "metrics_integration"))]
impl<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> MetricScheduler<F> {
pub fn new(
_initial_lr: F,
_factor: F,
_patience: usize,
_min_lr: F,
_metric_name: &str,
_maximize: bool,
) -> Self {
panic!("metrics_integration feature is not enabled - enable it in your Cargo.toml");
}
}