Skip to main content

entrenar/train/tui/callback/
monitor.rs

1//! Terminal Monitor Callback structure and builder methods.
2
3use std::io::Write;
4use std::time::{Duration, Instant};
5
6use super::render::CallbackRenderer;
7use crate::train::callback::{CallbackAction, CallbackContext, TrainerCallback};
8use crate::train::tui::andon::AndonSystem;
9use crate::train::tui::buffer::MetricsBuffer;
10use crate::train::tui::capability::{DashboardLayout, TerminalMode};
11use crate::train::tui::progress::ProgressBar;
12use crate::train::tui::refresh::RefreshPolicy;
13
14/// Real-time terminal monitoring callback.
15///
16/// Integrates with training loop to provide live visualization.
17#[derive(Debug)]
18pub struct TerminalMonitorCallback {
19    /// Loss buffer
20    pub(crate) loss_buffer: MetricsBuffer,
21    /// Validation loss buffer
22    pub(crate) val_loss_buffer: MetricsBuffer,
23    /// Learning rate buffer
24    pub(crate) lr_buffer: MetricsBuffer,
25    /// Progress bar
26    pub(crate) progress: ProgressBar,
27    /// Refresh policy
28    pub(crate) refresh_policy: RefreshPolicy,
29    /// Andon system
30    pub(crate) andon: AndonSystem,
31    /// Terminal mode
32    pub(crate) mode: TerminalMode,
33    /// Dashboard layout
34    pub(crate) layout: DashboardLayout,
35    /// Sparkline width
36    pub(crate) sparkline_width: usize,
37    /// Start time
38    pub(crate) start_time: Instant,
39    /// Model name (for display)
40    pub(crate) model_name: String,
41}
42
43impl Default for TerminalMonitorCallback {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49impl TerminalMonitorCallback {
50    /// Create a new terminal monitor callback.
51    pub fn new() -> Self {
52        Self {
53            loss_buffer: MetricsBuffer::new(100),
54            val_loss_buffer: MetricsBuffer::new(100),
55            lr_buffer: MetricsBuffer::new(100),
56            progress: ProgressBar::new(100, 30),
57            refresh_policy: RefreshPolicy::default(),
58            andon: AndonSystem::new(),
59            mode: TerminalMode::default(),
60            layout: DashboardLayout::default(),
61            sparkline_width: 20,
62            start_time: Instant::now(),
63            model_name: "model".to_string(),
64        }
65    }
66
67    /// Set terminal mode.
68    pub fn mode(mut self, mode: TerminalMode) -> Self {
69        self.mode = mode;
70        self
71    }
72
73    /// Set dashboard layout.
74    pub fn layout(mut self, layout: DashboardLayout) -> Self {
75        self.layout = layout;
76        self
77    }
78
79    /// Set model name.
80    pub fn model_name(mut self, name: impl Into<String>) -> Self {
81        self.model_name = name.into();
82        self
83    }
84
85    /// Set sparkline width.
86    pub fn sparkline_width(mut self, width: usize) -> Self {
87        self.sparkline_width = width;
88        self
89    }
90
91    /// Set refresh interval.
92    pub fn refresh_interval_ms(mut self, ms: u64) -> Self {
93        self.refresh_policy.min_interval = Duration::from_millis(ms);
94        self
95    }
96}
97
98impl TrainerCallback for TerminalMonitorCallback {
99    fn on_train_begin(&mut self, ctx: &CallbackContext) -> CallbackAction {
100        self.start_time = Instant::now();
101        self.progress = ProgressBar::new(ctx.max_epochs * ctx.steps_per_epoch, 30);
102
103        // Clear screen and hide cursor
104        print!("\x1b[?25l\x1b[2J\x1b[H");
105        let _ = std::io::stdout().flush();
106
107        CallbackAction::Continue
108    }
109
110    fn on_train_end(&mut self, ctx: &CallbackContext) {
111        // Final render
112        self.print_display(ctx);
113
114        // Show cursor
115        println!("\x1b[?25h");
116        let _ = std::io::stdout().flush();
117
118        // Print summary
119        println!("\nTraining complete!");
120        if let Some(best) = self.loss_buffer.min() {
121            println!("Best loss: {best:.4}");
122        }
123        println!(
124            "Total time: {}",
125            crate::train::tui::progress::format_duration(self.start_time.elapsed().as_secs_f64())
126        );
127    }
128
129    fn on_step_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
130        // Record metrics
131        self.loss_buffer.push(ctx.loss);
132        self.lr_buffer.push(ctx.lr);
133        if let Some(val) = ctx.val_loss {
134            self.val_loss_buffer.push(val);
135        }
136
137        // Update progress
138        self.progress.update(ctx.global_step);
139
140        // Check health
141        if self.andon.check_loss(ctx.loss) {
142            return CallbackAction::Stop;
143        }
144
145        // Rate-limited refresh
146        if self.refresh_policy.should_refresh(ctx.global_step) {
147            self.print_display(ctx);
148        }
149
150        CallbackAction::Continue
151    }
152
153    fn on_epoch_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
154        // Force refresh at epoch boundaries
155        self.refresh_policy.force_refresh(ctx.global_step);
156        self.print_display(ctx);
157        CallbackAction::Continue
158    }
159
160    fn name(&self) -> &'static str {
161        "TerminalMonitorCallback"
162    }
163}