#[allow(unused_imports)]
use crate::error::MetricsError;
use scirs2_core::numeric::{Float, FromPrimitive};
use std::collections::HashMap;
use std::fmt;
use std::marker::PhantomData;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OptimizationMode {
Minimize,
Maximize,
}
impl fmt::Display for OptimizationMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
OptimizationMode::Minimize => write!(f, "minimize"),
OptimizationMode::Maximize => write!(f, "maximize"),
}
}
}
#[derive(Debug, Clone)]
pub struct MetricOptimizer<F: Float + fmt::Debug + fmt::Display + FromPrimitive = f64> {
metric_name: String,
mode: OptimizationMode,
history: Vec<F>,
best_value: Option<F>,
additional_metrics: HashMap<String, Vec<F>>,
_phantom: PhantomData<F>,
}
impl<F: Float + fmt::Debug + fmt::Display + FromPrimitive> MetricOptimizer<F> {
pub fn new<S: Into<String>>(name: S, maximize: bool) -> Self {
Self {
metric_name: name.into(),
mode: if maximize {
OptimizationMode::Maximize
} else {
OptimizationMode::Minimize
},
history: Vec::new(),
best_value: None,
additional_metrics: HashMap::new(),
_phantom: PhantomData,
}
}
pub fn metric_name(&self) -> &str {
&self.metric_name
}
pub fn mode(&self) -> OptimizationMode {
self.mode
}
pub fn history(&self) -> &[F] {
&self.history
}
pub fn best_value(&self) -> Option<F> {
self.best_value
}
pub fn add_value(&mut self, value: F) {
self.history.push(value);
self.best_value = match (self.best_value, self.mode) {
(None, _) => Some(value),
(Some(best), OptimizationMode::Maximize) if value > best => Some(value),
(Some(best), OptimizationMode::Minimize) if value < best => Some(value),
(Some(best), _) => Some(best),
};
}
pub fn add_additional_value(&mut self, metricname: &str, value: F) {
self.additional_metrics
.entry(metricname.to_string())
.or_default()
.push(value);
}
pub fn additional_metric_history(&self, metricname: &str) -> Option<&[F]> {
self.additional_metrics
.get(metricname)
.map(|v| v.as_slice())
}
pub fn reset(&mut self) {
self.history.clear();
self.best_value = None;
self.additional_metrics.clear();
}
pub fn is_better(&self, current: F, previous: F) -> bool {
match self.mode {
OptimizationMode::Maximize => current > previous,
OptimizationMode::Minimize => current < previous,
}
}
pub fn is_improvement(&self, value: F) -> bool {
match self.best_value {
None => true,
Some(best) => self.is_better(value, best),
}
}
pub fn create_scheduler_config(
&self,
initial_lr: F,
factor: F,
patience: usize,
min_lr: F,
) -> SchedulerConfig<F> {
SchedulerConfig {
initial_lr,
factor,
patience,
min_lr,
mode: self.mode,
metric_name: self.metric_name.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct SchedulerConfig<F: Float + fmt::Debug + fmt::Display + FromPrimitive> {
pub initial_lr: F,
pub factor: F,
pub patience: usize,
pub min_lr: F,
pub mode: OptimizationMode,
pub metric_name: String,
}
impl<F: Float + fmt::Debug + fmt::Display + FromPrimitive> SchedulerConfig<F> {
pub fn as_tuple(&self) -> (F, F, usize, F, OptimizationMode) {
(
self.initial_lr,
self.factor,
self.patience,
self.min_lr,
self.mode,
)
}
pub fn new(
initial_lr: F,
factor: F,
patience: usize,
min_lr: F,
mode: OptimizationMode,
metric_name: String,
) -> Self {
Self {
initial_lr,
factor,
patience,
min_lr,
mode,
metric_name,
}
}
}