entrenar/train/tui/callback/
monitor.rs1use std::io::Write;
4use std::time::{Duration, Instant};
5
6use super::render::CallbackRenderer;
7use crate::train::callback::{CallbackAction, CallbackContext, TrainerCallback};
8use crate::train::tui::andon::AndonSystem;
9use crate::train::tui::buffer::MetricsBuffer;
10use crate::train::tui::capability::{DashboardLayout, TerminalMode};
11use crate::train::tui::progress::ProgressBar;
12use crate::train::tui::refresh::RefreshPolicy;
13
14#[derive(Debug)]
18pub struct TerminalMonitorCallback {
19 pub(crate) loss_buffer: MetricsBuffer,
21 pub(crate) val_loss_buffer: MetricsBuffer,
23 pub(crate) lr_buffer: MetricsBuffer,
25 pub(crate) progress: ProgressBar,
27 pub(crate) refresh_policy: RefreshPolicy,
29 pub(crate) andon: AndonSystem,
31 pub(crate) mode: TerminalMode,
33 pub(crate) layout: DashboardLayout,
35 pub(crate) sparkline_width: usize,
37 pub(crate) start_time: Instant,
39 pub(crate) model_name: String,
41}
42
43impl Default for TerminalMonitorCallback {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49impl TerminalMonitorCallback {
50 pub fn new() -> Self {
52 Self {
53 loss_buffer: MetricsBuffer::new(100),
54 val_loss_buffer: MetricsBuffer::new(100),
55 lr_buffer: MetricsBuffer::new(100),
56 progress: ProgressBar::new(100, 30),
57 refresh_policy: RefreshPolicy::default(),
58 andon: AndonSystem::new(),
59 mode: TerminalMode::default(),
60 layout: DashboardLayout::default(),
61 sparkline_width: 20,
62 start_time: Instant::now(),
63 model_name: "model".to_string(),
64 }
65 }
66
67 pub fn mode(mut self, mode: TerminalMode) -> Self {
69 self.mode = mode;
70 self
71 }
72
73 pub fn layout(mut self, layout: DashboardLayout) -> Self {
75 self.layout = layout;
76 self
77 }
78
79 pub fn model_name(mut self, name: impl Into<String>) -> Self {
81 self.model_name = name.into();
82 self
83 }
84
85 pub fn sparkline_width(mut self, width: usize) -> Self {
87 self.sparkline_width = width;
88 self
89 }
90
91 pub fn refresh_interval_ms(mut self, ms: u64) -> Self {
93 self.refresh_policy.min_interval = Duration::from_millis(ms);
94 self
95 }
96}
97
98impl TrainerCallback for TerminalMonitorCallback {
99 fn on_train_begin(&mut self, ctx: &CallbackContext) -> CallbackAction {
100 self.start_time = Instant::now();
101 self.progress = ProgressBar::new(ctx.max_epochs * ctx.steps_per_epoch, 30);
102
103 print!("\x1b[?25l\x1b[2J\x1b[H");
105 let _ = std::io::stdout().flush();
106
107 CallbackAction::Continue
108 }
109
110 fn on_train_end(&mut self, ctx: &CallbackContext) {
111 self.print_display(ctx);
113
114 println!("\x1b[?25h");
116 let _ = std::io::stdout().flush();
117
118 println!("\nTraining complete!");
120 if let Some(best) = self.loss_buffer.min() {
121 println!("Best loss: {best:.4}");
122 }
123 println!(
124 "Total time: {}",
125 crate::train::tui::progress::format_duration(self.start_time.elapsed().as_secs_f64())
126 );
127 }
128
129 fn on_step_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
130 self.loss_buffer.push(ctx.loss);
132 self.lr_buffer.push(ctx.lr);
133 if let Some(val) = ctx.val_loss {
134 self.val_loss_buffer.push(val);
135 }
136
137 self.progress.update(ctx.global_step);
139
140 if self.andon.check_loss(ctx.loss) {
142 return CallbackAction::Stop;
143 }
144
145 if self.refresh_policy.should_refresh(ctx.global_step) {
147 self.print_display(ctx);
148 }
149
150 CallbackAction::Continue
151 }
152
153 fn on_epoch_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
154 self.refresh_policy.force_refresh(ctx.global_step);
156 self.print_display(ctx);
157 CallbackAction::Continue
158 }
159
160 fn name(&self) -> &'static str {
161 "TerminalMonitorCallback"
162 }
163}