entrenar/train/tui/callback/
render.rs1use std::io::Write;
4
5use crate::train::callback::CallbackContext;
6use crate::train::tui::andon::AlertLevel;
7use crate::train::tui::capability::DashboardLayout;
8use crate::train::tui::progress::format_duration;
9use crate::train::tui::sparkline::sparkline;
10
11use super::monitor::TerminalMonitorCallback;
12
13pub(crate) trait CallbackRenderer {
15 fn render(&self, ctx: &CallbackContext) -> String;
17 fn print_display(&self, ctx: &CallbackContext);
19}
20
21impl CallbackRenderer for TerminalMonitorCallback {
22 fn render(&self, ctx: &CallbackContext) -> String {
24 match self.layout {
25 DashboardLayout::Minimal => render_minimal(self, ctx),
26 DashboardLayout::Compact => render_compact(self, ctx),
27 DashboardLayout::Full => render_full(self, ctx),
28 }
29 }
30
31 fn print_display(&self, ctx: &CallbackContext) {
33 let output = self.render(ctx);
34 print!("{output}");
35 let _ = std::io::stdout().flush();
36 }
37}
38
39fn render_minimal(_callback: &TerminalMonitorCallback, ctx: &CallbackContext) -> String {
41 let percent = (ctx.epoch as f32 / ctx.max_epochs as f32) * 100.0;
42 format!(
43 "\rEpoch {}/{} [{:.1}%] loss={:.4} lr={:.2e}",
44 ctx.epoch + 1,
45 ctx.max_epochs,
46 percent,
47 ctx.loss,
48 ctx.lr
49 )
50}
51
52fn render_compact(callback: &TerminalMonitorCallback, ctx: &CallbackContext) -> String {
54 let loss_spark = sparkline(&callback.loss_buffer.values(), callback.sparkline_width);
55 let elapsed = callback.start_time.elapsed().as_secs_f64();
56
57 let val_info = ctx.val_loss.map(|v| format!(" val={v:.4}")).unwrap_or_default();
58
59 let best_info = callback.loss_buffer.min().map(|m| format!(" best={m:.4}")).unwrap_or_default();
60
61 format!(
62 "\x1b[H\x1b[2J\
63 ═══ {} Training ═══\n\
64 Epoch {}/{} │ loss={:.4}{}{}\n\
65 Loss: {} \n\
66 LR: {:.2e} │ {:.1} steps/s\n\
67 {}",
68 callback.model_name,
69 ctx.epoch + 1,
70 ctx.max_epochs,
71 ctx.loss,
72 val_info,
73 best_info,
74 loss_spark,
75 ctx.lr,
76 ctx.global_step as f64 / elapsed.max(0.001),
77 callback.progress.render()
78 )
79}
80
81fn render_full(callback: &TerminalMonitorCallback, ctx: &CallbackContext) -> String {
83 let loss_spark = sparkline(&callback.loss_buffer.values(), callback.sparkline_width);
84 let lr_spark = sparkline(&callback.lr_buffer.values(), callback.sparkline_width);
85 let elapsed = callback.start_time.elapsed().as_secs_f64();
86 let steps_per_sec = ctx.global_step as f64 / elapsed.max(0.001);
87
88 let val_spark = if callback.val_loss_buffer.is_empty() {
89 String::new()
90 } else {
91 format!(
92 "Val Loss: {} {:.4}\n",
93 sparkline(&callback.val_loss_buffer.values(), callback.sparkline_width),
94 callback.val_loss_buffer.last().unwrap_or(0.0)
95 )
96 };
97
98 let alerts = render_alerts(callback);
99
100 format!(
101 "\x1b[H\x1b[2J\
102╔═══════════════════════════════════════════════════════════════════╗
103║ ENTRENAR TRAINING MONITOR [RUNNING] ║
104╠═══════════════════════════════════════════════════════════════════╣
105║ Model: {:<20} │ Epoch: {}/{} ║
106╠═══════════════════════════════════════════════════════════════════╣
107║ Loss: {} {:.4} ║
108║ {}║ LR: {} {:.2e} ║
109╠═══════════════════════════════════════════════════════════════════╣
110║ Steps/s: {:.1} │ Elapsed: {} ║
111║ {}║
112╚═══════════════════════════════════════════════════════════════════╝
113{}",
114 callback.model_name,
115 ctx.epoch + 1,
116 ctx.max_epochs,
117 loss_spark,
118 ctx.loss,
119 val_spark,
120 lr_spark,
121 ctx.lr,
122 steps_per_sec,
123 format_duration(elapsed),
124 callback.progress.render(),
125 alerts
126 )
127}
128
129pub(crate) fn render_alerts(callback: &TerminalMonitorCallback) -> String {
131 let alerts = callback.andon.recent_alerts(3);
132 if alerts.is_empty() {
133 return String::new();
134 }
135
136 alerts
137 .iter()
138 .map(|a| {
139 let prefix = match a.level {
140 AlertLevel::Info => "ℹ️ ",
141 AlertLevel::Warning => "⚠️ ",
142 AlertLevel::Critical => "🛑",
143 };
144 format!("{} {}", prefix, a.message)
145 })
146 .collect::<Vec<_>>()
147 .join("\n")
148}