1use crate::renderer::{tui::NumericMetricsState, MetricsRenderer};
2use crate::renderer::{MetricState, TrainingProgress};
3use crate::TrainingInterrupter;
4use ratatui::{
5 crossterm::{
6 event::{self, Event, KeyCode},
7 execute,
8 terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen},
9 },
10 prelude::*,
11 Terminal,
12};
13use std::panic::{set_hook, take_hook};
14use std::sync::Arc;
15use std::{
16 error::Error,
17 io::{self, Stdout},
18 time::{Duration, Instant},
19};
20
21use super::{
22 Callback, CallbackFn, ControlsView, MetricsView, PopupState, ProgressBarState, StatusState,
23 TextMetricsState,
24};
25
26pub(crate) type TerminalBackend = CrosstermBackend<Stdout>;
28pub(crate) type TerminalFrame<'a> = ratatui::Frame<'a>;
30
31#[allow(deprecated)] type PanicHook = Box<dyn Fn(&std::panic::PanicInfo<'_>) + 'static + Sync + Send>;
33
34const MAX_REFRESH_RATE_MILLIS: u64 = 100;
35
36pub struct TuiMetricsRenderer {
38 terminal: Terminal<TerminalBackend>,
39 last_update: std::time::Instant,
40 progress: ProgressBarState,
41 metrics_numeric: NumericMetricsState,
42 metrics_text: TextMetricsState,
43 status: StatusState,
44 interuptor: TrainingInterrupter,
45 popup: PopupState,
46 previous_panic_hook: Option<Arc<PanicHook>>,
47 persistent: bool,
48}
49
50impl MetricsRenderer for TuiMetricsRenderer {
51 fn update_train(&mut self, state: MetricState) {
52 match state {
53 MetricState::Generic(entry) => {
54 self.metrics_text.update_train(entry);
55 }
56 MetricState::Numeric(entry, value) => {
57 self.metrics_numeric.push_train(entry.name.clone(), value);
58 self.metrics_text.update_train(entry);
59 }
60 };
61 }
62
63 fn update_valid(&mut self, state: MetricState) {
64 match state {
65 MetricState::Generic(entry) => {
66 self.metrics_text.update_valid(entry);
67 }
68 MetricState::Numeric(entry, value) => {
69 self.metrics_numeric.push_valid(entry.name.clone(), value);
70 self.metrics_text.update_valid(entry);
71 }
72 };
73 }
74
75 fn render_train(&mut self, item: TrainingProgress) {
76 self.progress.update_train(&item);
77 self.metrics_numeric.update_progress_train(&item);
78 self.status.update_train(item);
79 self.render().unwrap();
80 }
81
82 fn render_valid(&mut self, item: TrainingProgress) {
83 self.progress.update_valid(&item);
84 self.metrics_numeric.update_progress_valid(&item);
85 self.status.update_valid(item);
86 self.render().unwrap();
87 }
88}
89
90impl TuiMetricsRenderer {
91 pub fn new(interuptor: TrainingInterrupter, checkpoint: Option<usize>) -> Self {
93 let mut stdout = io::stdout();
94 execute!(stdout, EnterAlternateScreen).unwrap();
95 enable_raw_mode().unwrap();
96 let terminal = Terminal::new(CrosstermBackend::new(stdout)).unwrap();
97
98 let previous_panic_hook = Arc::new(take_hook());
101 set_hook(Box::new({
102 let previous_panic_hook = previous_panic_hook.clone();
103 move |panic_info| {
104 let _ = disable_raw_mode();
105 let _ = execute!(io::stdout(), LeaveAlternateScreen);
106 previous_panic_hook(panic_info);
107 }
108 }));
109
110 Self {
111 terminal,
112 last_update: Instant::now(),
113 progress: ProgressBarState::new(checkpoint),
114 metrics_numeric: NumericMetricsState::default(),
115 metrics_text: TextMetricsState::default(),
116 status: StatusState::default(),
117 interuptor,
118 popup: PopupState::Empty,
119 previous_panic_hook: Some(previous_panic_hook),
120 persistent: false,
121 }
122 }
123
124 pub fn persistent(mut self) -> Self {
126 self.persistent = true;
127 self
128 }
129
130 fn render(&mut self) -> Result<(), Box<dyn Error>> {
131 let tick_rate = Duration::from_millis(MAX_REFRESH_RATE_MILLIS);
132 if self.last_update.elapsed() < tick_rate {
133 return Ok(());
134 }
135
136 self.draw()?;
137 self.handle_events()?;
138
139 self.last_update = Instant::now();
140
141 Ok(())
142 }
143
144 fn draw(&mut self) -> Result<(), Box<dyn Error>> {
145 self.terminal.draw(|frame| {
146 let size = frame.area();
147
148 match self.popup.view() {
149 Some(view) => view.render(frame, size),
150 None => {
151 let view = MetricsView::new(
152 self.metrics_numeric.view(),
153 self.metrics_text.view(),
154 self.progress.view(),
155 ControlsView,
156 self.status.view(),
157 );
158
159 view.render(frame, size);
160 }
161 };
162 })?;
163
164 Ok(())
165 }
166
167 fn handle_events(&mut self) -> Result<(), Box<dyn Error>> {
168 while event::poll(Duration::from_secs(0))? {
169 let event = event::read()?;
170 self.popup.on_event(&event);
171
172 if self.popup.is_empty() {
173 self.metrics_numeric.on_event(&event);
174
175 if let Event::Key(key) = event {
176 if let KeyCode::Char('q') = key.code {
177 self.popup = PopupState::Full(
178 "Quit".to_string(),
179 vec![
180 Callback::new(
181 "Stop the training.",
182 "Stop the training immediately. This will break from the \
183 training loop, but any remaining code after the loop will be \
184 executed.",
185 's',
186 QuitPopupAccept(self.interuptor.clone()),
187 ),
188 Callback::new(
189 "Stop the training immediately.",
190 "Kill the program. This will create a panic! which will make \
191 the current training fails. Any code following the training \
192 won't be executed.",
193 'k',
194 KillPopupAccept,
195 ),
196 Callback::new(
197 "Cancel",
198 "Cancel the action, continue the training.",
199 'c',
200 PopupCancel,
201 ),
202 ],
203 );
204 }
205 }
206 }
207 }
208
209 Ok(())
210 }
211
212 fn handle_post_training(&mut self) -> Result<(), Box<dyn Error>> {
213 self.popup = PopupState::Full(
214 "Training is done".to_string(),
215 vec![Callback::new(
216 "Training Done",
217 "Press 'x' to close this popup. Press 'q' to exit the application after the \
218 popup is closed.",
219 'x',
220 PopupCancel,
221 )],
222 );
223
224 self.draw().ok();
225
226 loop {
227 if let Ok(true) = event::poll(Duration::from_millis(MAX_REFRESH_RATE_MILLIS)) {
228 match event::read() {
229 Ok(event @ Event::Key(key)) => {
230 if self.popup.is_empty() {
231 self.metrics_numeric.on_event(&event);
232 if let KeyCode::Char('q') = key.code {
233 break;
234 }
235 } else {
236 self.popup.on_event(&event);
237 }
238 self.draw().ok();
239 }
240
241 Ok(Event::Resize(..)) => {
242 self.draw().ok();
243 }
244 Err(err) => {
245 eprintln!("Error reading event: {}", err);
246 break;
247 }
248 _ => continue,
249 }
250 }
251 }
252 Ok(())
253 }
254}
255
256struct QuitPopupAccept(TrainingInterrupter);
257struct KillPopupAccept;
258struct PopupCancel;
259
260impl CallbackFn for KillPopupAccept {
261 fn call(&self) -> bool {
262 panic!("Killing training from user input.");
263 }
264}
265
266impl CallbackFn for QuitPopupAccept {
267 fn call(&self) -> bool {
268 self.0.stop();
269 true
270 }
271}
272
273impl CallbackFn for PopupCancel {
274 fn call(&self) -> bool {
275 true
276 }
277}
278
279impl Drop for TuiMetricsRenderer {
280 fn drop(&mut self) {
281 if !std::thread::panicking() {
284 if self.persistent {
285 if let Err(err) = self.handle_post_training() {
286 eprintln!("Error in post-training handling: {}", err);
287 }
288 }
289
290 disable_raw_mode().ok();
291 execute!(self.terminal.backend_mut(), LeaveAlternateScreen).unwrap();
292 self.terminal.show_cursor().ok();
293
294 let _ = take_hook();
296 if let Some(previous_panic_hook) =
297 Arc::into_inner(self.previous_panic_hook.take().unwrap())
298 {
299 set_hook(previous_panic_hook);
300 }
301 }
302 }
303}