burn_train/renderer/tui/
renderer.rs

1use crate::renderer::{tui::NumericMetricsState, MetricsRenderer};
2use crate::renderer::{MetricState, TrainingProgress};
3use crate::TrainingInterrupter;
4use ratatui::{
5    crossterm::{
6        event::{self, Event, KeyCode},
7        execute,
8        terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen},
9    },
10    prelude::*,
11    Terminal,
12};
13use std::panic::{set_hook, take_hook};
14use std::sync::Arc;
15use std::{
16    error::Error,
17    io::{self, Stdout},
18    time::{Duration, Instant},
19};
20
21use super::{
22    Callback, CallbackFn, ControlsView, MetricsView, PopupState, ProgressBarState, StatusState,
23    TextMetricsState,
24};
25
26/// The current terminal backend.
27pub(crate) type TerminalBackend = CrosstermBackend<Stdout>;
28/// The current terminal frame.
29pub(crate) type TerminalFrame<'a> = ratatui::Frame<'a>;
30
31#[allow(deprecated)] // `PanicInfo` type is renamed to `PanicHookInfo` in Rust 1.82
32type PanicHook = Box<dyn Fn(&std::panic::PanicInfo<'_>) + 'static + Sync + Send>;
33
34const MAX_REFRESH_RATE_MILLIS: u64 = 100;
35
36/// The terminal UI metrics renderer.
37pub struct TuiMetricsRenderer {
38    terminal: Terminal<TerminalBackend>,
39    last_update: std::time::Instant,
40    progress: ProgressBarState,
41    metrics_numeric: NumericMetricsState,
42    metrics_text: TextMetricsState,
43    status: StatusState,
44    interuptor: TrainingInterrupter,
45    popup: PopupState,
46    previous_panic_hook: Option<Arc<PanicHook>>,
47    persistent: bool,
48}
49
50impl MetricsRenderer for TuiMetricsRenderer {
51    fn update_train(&mut self, state: MetricState) {
52        match state {
53            MetricState::Generic(entry) => {
54                self.metrics_text.update_train(entry);
55            }
56            MetricState::Numeric(entry, value) => {
57                self.metrics_numeric.push_train(entry.name.clone(), value);
58                self.metrics_text.update_train(entry);
59            }
60        };
61    }
62
63    fn update_valid(&mut self, state: MetricState) {
64        match state {
65            MetricState::Generic(entry) => {
66                self.metrics_text.update_valid(entry);
67            }
68            MetricState::Numeric(entry, value) => {
69                self.metrics_numeric.push_valid(entry.name.clone(), value);
70                self.metrics_text.update_valid(entry);
71            }
72        };
73    }
74
75    fn render_train(&mut self, item: TrainingProgress) {
76        self.progress.update_train(&item);
77        self.metrics_numeric.update_progress_train(&item);
78        self.status.update_train(item);
79        self.render().unwrap();
80    }
81
82    fn render_valid(&mut self, item: TrainingProgress) {
83        self.progress.update_valid(&item);
84        self.metrics_numeric.update_progress_valid(&item);
85        self.status.update_valid(item);
86        self.render().unwrap();
87    }
88}
89
90impl TuiMetricsRenderer {
91    /// Create a new terminal UI renderer.
92    pub fn new(interuptor: TrainingInterrupter, checkpoint: Option<usize>) -> Self {
93        let mut stdout = io::stdout();
94        execute!(stdout, EnterAlternateScreen).unwrap();
95        enable_raw_mode().unwrap();
96        let terminal = Terminal::new(CrosstermBackend::new(stdout)).unwrap();
97
98        // Reset the terminal to raw mode on panic before running the panic handler
99        // This prevents that the panic message is not visible for the user.
100        let previous_panic_hook = Arc::new(take_hook());
101        set_hook(Box::new({
102            let previous_panic_hook = previous_panic_hook.clone();
103            move |panic_info| {
104                let _ = disable_raw_mode();
105                let _ = execute!(io::stdout(), LeaveAlternateScreen);
106                previous_panic_hook(panic_info);
107            }
108        }));
109
110        Self {
111            terminal,
112            last_update: Instant::now(),
113            progress: ProgressBarState::new(checkpoint),
114            metrics_numeric: NumericMetricsState::default(),
115            metrics_text: TextMetricsState::default(),
116            status: StatusState::default(),
117            interuptor,
118            popup: PopupState::Empty,
119            previous_panic_hook: Some(previous_panic_hook),
120            persistent: false,
121        }
122    }
123
124    /// Set the renderer to persistent mode.
125    pub fn persistent(mut self) -> Self {
126        self.persistent = true;
127        self
128    }
129
130    fn render(&mut self) -> Result<(), Box<dyn Error>> {
131        let tick_rate = Duration::from_millis(MAX_REFRESH_RATE_MILLIS);
132        if self.last_update.elapsed() < tick_rate {
133            return Ok(());
134        }
135
136        self.draw()?;
137        self.handle_events()?;
138
139        self.last_update = Instant::now();
140
141        Ok(())
142    }
143
144    fn draw(&mut self) -> Result<(), Box<dyn Error>> {
145        self.terminal.draw(|frame| {
146            let size = frame.area();
147
148            match self.popup.view() {
149                Some(view) => view.render(frame, size),
150                None => {
151                    let view = MetricsView::new(
152                        self.metrics_numeric.view(),
153                        self.metrics_text.view(),
154                        self.progress.view(),
155                        ControlsView,
156                        self.status.view(),
157                    );
158
159                    view.render(frame, size);
160                }
161            };
162        })?;
163
164        Ok(())
165    }
166
167    fn handle_events(&mut self) -> Result<(), Box<dyn Error>> {
168        while event::poll(Duration::from_secs(0))? {
169            let event = event::read()?;
170            self.popup.on_event(&event);
171
172            if self.popup.is_empty() {
173                self.metrics_numeric.on_event(&event);
174
175                if let Event::Key(key) = event {
176                    if let KeyCode::Char('q') = key.code {
177                        self.popup = PopupState::Full(
178                            "Quit".to_string(),
179                            vec![
180                                Callback::new(
181                                    "Stop the training.",
182                                    "Stop the training immediately. This will break from the \
183                                     training loop, but any remaining code after the loop will be \
184                                     executed.",
185                                    's',
186                                    QuitPopupAccept(self.interuptor.clone()),
187                                ),
188                                Callback::new(
189                                    "Stop the training immediately.",
190                                    "Kill the program. This will create a panic! which will make \
191                                     the current training fails. Any code following the training \
192                                     won't be executed.",
193                                    'k',
194                                    KillPopupAccept,
195                                ),
196                                Callback::new(
197                                    "Cancel",
198                                    "Cancel the action, continue the training.",
199                                    'c',
200                                    PopupCancel,
201                                ),
202                            ],
203                        );
204                    }
205                }
206            }
207        }
208
209        Ok(())
210    }
211
212    fn handle_post_training(&mut self) -> Result<(), Box<dyn Error>> {
213        self.popup = PopupState::Full(
214            "Training is done".to_string(),
215            vec![Callback::new(
216                "Training Done",
217                "Press 'x' to close this popup.  Press 'q' to exit the application after the \
218                popup is closed.",
219                'x',
220                PopupCancel,
221            )],
222        );
223
224        self.draw().ok();
225
226        loop {
227            if let Ok(true) = event::poll(Duration::from_millis(MAX_REFRESH_RATE_MILLIS)) {
228                match event::read() {
229                    Ok(event @ Event::Key(key)) => {
230                        if self.popup.is_empty() {
231                            self.metrics_numeric.on_event(&event);
232                            if let KeyCode::Char('q') = key.code {
233                                break;
234                            }
235                        } else {
236                            self.popup.on_event(&event);
237                        }
238                        self.draw().ok();
239                    }
240
241                    Ok(Event::Resize(..)) => {
242                        self.draw().ok();
243                    }
244                    Err(err) => {
245                        eprintln!("Error reading event: {}", err);
246                        break;
247                    }
248                    _ => continue,
249                }
250            }
251        }
252        Ok(())
253    }
254}
255
256struct QuitPopupAccept(TrainingInterrupter);
257struct KillPopupAccept;
258struct PopupCancel;
259
260impl CallbackFn for KillPopupAccept {
261    fn call(&self) -> bool {
262        panic!("Killing training from user input.");
263    }
264}
265
266impl CallbackFn for QuitPopupAccept {
267    fn call(&self) -> bool {
268        self.0.stop();
269        true
270    }
271}
272
273impl CallbackFn for PopupCancel {
274    fn call(&self) -> bool {
275        true
276    }
277}
278
279impl Drop for TuiMetricsRenderer {
280    fn drop(&mut self) {
281        // Reset the terminal back to raw mode. This can be skipped during
282        // panicking because the panic hook has already reset the terminal
283        if !std::thread::panicking() {
284            if self.persistent {
285                if let Err(err) = self.handle_post_training() {
286                    eprintln!("Error in post-training handling: {}", err);
287                }
288            }
289
290            disable_raw_mode().ok();
291            execute!(self.terminal.backend_mut(), LeaveAlternateScreen).unwrap();
292            self.terminal.show_cursor().ok();
293
294            // Reinstall the previous panic hook
295            let _ = take_hook();
296            if let Some(previous_panic_hook) =
297                Arc::into_inner(self.previous_panic_hook.take().unwrap())
298            {
299                set_hook(previous_panic_hook);
300            }
301        }
302    }
303}