Skip to main content

entrenar/monitor/tui/render/
charts.rs

1//! Chart widgets: gauge, braille chart, sample panel, config panel, history table.
2
3use super::super::color::{ColorMode, Styled};
4use super::super::state::{SamplePeek, TrainingSnapshot};
5use super::bars::{build_block_bar, render_sparkline};
6use super::epoch::{compute_epoch_summaries, EpochSummary};
7use super::format::format_lr;
8
9pub fn render_gauge(value: f32, max: f32, width: usize, label: &str) -> String {
10    let percent = if max > 0.0 { value / max * 100.0 } else { 0.0 };
11    let bar = build_block_bar(percent, width.saturating_sub(label.len() + 8));
12    format!("{label}{bar} {percent:>5.1}%")
13}
14
15pub struct BrailleChart {
16    width: usize,
17    height: usize,
18    data: Vec<f32>,
19    color_mode: ColorMode,
20}
21
22impl BrailleChart {
23    pub fn new(width: usize, height: usize) -> Self {
24        Self { width, height, data: Vec::new(), color_mode: ColorMode::detect() }
25    }
26
27    pub fn color_mode(mut self, mode: ColorMode) -> Self {
28        self.color_mode = mode;
29        self
30    }
31
32    pub fn data(mut self, data: Vec<f32>) -> Self {
33        self.data = data;
34        self
35    }
36
37    #[allow(dead_code)]
38    pub fn bounds(self, _min: f32, _max: f32) -> Self {
39        self
40    }
41
42    pub fn log_scale(self, _enabled: bool) -> Self {
43        self
44    }
45
46    pub fn render(&self) -> String {
47        if self.data.is_empty() {
48            return " ".repeat(self.width).repeat(self.height);
49        }
50        let mut lines = Vec::new();
51        for row in 0..self.height {
52            let start = (row * self.data.len()) / self.height;
53            let end = ((row + 1) * self.data.len()) / self.height;
54            let slice = if end > start {
55                &self.data[start..end]
56            } else if start < self.data.len() {
57                &self.data[start..=start]
58            } else {
59                &[]
60            };
61            lines.push(render_sparkline(slice, self.width, self.color_mode));
62        }
63        lines.join("\n")
64    }
65}
66
67pub fn render_braille_chart(data: &[f32], width: usize, height: usize, _log_scale: bool) -> String {
68    BrailleChart::new(width, height).data(data.to_vec()).render()
69}
70
71pub fn render_sample_panel(
72    _sample: Option<&SamplePeek>,
73    _width: usize,
74    _color_mode: ColorMode,
75) -> String {
76    String::new()
77}
78
79pub fn render_config_panel(
80    snapshot: &TrainingSnapshot,
81    width: usize,
82    color_mode: ColorMode,
83) -> String {
84    let mut lines = Vec::new();
85
86    let model_name = if snapshot.model_name.is_empty() { "N/A" } else { &snapshot.model_name };
87    let model_display: String = model_name.chars().take(width - 8).collect();
88    lines.push(Styled::new(&model_display, color_mode).fg((180, 180, 255)).to_string());
89
90    let opt = if snapshot.optimizer_name.is_empty() { "N/A" } else { &snapshot.optimizer_name };
91    let batch = if snapshot.batch_size > 0 {
92        format!("batch:{}", snapshot.batch_size)
93    } else {
94        "N/A".to_string()
95    };
96    lines.push(format!("{opt}  {batch}"));
97
98    lines.join("\n")
99}
100
101pub fn render_history_table(
102    snapshot: &TrainingSnapshot,
103    width: usize,
104    max_rows: usize,
105    color_mode: ColorMode,
106) -> String {
107    let mut lines = Vec::new();
108
109    let header = format!(
110        "{:>5} {:>8} {:>8} {:>8} {:>10} {:>10} {:>5}",
111        "Epoch", "Loss", "Min", "Max", "LR", "Tok/s", "Trend"
112    );
113    lines.push(Styled::new(&header, color_mode).fg((150, 150, 150)).to_string());
114    lines.push("\u{2500}".repeat(width.min(70)));
115
116    let summaries = compute_epoch_summaries(snapshot);
117    if summaries.is_empty() {
118        lines.push("(waiting for epoch data...)".to_string());
119        return lines.join("\n");
120    }
121
122    let start_idx = summaries.len().saturating_sub(max_rows);
123    for (i, summary) in summaries.iter().skip(start_idx).enumerate() {
124        let trend = history_trend(i, start_idx, summary, &summaries, color_mode);
125
126        let row = format!(
127            "{:>5} {:>8.3} {:>8.3} {:>8.3} {:>10} {:>10.1} {}",
128            summary.epoch,
129            summary.avg_loss,
130            summary.min_loss,
131            summary.max_loss,
132            format_lr(summary.lr),
133            summary.tokens_per_sec,
134            Styled::new(trend.0, color_mode).fg(trend.1)
135        );
136        lines.push(row);
137    }
138
139    if start_idx > 0 {
140        lines.push(format!("  \u{2191} {start_idx} more epochs above"));
141    }
142
143    lines.join("\n")
144}
145
146fn history_trend<'a>(
147    i: usize,
148    start_idx: usize,
149    summary: &EpochSummary,
150    summaries: &[EpochSummary],
151    _color_mode: ColorMode,
152) -> (&'a str, (u8, u8, u8)) {
153    if i > 0 || start_idx > 0 {
154        let prev_idx = if i > 0 { start_idx + i - 1 } else { start_idx.saturating_sub(1) };
155        if let Some(prev) = summaries.get(prev_idx) {
156            let change = (summary.avg_loss - prev.avg_loss) / prev.avg_loss.abs().max(0.001);
157            if change < -0.02 {
158                ("\u{2193}", (100, 255, 100))
159            } else if change > 0.02 {
160                ("\u{2191}", (255, 100, 100))
161            } else {
162                ("\u{2192}", (150, 150, 150))
163            }
164        } else {
165            ("", (150, 150, 150))
166        }
167    } else {
168        ("", (150, 150, 150))
169    }
170}
171
172#[cfg(test)]
173#[allow(clippy::unwrap_used)]
174mod tests {
175    use super::*;
176
177    #[test]
178    fn test_history_table_render() {
179        let snapshot = TrainingSnapshot {
180            steps_per_epoch: 4,
181            loss_history: vec![10.0, 9.5, 9.0, 8.5, 8.0, 7.5, 7.0, 6.5, 6.0, 5.5, 5.0, 4.5],
182            tokens_per_second: 100.0,
183            learning_rate: 0.0001,
184            gradient_norm: 2.5,
185            ..Default::default()
186        };
187
188        let table = render_history_table(&snapshot, 80, 10, ColorMode::Mono);
189        assert!(table.contains("Epoch"));
190        assert!(table.contains("Loss"));
191    }
192
193    #[test]
194    fn test_history_table_empty() {
195        let snapshot = TrainingSnapshot::default();
196        let table = render_history_table(&snapshot, 80, 10, ColorMode::Mono);
197        assert!(table.contains("waiting for epoch data"));
198    }
199
200    // ── render_gauge tests ─────────────────────────────────────────
201
202    #[test]
203    fn test_render_gauge_zero() {
204        let result = render_gauge(0.0, 100.0, 30, "GPU: ");
205        assert!(result.contains("0.0%"));
206        assert!(result.starts_with("GPU: "));
207    }
208
209    #[test]
210    fn test_render_gauge_full() {
211        let result = render_gauge(100.0, 100.0, 30, "");
212        assert!(result.contains("100.0%"));
213    }
214
215    #[test]
216    fn test_render_gauge_zero_max() {
217        let result = render_gauge(50.0, 0.0, 30, "");
218        assert!(result.contains("0.0%"));
219    }
220
221    #[test]
222    fn test_render_gauge_half() {
223        let result = render_gauge(50.0, 100.0, 30, "VRAM: ");
224        assert!(result.contains("50.0%"));
225    }
226
227    // ── BrailleChart tests ─────────────────────────────────────────
228
229    #[test]
230    fn test_braille_chart_empty_data() {
231        let chart = BrailleChart::new(10, 3).data(Vec::new()).render();
232        // Empty data should return spaces
233        assert!(chart.chars().all(|c| c == ' '));
234    }
235
236    #[test]
237    fn test_braille_chart_with_data() {
238        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
239        let chart = BrailleChart::new(10, 2).data(data).render();
240        assert!(!chart.is_empty());
241    }
242
243    #[test]
244    fn test_braille_chart_color_mode() {
245        let data = vec![1.0, 2.0, 3.0];
246        let chart = BrailleChart::new(10, 2).color_mode(ColorMode::Mono).data(data).render();
247        assert!(!chart.is_empty());
248    }
249
250    #[test]
251    fn test_braille_chart_log_scale_noop() {
252        let data = vec![1.0, 10.0, 100.0];
253        let chart = BrailleChart::new(10, 2).log_scale(true).data(data).render();
254        assert!(!chart.is_empty());
255    }
256
257    #[test]
258    fn test_braille_chart_bounds_noop() {
259        let data = vec![1.0, 5.0, 10.0];
260        let chart = BrailleChart::new(10, 2).bounds(0.0, 10.0).data(data).render();
261        assert!(!chart.is_empty());
262    }
263
264    #[test]
265    fn test_braille_chart_single_datapoint() {
266        let data = vec![5.0];
267        let chart = BrailleChart::new(10, 2).data(data).render();
268        assert!(!chart.is_empty());
269    }
270
271    #[test]
272    fn test_render_braille_chart_function() {
273        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
274        let result = render_braille_chart(&data, 10, 2, false);
275        assert!(!result.is_empty());
276    }
277
278    #[test]
279    fn test_render_braille_chart_empty() {
280        let result = render_braille_chart(&[], 10, 2, false);
281        assert!(result.chars().all(|c| c == ' '));
282    }
283
284    // ── render_sample_panel tests ──────────────────────────────────
285
286    #[test]
287    fn test_render_sample_panel_none() {
288        let result = render_sample_panel(None, 80, ColorMode::Mono);
289        assert!(result.is_empty());
290    }
291
292    #[test]
293    fn test_render_sample_panel_some() {
294        let sample = SamplePeek {
295            input_preview: "fn hello()".to_string(),
296            target_preview: "fn test_hello()".to_string(),
297            generated_preview: "fn test_hello()".to_string(),
298            token_match_percent: 100.0,
299        };
300        let result = render_sample_panel(Some(&sample), 80, ColorMode::Mono);
301        // Current implementation returns empty string
302        assert!(result.is_empty());
303    }
304
305    // ── render_config_panel tests ──────────────────────────────────
306
307    #[test]
308    fn test_render_config_panel_defaults() {
309        let snapshot = TrainingSnapshot::default();
310        let result = render_config_panel(&snapshot, 80, ColorMode::Mono);
311        assert!(result.contains("N/A")); // empty model_name and optimizer
312    }
313
314    #[test]
315    fn test_render_config_panel_with_values() {
316        let snapshot = TrainingSnapshot {
317            model_name: "Qwen2.5-Coder-0.5B".to_string(),
318            optimizer_name: "AdamW".to_string(),
319            batch_size: 4,
320            ..Default::default()
321        };
322        let result = render_config_panel(&snapshot, 80, ColorMode::Mono);
323        assert!(result.contains("Qwen2.5-Coder-0.5B"));
324        assert!(result.contains("AdamW"));
325        assert!(result.contains("batch:4"));
326    }
327
328    #[test]
329    fn test_render_config_panel_zero_batch() {
330        let snapshot = TrainingSnapshot {
331            model_name: "model".to_string(),
332            optimizer_name: "SGD".to_string(),
333            batch_size: 0,
334            ..Default::default()
335        };
336        let result = render_config_panel(&snapshot, 80, ColorMode::Mono);
337        assert!(result.contains("N/A")); // batch_size 0 shows N/A
338    }
339
340    #[test]
341    fn test_render_config_panel_long_model_name_truncated() {
342        let snapshot = TrainingSnapshot { model_name: "A".repeat(200), ..Default::default() };
343        let result = render_config_panel(&snapshot, 30, ColorMode::Mono);
344        // Model name should be truncated to fit within width - 8
345        let first_line = result.lines().next().unwrap_or("");
346        assert!(first_line.len() <= 30);
347    }
348
349    // ── render_history_table advanced tests ─────────────────────────
350
351    #[test]
352    fn test_history_table_multiple_epochs_with_trend() {
353        let snapshot = TrainingSnapshot {
354            steps_per_epoch: 2,
355            loss_history: vec![10.0, 9.0, 5.0, 4.0, 2.0, 1.0],
356            lr_history: vec![0.001, 0.001, 0.0005, 0.0005, 0.0001, 0.0001],
357            tokens_per_second: 500.0,
358            ..Default::default()
359        };
360        let table = render_history_table(&snapshot, 80, 10, ColorMode::Mono);
361        // Should have header + separator + 3 epochs
362        let lines: Vec<&str> = table.lines().collect();
363        assert!(lines.len() >= 4); // header + sep + at least 2 data rows
364    }
365
366    #[test]
367    fn test_history_table_max_rows_truncation() {
368        let snapshot = TrainingSnapshot {
369            steps_per_epoch: 1,
370            loss_history: vec![10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0],
371            tokens_per_second: 100.0,
372            learning_rate: 0.001,
373            ..Default::default()
374        };
375        let table = render_history_table(&snapshot, 80, 3, ColorMode::Mono);
376        // Should show "more epochs above" message
377        assert!(table.contains("more epochs above"));
378    }
379
380    #[test]
381    fn test_history_table_single_epoch() {
382        let snapshot = TrainingSnapshot {
383            steps_per_epoch: 3,
384            loss_history: vec![5.0, 4.0, 3.0],
385            tokens_per_second: 200.0,
386            learning_rate: 0.001,
387            ..Default::default()
388        };
389        let table = render_history_table(&snapshot, 80, 10, ColorMode::Mono);
390        assert!(table.contains("Epoch"));
391    }
392
393    // ── history_trend tests ────────────────────────────────────────
394
395    #[test]
396    fn test_history_trend_first_epoch() {
397        let summary = EpochSummary {
398            epoch: 1,
399            avg_loss: 5.0,
400            min_loss: 4.0,
401            max_loss: 6.0,
402            end_loss: 4.5,
403            avg_grad: 1.0,
404            lr: 0.001,
405            tokens_per_sec: 100.0,
406        };
407        let summaries = vec![summary.clone()];
408        let (arrow, _color) = history_trend(0, 0, &summary, &summaries, ColorMode::Mono);
409        assert_eq!(arrow, ""); // first epoch, no trend
410    }
411
412    #[test]
413    fn test_history_trend_decreasing() {
414        let summaries = vec![
415            EpochSummary {
416                epoch: 1,
417                avg_loss: 5.0,
418                min_loss: 4.0,
419                max_loss: 6.0,
420                end_loss: 4.5,
421                avg_grad: 1.0,
422                lr: 0.001,
423                tokens_per_sec: 100.0,
424            },
425            EpochSummary {
426                epoch: 2,
427                avg_loss: 3.0,
428                min_loss: 2.5,
429                max_loss: 3.5,
430                end_loss: 2.8,
431                avg_grad: 0.8,
432                lr: 0.001,
433                tokens_per_sec: 100.0,
434            },
435        ];
436        let (arrow, color) = history_trend(1, 0, &summaries[1], &summaries, ColorMode::Mono);
437        assert_eq!(arrow, "\u{2193}"); // down arrow for decreasing loss
438        assert_eq!(color, (100, 255, 100)); // green
439    }
440
441    #[test]
442    fn test_history_trend_increasing() {
443        let summaries = vec![
444            EpochSummary {
445                epoch: 1,
446                avg_loss: 3.0,
447                min_loss: 2.5,
448                max_loss: 3.5,
449                end_loss: 2.8,
450                avg_grad: 1.0,
451                lr: 0.001,
452                tokens_per_sec: 100.0,
453            },
454            EpochSummary {
455                epoch: 2,
456                avg_loss: 5.0,
457                min_loss: 4.0,
458                max_loss: 6.0,
459                end_loss: 4.5,
460                avg_grad: 0.8,
461                lr: 0.001,
462                tokens_per_sec: 100.0,
463            },
464        ];
465        let (arrow, color) = history_trend(1, 0, &summaries[1], &summaries, ColorMode::Mono);
466        assert_eq!(arrow, "\u{2191}"); // up arrow for increasing loss
467        assert_eq!(color, (255, 100, 100)); // red
468    }
469
470    #[test]
471    fn test_history_trend_stable() {
472        let summaries = vec![
473            EpochSummary {
474                epoch: 1,
475                avg_loss: 5.0,
476                min_loss: 4.0,
477                max_loss: 6.0,
478                end_loss: 4.5,
479                avg_grad: 1.0,
480                lr: 0.001,
481                tokens_per_sec: 100.0,
482            },
483            EpochSummary {
484                epoch: 2,
485                avg_loss: 5.01,
486                min_loss: 4.0,
487                max_loss: 6.0,
488                end_loss: 4.5,
489                avg_grad: 0.8,
490                lr: 0.001,
491                tokens_per_sec: 100.0,
492            },
493        ];
494        let (arrow, color) = history_trend(1, 0, &summaries[1], &summaries, ColorMode::Mono);
495        assert_eq!(arrow, "\u{2192}"); // right arrow for stable
496        assert_eq!(color, (150, 150, 150)); // grey
497    }
498}