use super::traits::{CallbackAction, CallbackContext, TrainerCallback};
#[derive(Debug)]
pub struct MonitorCallback {
collector: crate::monitor::MetricsCollector,
andon: crate::monitor::AndonSystem,
}
impl MonitorCallback {
pub fn new() -> Self {
Self {
collector: crate::monitor::MetricsCollector::new(),
andon: crate::monitor::AndonSystem::new(),
}
}
pub fn collector(&self) -> &crate::monitor::MetricsCollector {
&self.collector
}
pub fn summary_json(&self) -> Result<String, serde_json::Error> {
let summary: std::collections::HashMap<String, _> = self
.collector
.summary()
.into_iter()
.map(|(k, v)| (k.as_str().to_string(), v))
.collect();
serde_json::to_string_pretty(&summary)
}
}
impl Default for MonitorCallback {
fn default() -> Self {
Self::new()
}
}
impl TrainerCallback for MonitorCallback {
fn on_step_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
self.collector.record(crate::monitor::Metric::Loss, f64::from(ctx.loss));
self.collector.record(crate::monitor::Metric::LearningRate, f64::from(ctx.lr));
CallbackAction::Continue
}
fn on_epoch_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
self.collector.record(crate::monitor::Metric::Epoch, ctx.epoch as f64);
if ctx.loss.is_nan() {
self.andon.critical("NaN loss detected");
} else if ctx.loss.is_infinite() {
self.andon.critical("Infinite loss detected");
}
if self.andon.should_stop() {
CallbackAction::Stop
} else {
CallbackAction::Continue
}
}
fn name(&self) -> &'static str {
"MonitorCallback"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_monitor_callback() {
let mut monitor = MonitorCallback::new();
let ctx = CallbackContext { epoch: 0, step: 0, loss: 0.5, lr: 0.001, ..Default::default() };
assert_eq!(monitor.on_step_end(&ctx), CallbackAction::Continue);
assert_eq!(monitor.on_epoch_end(&ctx), CallbackAction::Continue);
let summary = monitor.collector().summary();
assert!(summary.contains_key(&crate::monitor::Metric::Loss));
}
#[test]
fn test_monitor_callback_nan_detection() {
let mut monitor = MonitorCallback::new();
let ctx = CallbackContext { loss: f32::NAN, ..Default::default() };
assert_eq!(monitor.on_epoch_end(&ctx), CallbackAction::Stop);
}
#[test]
fn test_monitor_callback_default() {
let mc = MonitorCallback::default();
assert_eq!(mc.name(), "MonitorCallback");
}
#[test]
fn test_monitor_callback_summary_json() {
let mut mc = MonitorCallback::new();
let ctx = CallbackContext { loss: 0.5, lr: 0.001, ..Default::default() };
mc.on_step_end(&ctx);
let json = mc.summary_json();
assert!(json.is_ok());
}
#[test]
fn test_monitor_callback_inf_detection() {
let mut mc = MonitorCallback::new();
let ctx = CallbackContext { loss: f32::INFINITY, ..Default::default() };
assert_eq!(mc.on_epoch_end(&ctx), CallbackAction::Stop);
}
#[test]
fn test_monitor_callback_nan_loss() {
let mut cb = MonitorCallback::new();
let mut ctx = CallbackContext::default();
ctx.loss = f32::NAN;
let action = cb.on_epoch_end(&ctx);
assert!(action == CallbackAction::Stop || action == CallbackAction::Continue);
}
#[test]
fn test_monitor_callback_infinite_loss() {
let mut cb = MonitorCallback::new();
let mut ctx = CallbackContext::default();
ctx.loss = f32::INFINITY;
cb.on_epoch_end(&ctx);
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn monitor_callback_detects_nan_inf(
normal_loss in -100.0f32..100.0,
) {
let mut monitor = MonitorCallback::new();
let ctx = CallbackContext {
loss: normal_loss,
..Default::default()
};
prop_assert_eq!(monitor.on_epoch_end(&ctx), CallbackAction::Continue);
let mut monitor_nan = MonitorCallback::new();
let ctx_nan = CallbackContext {
loss: f32::NAN,
..Default::default()
};
prop_assert_eq!(monitor_nan.on_epoch_end(&ctx_nan), CallbackAction::Stop);
let mut monitor_inf = MonitorCallback::new();
let ctx_inf = CallbackContext {
loss: f32::INFINITY,
..Default::default()
};
prop_assert_eq!(monitor_inf.on_epoch_end(&ctx_inf), CallbackAction::Stop);
}
}
}