Skip to main content

nuviz_cli/tui/
app.rs

1use std::collections::HashMap;
2
3use crate::data::metrics::MetricRecord;
4
5/// Panel focus state for keyboard navigation.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum Panel {
8    Chart1,
9    Chart2,
10    Info,
11}
12
13impl Panel {
14    pub fn next(self) -> Self {
15        match self {
16            Panel::Chart1 => Panel::Chart2,
17            Panel::Chart2 => Panel::Info,
18            Panel::Info => Panel::Chart1,
19        }
20    }
21}
22
23/// Application state for the watch TUI.
24pub struct App {
25    /// All metric records received so far, per experiment
26    pub records: HashMap<String, Vec<MetricRecord>>,
27    /// Which experiments we are watching
28    pub experiment_names: Vec<String>,
29    /// Currently focused panel
30    pub focus: Panel,
31    /// Time axis zoom level (1.0 = fit all data)
32    pub zoom: f64,
33    /// Whether the app should quit
34    pub should_quit: bool,
35    /// Metric names to display in chart panels
36    pub chart_metrics: (String, String),
37    /// Alert messages
38    pub alerts: Vec<String>,
39}
40
41impl App {
42    pub fn new(experiment_names: Vec<String>) -> Self {
43        Self {
44            records: HashMap::new(),
45            experiment_names,
46            focus: Panel::Chart1,
47            zoom: 1.0,
48            should_quit: false,
49            chart_metrics: ("loss".into(), "psnr".into()),
50            alerts: Vec::new(),
51        }
52    }
53
54    /// Add new records for an experiment.
55    pub fn push_records(&mut self, experiment: &str, new_records: Vec<MetricRecord>) {
56        // Auto-detect chart metrics from first record
57        if self.records.is_empty() || self.records.values().all(|v| v.is_empty()) {
58            if let Some(first) = new_records.first() {
59                let mut names: Vec<&String> = first.metrics.keys().collect();
60                names.sort();
61                if let Some(name) = names.first() {
62                    self.chart_metrics.0 = (*name).clone();
63                }
64                if let Some(name) = names.get(1) {
65                    self.chart_metrics.1 = (*name).clone();
66                }
67            }
68        }
69
70        let entry = self.records.entry(experiment.into()).or_default();
71        entry.extend(new_records);
72    }
73
74    /// Get metric values for a specific metric across all steps.
75    pub fn metric_series(&self, experiment: &str, metric: &str) -> Vec<f64> {
76        self.records
77            .get(experiment)
78            .map(|records| {
79                records
80                    .iter()
81                    .filter_map(|r| r.metrics.get(metric).copied())
82                    .filter(|v| v.is_finite())
83                    .collect()
84            })
85            .unwrap_or_default()
86    }
87
88    /// Get current step for an experiment.
89    pub fn current_step(&self, experiment: &str) -> Option<u64> {
90        self.records
91            .get(experiment)
92            .and_then(|r| r.last())
93            .map(|r| r.step)
94    }
95
96    /// Get best value for a metric.
97    pub fn best_metric(&self, experiment: &str, metric: &str) -> Option<f64> {
98        let series = self.metric_series(experiment, metric);
99        if series.is_empty() {
100            return None;
101        }
102
103        let lower = metric.to_lowercase();
104        let minimize = lower.contains("loss") || lower.contains("lpips") || lower.contains("error");
105
106        if minimize {
107            series.into_iter().reduce(f64::min)
108        } else {
109            series.into_iter().reduce(f64::max)
110        }
111    }
112
113    /// Estimate ETA based on step rate.
114    #[allow(dead_code)]
115    pub fn eta_seconds(&self, experiment: &str, total_steps: u64) -> Option<f64> {
116        let records = self.records.get(experiment)?;
117        if records.len() < 2 {
118            return None;
119        }
120
121        let first = records.first()?;
122        let last = records.last()?;
123        let elapsed = last.timestamp - first.timestamp;
124        let steps_done = last.step - first.step;
125
126        if steps_done == 0 || elapsed <= 0.0 {
127            return None;
128        }
129
130        let steps_remaining = total_steps.saturating_sub(last.step);
131        let rate = elapsed / steps_done as f64;
132        Some(steps_remaining as f64 * rate)
133    }
134
135    pub fn handle_key(&mut self, key: crossterm::event::KeyCode) {
136        use crossterm::event::KeyCode;
137        match key {
138            KeyCode::Char('q') => self.should_quit = true,
139            KeyCode::Tab => self.focus = self.focus.next(),
140            KeyCode::Char(']') => {
141                self.zoom = (self.zoom * 1.5).min(10.0);
142            }
143            KeyCode::Char('[') => {
144                self.zoom = (self.zoom / 1.5).max(0.1);
145            }
146            _ => {}
147        }
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use std::collections::HashMap;
155
156    fn make_record(step: u64, loss: f64, psnr: f64) -> MetricRecord {
157        MetricRecord {
158            step,
159            timestamp: step as f64 * 1.0,
160            metrics: HashMap::from([("loss".into(), loss), ("psnr".into(), psnr)]),
161            gpu: None,
162        }
163    }
164
165    #[test]
166    fn test_push_records() {
167        let mut app = App::new(vec!["exp-1".into()]);
168        app.push_records("exp-1", vec![make_record(0, 1.0, 20.0)]);
169        assert_eq!(app.records["exp-1"].len(), 1);
170
171        app.push_records("exp-1", vec![make_record(1, 0.5, 25.0)]);
172        assert_eq!(app.records["exp-1"].len(), 2);
173    }
174
175    #[test]
176    fn test_metric_series() {
177        let mut app = App::new(vec!["exp".into()]);
178        app.push_records(
179            "exp",
180            vec![
181                make_record(0, 1.0, 20.0),
182                make_record(1, 0.5, 25.0),
183                make_record(2, 0.3, 28.0),
184            ],
185        );
186
187        let loss = app.metric_series("exp", "loss");
188        assert_eq!(loss, vec![1.0, 0.5, 0.3]);
189
190        let psnr = app.metric_series("exp", "psnr");
191        assert_eq!(psnr, vec![20.0, 25.0, 28.0]);
192    }
193
194    #[test]
195    fn test_metric_series_filters_nan() {
196        let mut app = App::new(vec!["exp".into()]);
197        app.push_records(
198            "exp",
199            vec![
200                make_record(0, 1.0, 20.0),
201                MetricRecord {
202                    step: 1,
203                    timestamp: 1.0,
204                    metrics: HashMap::from([("loss".into(), f64::NAN)]),
205                    gpu: None,
206                },
207                make_record(2, 0.5, 25.0),
208            ],
209        );
210
211        let loss = app.metric_series("exp", "loss");
212        assert_eq!(loss, vec![1.0, 0.5]);
213    }
214
215    #[test]
216    fn test_best_metric_loss_minimized() {
217        let mut app = App::new(vec!["exp".into()]);
218        app.push_records(
219            "exp",
220            vec![
221                make_record(0, 1.0, 20.0),
222                make_record(1, 0.3, 28.0),
223                make_record(2, 0.5, 25.0),
224            ],
225        );
226        assert_eq!(app.best_metric("exp", "loss"), Some(0.3));
227    }
228
229    #[test]
230    fn test_best_metric_psnr_maximized() {
231        let mut app = App::new(vec!["exp".into()]);
232        app.push_records(
233            "exp",
234            vec![
235                make_record(0, 1.0, 20.0),
236                make_record(1, 0.3, 28.0),
237                make_record(2, 0.5, 25.0),
238            ],
239        );
240        assert_eq!(app.best_metric("exp", "psnr"), Some(28.0));
241    }
242
243    #[test]
244    fn test_current_step() {
245        let mut app = App::new(vec!["exp".into()]);
246        app.push_records(
247            "exp",
248            vec![make_record(0, 1.0, 20.0), make_record(99, 0.1, 30.0)],
249        );
250        assert_eq!(app.current_step("exp"), Some(99));
251    }
252
253    #[test]
254    fn test_eta_seconds() {
255        let mut app = App::new(vec!["exp".into()]);
256        app.push_records(
257            "exp",
258            vec![
259                MetricRecord {
260                    step: 0,
261                    timestamp: 0.0,
262                    metrics: HashMap::from([("loss".into(), 1.0)]),
263                    gpu: None,
264                },
265                MetricRecord {
266                    step: 100,
267                    timestamp: 100.0,
268                    metrics: HashMap::from([("loss".into(), 0.5)]),
269                    gpu: None,
270                },
271            ],
272        );
273
274        let eta = app.eta_seconds("exp", 200).unwrap();
275        assert!((eta - 100.0).abs() < 1.0);
276    }
277
278    #[test]
279    fn test_handle_key_quit() {
280        let mut app = App::new(vec![]);
281        assert!(!app.should_quit);
282        app.handle_key(crossterm::event::KeyCode::Char('q'));
283        assert!(app.should_quit);
284    }
285
286    #[test]
287    fn test_handle_key_tab_cycles() {
288        let mut app = App::new(vec![]);
289        assert_eq!(app.focus, Panel::Chart1);
290        app.handle_key(crossterm::event::KeyCode::Tab);
291        assert_eq!(app.focus, Panel::Chart2);
292        app.handle_key(crossterm::event::KeyCode::Tab);
293        assert_eq!(app.focus, Panel::Info);
294        app.handle_key(crossterm::event::KeyCode::Tab);
295        assert_eq!(app.focus, Panel::Chart1);
296    }
297
298    #[test]
299    fn test_handle_key_zoom() {
300        let mut app = App::new(vec![]);
301        let initial_zoom = app.zoom;
302        app.handle_key(crossterm::event::KeyCode::Char(']'));
303        assert!(app.zoom > initial_zoom);
304        app.handle_key(crossterm::event::KeyCode::Char('['));
305        assert!((app.zoom - initial_zoom).abs() < 0.01);
306    }
307}