Skip to main content

entrenar/train/tui/callback/
render.rs

1//! Rendering methods for TerminalMonitorCallback.
2
3use std::io::Write;
4
5use crate::train::callback::CallbackContext;
6use crate::train::tui::andon::AlertLevel;
7use crate::train::tui::capability::DashboardLayout;
8use crate::train::tui::progress::format_duration;
9use crate::train::tui::sparkline::sparkline;
10
11use super::monitor::TerminalMonitorCallback;
12
13/// Rendering trait for callback display.
14pub(crate) trait CallbackRenderer {
15    /// Render the current display.
16    fn render(&self, ctx: &CallbackContext) -> String;
17    /// Print the display to stdout.
18    fn print_display(&self, ctx: &CallbackContext);
19}
20
21impl CallbackRenderer for TerminalMonitorCallback {
22    /// Render the current display.
23    fn render(&self, ctx: &CallbackContext) -> String {
24        match self.layout {
25            DashboardLayout::Minimal => render_minimal(self, ctx),
26            DashboardLayout::Compact => render_compact(self, ctx),
27            DashboardLayout::Full => render_full(self, ctx),
28        }
29    }
30
31    /// Print the display to stdout.
32    fn print_display(&self, ctx: &CallbackContext) {
33        let output = self.render(ctx);
34        print!("{output}");
35        let _ = std::io::stdout().flush();
36    }
37}
38
39/// Render minimal single-line display.
40fn render_minimal(_callback: &TerminalMonitorCallback, ctx: &CallbackContext) -> String {
41    let percent = (ctx.epoch as f32 / ctx.max_epochs as f32) * 100.0;
42    format!(
43        "\rEpoch {}/{} [{:.1}%] loss={:.4} lr={:.2e}",
44        ctx.epoch + 1,
45        ctx.max_epochs,
46        percent,
47        ctx.loss,
48        ctx.lr
49    )
50}
51
52/// Render compact 5-line display.
53fn render_compact(callback: &TerminalMonitorCallback, ctx: &CallbackContext) -> String {
54    let loss_spark = sparkline(&callback.loss_buffer.values(), callback.sparkline_width);
55    let elapsed = callback.start_time.elapsed().as_secs_f64();
56
57    let val_info = ctx.val_loss.map(|v| format!(" val={v:.4}")).unwrap_or_default();
58
59    let best_info = callback.loss_buffer.min().map(|m| format!(" best={m:.4}")).unwrap_or_default();
60
61    format!(
62        "\x1b[H\x1b[2J\
63         ═══ {} Training ═══\n\
64         Epoch {}/{} │ loss={:.4}{}{}\n\
65         Loss: {} \n\
66         LR: {:.2e} │ {:.1} steps/s\n\
67         {}",
68        callback.model_name,
69        ctx.epoch + 1,
70        ctx.max_epochs,
71        ctx.loss,
72        val_info,
73        best_info,
74        loss_spark,
75        ctx.lr,
76        ctx.global_step as f64 / elapsed.max(0.001),
77        callback.progress.render()
78    )
79}
80
81/// Render full dashboard display.
82fn render_full(callback: &TerminalMonitorCallback, ctx: &CallbackContext) -> String {
83    let loss_spark = sparkline(&callback.loss_buffer.values(), callback.sparkline_width);
84    let lr_spark = sparkline(&callback.lr_buffer.values(), callback.sparkline_width);
85    let elapsed = callback.start_time.elapsed().as_secs_f64();
86    let steps_per_sec = ctx.global_step as f64 / elapsed.max(0.001);
87
88    let val_spark = if callback.val_loss_buffer.is_empty() {
89        String::new()
90    } else {
91        format!(
92            "Val Loss: {} {:.4}\n",
93            sparkline(&callback.val_loss_buffer.values(), callback.sparkline_width),
94            callback.val_loss_buffer.last().unwrap_or(0.0)
95        )
96    };
97
98    let alerts = render_alerts(callback);
99
100    format!(
101        "\x1b[H\x1b[2J\
102╔═══════════════════════════════════════════════════════════════════╗
103║  ENTRENAR TRAINING MONITOR                              [RUNNING] ║
104╠═══════════════════════════════════════════════════════════════════╣
105║  Model: {:<20} │ Epoch: {}/{}                  ║
106╠═══════════════════════════════════════════════════════════════════╣
107║  Loss: {} {:.4}                                 ║
108║  {}║  LR:   {} {:.2e}                                 ║
109╠═══════════════════════════════════════════════════════════════════╣
110║  Steps/s: {:.1}  │  Elapsed: {}                        ║
111║  {}║
112╚═══════════════════════════════════════════════════════════════════╝
113{}",
114        callback.model_name,
115        ctx.epoch + 1,
116        ctx.max_epochs,
117        loss_spark,
118        ctx.loss,
119        val_spark,
120        lr_spark,
121        ctx.lr,
122        steps_per_sec,
123        format_duration(elapsed),
124        callback.progress.render(),
125        alerts
126    )
127}
128
129/// Render recent alerts.
130pub(crate) fn render_alerts(callback: &TerminalMonitorCallback) -> String {
131    let alerts = callback.andon.recent_alerts(3);
132    if alerts.is_empty() {
133        return String::new();
134    }
135
136    alerts
137        .iter()
138        .map(|a| {
139            let prefix = match a.level {
140                AlertLevel::Info => "ℹ️ ",
141                AlertLevel::Warning => "⚠️ ",
142                AlertLevel::Critical => "🛑",
143            };
144            format!("{} {}", prefix, a.message)
145        })
146        .collect::<Vec<_>>()
147        .join("\n")
148}