1use std::collections::HashMap;
2
3use crate::data::metrics::MetricRecord;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum Panel {
8 Chart1,
9 Chart2,
10 Info,
11}
12
13impl Panel {
14 pub fn next(self) -> Self {
15 match self {
16 Panel::Chart1 => Panel::Chart2,
17 Panel::Chart2 => Panel::Info,
18 Panel::Info => Panel::Chart1,
19 }
20 }
21}
22
23pub struct App {
25 pub records: HashMap<String, Vec<MetricRecord>>,
27 pub experiment_names: Vec<String>,
29 pub focus: Panel,
31 pub zoom: f64,
33 pub should_quit: bool,
35 pub chart_metrics: (String, String),
37 pub alerts: Vec<String>,
39}
40
41impl App {
42 pub fn new(experiment_names: Vec<String>) -> Self {
43 Self {
44 records: HashMap::new(),
45 experiment_names,
46 focus: Panel::Chart1,
47 zoom: 1.0,
48 should_quit: false,
49 chart_metrics: ("loss".into(), "psnr".into()),
50 alerts: Vec::new(),
51 }
52 }
53
54 pub fn push_records(&mut self, experiment: &str, new_records: Vec<MetricRecord>) {
56 if self.records.is_empty() || self.records.values().all(|v| v.is_empty()) {
58 if let Some(first) = new_records.first() {
59 let mut names: Vec<&String> = first.metrics.keys().collect();
60 names.sort();
61 if let Some(name) = names.first() {
62 self.chart_metrics.0 = (*name).clone();
63 }
64 if let Some(name) = names.get(1) {
65 self.chart_metrics.1 = (*name).clone();
66 }
67 }
68 }
69
70 let entry = self.records.entry(experiment.into()).or_default();
71 entry.extend(new_records);
72 }
73
74 pub fn metric_series(&self, experiment: &str, metric: &str) -> Vec<f64> {
76 self.records
77 .get(experiment)
78 .map(|records| {
79 records
80 .iter()
81 .filter_map(|r| r.metrics.get(metric).copied())
82 .filter(|v| v.is_finite())
83 .collect()
84 })
85 .unwrap_or_default()
86 }
87
88 pub fn current_step(&self, experiment: &str) -> Option<u64> {
90 self.records
91 .get(experiment)
92 .and_then(|r| r.last())
93 .map(|r| r.step)
94 }
95
96 pub fn best_metric(&self, experiment: &str, metric: &str) -> Option<f64> {
98 let series = self.metric_series(experiment, metric);
99 if series.is_empty() {
100 return None;
101 }
102
103 let lower = metric.to_lowercase();
104 let minimize = lower.contains("loss") || lower.contains("lpips") || lower.contains("error");
105
106 if minimize {
107 series.into_iter().reduce(f64::min)
108 } else {
109 series.into_iter().reduce(f64::max)
110 }
111 }
112
113 #[allow(dead_code)]
115 pub fn eta_seconds(&self, experiment: &str, total_steps: u64) -> Option<f64> {
116 let records = self.records.get(experiment)?;
117 if records.len() < 2 {
118 return None;
119 }
120
121 let first = records.first()?;
122 let last = records.last()?;
123 let elapsed = last.timestamp - first.timestamp;
124 let steps_done = last.step - first.step;
125
126 if steps_done == 0 || elapsed <= 0.0 {
127 return None;
128 }
129
130 let steps_remaining = total_steps.saturating_sub(last.step);
131 let rate = elapsed / steps_done as f64;
132 Some(steps_remaining as f64 * rate)
133 }
134
135 pub fn handle_key(&mut self, key: crossterm::event::KeyCode) {
136 use crossterm::event::KeyCode;
137 match key {
138 KeyCode::Char('q') => self.should_quit = true,
139 KeyCode::Tab => self.focus = self.focus.next(),
140 KeyCode::Char(']') => {
141 self.zoom = (self.zoom * 1.5).min(10.0);
142 }
143 KeyCode::Char('[') => {
144 self.zoom = (self.zoom / 1.5).max(0.1);
145 }
146 _ => {}
147 }
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use std::collections::HashMap;
155
156 fn make_record(step: u64, loss: f64, psnr: f64) -> MetricRecord {
157 MetricRecord {
158 step,
159 timestamp: step as f64 * 1.0,
160 metrics: HashMap::from([("loss".into(), loss), ("psnr".into(), psnr)]),
161 gpu: None,
162 }
163 }
164
165 #[test]
166 fn test_push_records() {
167 let mut app = App::new(vec!["exp-1".into()]);
168 app.push_records("exp-1", vec![make_record(0, 1.0, 20.0)]);
169 assert_eq!(app.records["exp-1"].len(), 1);
170
171 app.push_records("exp-1", vec![make_record(1, 0.5, 25.0)]);
172 assert_eq!(app.records["exp-1"].len(), 2);
173 }
174
175 #[test]
176 fn test_metric_series() {
177 let mut app = App::new(vec!["exp".into()]);
178 app.push_records(
179 "exp",
180 vec![
181 make_record(0, 1.0, 20.0),
182 make_record(1, 0.5, 25.0),
183 make_record(2, 0.3, 28.0),
184 ],
185 );
186
187 let loss = app.metric_series("exp", "loss");
188 assert_eq!(loss, vec![1.0, 0.5, 0.3]);
189
190 let psnr = app.metric_series("exp", "psnr");
191 assert_eq!(psnr, vec![20.0, 25.0, 28.0]);
192 }
193
194 #[test]
195 fn test_metric_series_filters_nan() {
196 let mut app = App::new(vec!["exp".into()]);
197 app.push_records(
198 "exp",
199 vec![
200 make_record(0, 1.0, 20.0),
201 MetricRecord {
202 step: 1,
203 timestamp: 1.0,
204 metrics: HashMap::from([("loss".into(), f64::NAN)]),
205 gpu: None,
206 },
207 make_record(2, 0.5, 25.0),
208 ],
209 );
210
211 let loss = app.metric_series("exp", "loss");
212 assert_eq!(loss, vec![1.0, 0.5]);
213 }
214
215 #[test]
216 fn test_best_metric_loss_minimized() {
217 let mut app = App::new(vec!["exp".into()]);
218 app.push_records(
219 "exp",
220 vec![
221 make_record(0, 1.0, 20.0),
222 make_record(1, 0.3, 28.0),
223 make_record(2, 0.5, 25.0),
224 ],
225 );
226 assert_eq!(app.best_metric("exp", "loss"), Some(0.3));
227 }
228
229 #[test]
230 fn test_best_metric_psnr_maximized() {
231 let mut app = App::new(vec!["exp".into()]);
232 app.push_records(
233 "exp",
234 vec![
235 make_record(0, 1.0, 20.0),
236 make_record(1, 0.3, 28.0),
237 make_record(2, 0.5, 25.0),
238 ],
239 );
240 assert_eq!(app.best_metric("exp", "psnr"), Some(28.0));
241 }
242
243 #[test]
244 fn test_current_step() {
245 let mut app = App::new(vec!["exp".into()]);
246 app.push_records(
247 "exp",
248 vec![make_record(0, 1.0, 20.0), make_record(99, 0.1, 30.0)],
249 );
250 assert_eq!(app.current_step("exp"), Some(99));
251 }
252
253 #[test]
254 fn test_eta_seconds() {
255 let mut app = App::new(vec!["exp".into()]);
256 app.push_records(
257 "exp",
258 vec![
259 MetricRecord {
260 step: 0,
261 timestamp: 0.0,
262 metrics: HashMap::from([("loss".into(), 1.0)]),
263 gpu: None,
264 },
265 MetricRecord {
266 step: 100,
267 timestamp: 100.0,
268 metrics: HashMap::from([("loss".into(), 0.5)]),
269 gpu: None,
270 },
271 ],
272 );
273
274 let eta = app.eta_seconds("exp", 200).unwrap();
275 assert!((eta - 100.0).abs() < 1.0);
276 }
277
278 #[test]
279 fn test_handle_key_quit() {
280 let mut app = App::new(vec![]);
281 assert!(!app.should_quit);
282 app.handle_key(crossterm::event::KeyCode::Char('q'));
283 assert!(app.should_quit);
284 }
285
286 #[test]
287 fn test_handle_key_tab_cycles() {
288 let mut app = App::new(vec![]);
289 assert_eq!(app.focus, Panel::Chart1);
290 app.handle_key(crossterm::event::KeyCode::Tab);
291 assert_eq!(app.focus, Panel::Chart2);
292 app.handle_key(crossterm::event::KeyCode::Tab);
293 assert_eq!(app.focus, Panel::Info);
294 app.handle_key(crossterm::event::KeyCode::Tab);
295 assert_eq!(app.focus, Panel::Chart1);
296 }
297
298 #[test]
299 fn test_handle_key_zoom() {
300 let mut app = App::new(vec![]);
301 let initial_zoom = app.zoom;
302 app.handle_key(crossterm::event::KeyCode::Char(']'));
303 assert!(app.zoom > initial_zoom);
304 app.handle_key(crossterm::event::KeyCode::Char('['));
305 assert!((app.zoom - initial_zoom).abs() < 0.01);
306 }
307}