#[cfg(feature = "metrics_integration")]
use crate::callbacks::{Callback, CallbackContext, CallbackTiming};
#[cfg(feature = "metrics_integration")]
use crate::error::Result;
#[cfg(feature = "metrics_integration")]
use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
#[cfg(feature = "metrics_integration")]
use scirs2_core::numeric::NumAssign;
#[cfg(feature = "metrics_integration")]
use scirs2_core::numeric::{Float, FromPrimitive};
#[cfg(feature = "metrics_integration")]
use scirs2_core::simd_ops::SimdUnifiedOps;
#[cfg(feature = "metrics_integration")]
use scirs2_metrics::integration::traits::MetricComputation;
#[cfg(feature = "metrics_integration")]
use std::collections::HashMap;
#[cfg(feature = "metrics_integration")]
use std::fmt::{Debug, Display};
#[cfg(feature = "metrics_integration")]
pub struct ScirsMetricsCallback<
F: Float
+ Debug
+ Display
+ FromPrimitive
+ Send
+ Sync
+ ScalarOperand
+ NumAssign
+ SimdUnifiedOps,
> {
metrics: Vec<scirs2_metrics::integration::neural::NeuralMetricAdapter<F>>,
pub current_predictions: Option<Array<F, IxDyn>>,
pub current_targets: Option<Array<F, IxDyn>>,
epoch_results: HashMap<String, F>,
history: Vec<HashMap<String, F>>,
verbose: bool,
}
#[cfg(feature = "metrics_integration")]
impl<
F: Float
+ Debug
+ Display
+ FromPrimitive
+ Send
+ Sync
+ ScalarOperand
+ NumAssign
+ SimdUnifiedOps,
> ScirsMetricsCallback<F>
{
pub fn new(
metrics: Vec<scirs2_metrics::integration::neural::NeuralMetricAdapter<F>>,
) -> Option<Self> {
Some(Self {
metrics,
current_predictions: None,
current_targets: None,
epoch_results: HashMap::new(),
history: Vec::new(),
verbose: true,
})
}
pub fn with_verbose(mut self, verbose: bool) -> Self {
self.verbose = verbose;
self
}
pub fn history(&self) -> &[HashMap<String, F>] {
&self.history
}
pub fn epoch_results(&self) -> &HashMap<String, F> {
&self.epoch_results
}
}
#[cfg(feature = "metrics_integration")]
impl<
F: Float
+ Debug
+ Display
+ FromPrimitive
+ Send
+ Sync
+ ScalarOperand
+ NumAssign
+ SimdUnifiedOps,
> Callback<F> for ScirsMetricsCallback<F>
{
fn on_event(&mut self, timing: CallbackTiming, context: &mut CallbackContext<F>) -> Result<()> {
match timing {
CallbackTiming::AfterBatch => {
}
CallbackTiming::AfterEpoch => {
if let (Some(preds), Some(targets)) =
(&self.current_predictions, &self.current_targets)
{
self.epoch_results.clear();
for metric in &self.metrics {
match metric.compute(preds, targets) {
Ok(value) => {
let metric_name = metric.name().to_string();
if self.verbose {
println!(" {}: {:.4}", metric_name, value);
}
self.epoch_results.insert(metric_name.clone(), value);
context.metrics.push(value);
}
Err(err) => {
eprintln!("Error computing {}: {}", metric.name(), err);
}
}
}
self.history.push(self.epoch_results.clone());
}
self.current_predictions = None;
self.current_targets = None;
}
_ => {}
}
Ok(())
}
}
#[cfg(not(feature = "metrics_integration"))]
#[derive(Debug)]
#[allow(dead_code)]
pub struct ScirsMetricsCallback<F> {
_phantom: std::marker::PhantomData<F>,
}
#[cfg(not(feature = "metrics_integration"))]
#[allow(unused_attributes, dead_code)]
impl<F> ScirsMetricsCallback<F> {
pub fn new<T>(_metrics: Vec<T>) -> Option<Self> {
eprintln!("Warning: ScirsMetricsCallback requires the 'metrics_integration' feature.");
eprintln!("To use it, compile with: --features metrics_integration");
None
}
pub fn with_verbose(mut self, _verbose: bool) -> Self {
self
}
}