use std::io::Write;
use std::time::{Duration, Instant};
use super::render::CallbackRenderer;
use crate::train::callback::{CallbackAction, CallbackContext, TrainerCallback};
use crate::train::tui::andon::AndonSystem;
use crate::train::tui::buffer::MetricsBuffer;
use crate::train::tui::capability::{DashboardLayout, TerminalMode};
use crate::train::tui::progress::ProgressBar;
use crate::train::tui::refresh::RefreshPolicy;
#[derive(Debug)]
pub struct TerminalMonitorCallback {
pub(crate) loss_buffer: MetricsBuffer,
pub(crate) val_loss_buffer: MetricsBuffer,
pub(crate) lr_buffer: MetricsBuffer,
pub(crate) progress: ProgressBar,
pub(crate) refresh_policy: RefreshPolicy,
pub(crate) andon: AndonSystem,
pub(crate) mode: TerminalMode,
pub(crate) layout: DashboardLayout,
pub(crate) sparkline_width: usize,
pub(crate) start_time: Instant,
pub(crate) model_name: String,
}
impl Default for TerminalMonitorCallback {
fn default() -> Self {
Self::new()
}
}
impl TerminalMonitorCallback {
pub fn new() -> Self {
Self {
loss_buffer: MetricsBuffer::new(100),
val_loss_buffer: MetricsBuffer::new(100),
lr_buffer: MetricsBuffer::new(100),
progress: ProgressBar::new(100, 30),
refresh_policy: RefreshPolicy::default(),
andon: AndonSystem::new(),
mode: TerminalMode::default(),
layout: DashboardLayout::default(),
sparkline_width: 20,
start_time: Instant::now(),
model_name: "model".to_string(),
}
}
pub fn mode(mut self, mode: TerminalMode) -> Self {
self.mode = mode;
self
}
pub fn layout(mut self, layout: DashboardLayout) -> Self {
self.layout = layout;
self
}
pub fn model_name(mut self, name: impl Into<String>) -> Self {
self.model_name = name.into();
self
}
pub fn sparkline_width(mut self, width: usize) -> Self {
self.sparkline_width = width;
self
}
pub fn refresh_interval_ms(mut self, ms: u64) -> Self {
self.refresh_policy.min_interval = Duration::from_millis(ms);
self
}
}
impl TrainerCallback for TerminalMonitorCallback {
fn on_train_begin(&mut self, ctx: &CallbackContext) -> CallbackAction {
self.start_time = Instant::now();
self.progress = ProgressBar::new(ctx.max_epochs * ctx.steps_per_epoch, 30);
print!("\x1b[?25l\x1b[2J\x1b[H");
let _ = std::io::stdout().flush();
CallbackAction::Continue
}
fn on_train_end(&mut self, ctx: &CallbackContext) {
self.print_display(ctx);
println!("\x1b[?25h");
let _ = std::io::stdout().flush();
println!("\nTraining complete!");
if let Some(best) = self.loss_buffer.min() {
println!("Best loss: {best:.4}");
}
println!(
"Total time: {}",
crate::train::tui::progress::format_duration(self.start_time.elapsed().as_secs_f64())
);
}
fn on_step_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
self.loss_buffer.push(ctx.loss);
self.lr_buffer.push(ctx.lr);
if let Some(val) = ctx.val_loss {
self.val_loss_buffer.push(val);
}
self.progress.update(ctx.global_step);
if self.andon.check_loss(ctx.loss) {
return CallbackAction::Stop;
}
if self.refresh_policy.should_refresh(ctx.global_step) {
self.print_display(ctx);
}
CallbackAction::Continue
}
fn on_epoch_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
self.refresh_policy.force_refresh(ctx.global_step);
self.print_display(ctx);
CallbackAction::Continue
}
fn name(&self) -> &'static str {
"TerminalMonitorCallback"
}
}