#[cfg(feature = "metrics_integration")]
use crate::error::Result;
use crate::optimizers::Optimizer;
#[cfg(feature = "metrics_integration")]
use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
#[cfg(not(feature = "metrics_integration"))]
use scirs2_core::ndarray::{Dimension, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive};
#[cfg(feature = "metrics_integration")]
use std::collections::HashMap;
use std::fmt::{Debug, Display};
use std::marker::PhantomData;
#[cfg(feature = "metrics_integration")]
pub struct MetricOptimizer<F, D>
where
F: Float + Debug + Display + FromPrimitive + ScalarOperand,
D: Dimension,
{
base_optimizer: Box<dyn Optimizer<F, D>>,
current_lr: F,
metric_adapter: scirs2_metrics::integration::optim::MetricOptimizer<F>,
history: Vec<HashMap<String, Array<F, D>>>,
best_params: Option<HashMap<String, Array<F, D>>>,
_phantom: PhantomData<(F, D)>,
}
#[cfg(feature = "metrics_integration")]
impl<F, D> MetricOptimizer<F, D>
where
F: Float + Debug + Display + FromPrimitive + ScalarOperand + 'static,
D: Dimension + 'static,
{
pub fn new<O>(optimizer: O, metric_name: &str, maximize: bool) -> Self
where
O: Optimizer<F, D> + 'static,
{
let initial_lr = optimizer.get_learning_rate();
Self {
base_optimizer: Box::new(optimizer),
current_lr: initial_lr,
metric_adapter: scirs2_metrics::integration::optim::MetricOptimizer::new(
metric_name,
maximize,
),
history: Vec::new(),
best_params: None,
_phantom: PhantomData,
}
}
pub fn update_metric(&mut self, metric: F) -> Result<()> {
self.metric_adapter.add_value(metric);
Ok(())
}
pub fn update_metrics(&mut self, metrics: HashMap<String, F>) -> Result<()> {
if let Some(value) = metrics.get(self.metric_adapter.metric_name()) {
self.metric_adapter.add_value(*value);
}
for (name, value) in metrics {
if name != self.metric_adapter.metric_name() {
self.metric_adapter.add_additional_value(&name, value);
}
}
Ok(())
}
pub fn metric_adapter(&self) -> &scirs2_metrics::integration::optim::MetricOptimizer<F> {
&self.metric_adapter
}
pub fn metric_adapter_mut(
&mut self,
) -> &mut scirs2_metrics::integration::optim::MetricOptimizer<F> {
&mut self.metric_adapter
}
pub fn base_optimizer(&self) -> &dyn Optimizer<F, D> {
&*self.base_optimizer
}
pub fn base_optimizer_mut(&mut self) -> &mut dyn Optimizer<F, D> {
&mut *self.base_optimizer
}
pub fn best_params(&self) -> Option<&HashMap<String, Array<F, D>>> {
self.best_params.as_ref()
}
pub fn history(&self) -> &[HashMap<String, Array<F, D>>] {
&self.history
}
pub fn reset(&mut self) {
self.metric_adapter.reset();
self.history.clear();
self.best_params = None;
}
pub fn create_lr_scheduler(
&self,
initial_lr: F,
factor: F,
patience: usize,
min_lr: F,
) -> crate::schedulers::ReduceOnPlateau<F> {
let mut scheduler =
crate::schedulers::ReduceOnPlateau::new(initial_lr, factor, patience, min_lr);
match self.metric_adapter.mode() {
scirs2_metrics::integration::optim::OptimizationMode::Minimize => {
scheduler.mode_min();
}
scirs2_metrics::integration::optim::OptimizationMode::Maximize => {
scheduler.mode_max();
}
}
scheduler
}
}
#[cfg(feature = "metrics_integration")]
impl<F, D> Optimizer<F, D> for MetricOptimizer<F, D>
where
F: Float + Debug + Display + FromPrimitive + ScalarOperand + 'static,
D: Dimension + 'static,
{
fn step(&mut self, params: &Array<F, D>, gradients: &Array<F, D>) -> Result<Array<F, D>> {
let updated_params = self.base_optimizer.step(params, gradients)?;
let mut param_update = HashMap::new();
param_update.insert("params".to_string(), updated_params.clone());
param_update.insert("gradients".to_string(), gradients.clone());
self.history.push(param_update);
if let Some(best_value) = self.metric_adapter.best_value() {
let is_improvement = match self.metric_adapter.mode() {
scirs2_metrics::integration::optim::OptimizationMode::Maximize => {
if let Some(last_value) = self.metric_adapter.history().last() {
*last_value > best_value
} else {
false
}
}
scirs2_metrics::integration::optim::OptimizationMode::Minimize => {
if let Some(last_value) = self.metric_adapter.history().last() {
*last_value < best_value
} else {
false
}
}
};
if is_improvement {
let mut best_params = HashMap::new();
best_params.insert("params".to_string(), updated_params.clone());
self.best_params = Some(best_params);
}
}
Ok(updated_params)
}
fn get_learning_rate(&self) -> F {
self.current_lr
}
fn set_learning_rate(&mut self, learning_rate: F) {
self.current_lr = learning_rate;
}
}
#[cfg(not(feature = "metrics_integration"))]
#[derive(Debug)]
pub struct MetricOptimizer<F, D>
where
F: Float + Debug + Display + FromPrimitive + ScalarOperand,
D: Dimension,
{
_phantom: PhantomData<(F, D)>,
}
#[cfg(not(feature = "metrics_integration"))]
impl<F, D> MetricOptimizer<F, D>
where
F: Float + Debug + Display + FromPrimitive + ScalarOperand,
D: Dimension,
{
pub fn new<O>(_optimizer: O, _metric_name: &str, _maximize: bool) -> Self
where
O: Optimizer<F, D>,
{
panic!("metrics_integration feature is not enabled - enable it in your Cargo.toml");
}
}