Skip to main content

entrenar/train/tui/charts/
loss_curve.rs

1//! Loss curve display wrapper for trueno-viz.
2
3use trueno_viz::output::{TerminalEncoder, TerminalMode as TruenoTerminalMode};
4use trueno_viz::plots::{LossCurve, MetricSeries};
5use trueno_viz::prelude::{Rgba, WithDimensions};
6
7/// Helper to set dimensions on a LossCurve builder.
8fn with_dimensions(mut curve: LossCurve, width: u32, height: u32) -> LossCurve {
9    curve.set_dimensions(width, height);
10    curve
11}
12
13use crate::train::tui::capability::TerminalMode;
14
15/// Summary of a metric series: (name, min_value, last_smoothed, best_epoch).
16pub type SeriesSummaryTuple = (String, Option<f32>, Option<f32>, Option<usize>);
17
18/// Wrapper for trueno-viz LossCurve with terminal output support.
19///
20/// Provides streaming loss curve visualization with:
21/// - Train and validation loss tracking
22/// - Exponential moving average smoothing
23/// - Best value markers
24/// - ASCII/Unicode/ANSI terminal rendering modes
25///
26/// # Example
27///
28/// ```no_run
29/// use entrenar::train::tui::LossCurveDisplay;
30///
31/// let mut display = LossCurveDisplay::new(80, 20);
32/// display.push_train_loss(1.0);
33/// display.push_val_loss(1.2);
34/// println!("{}", display.render_terminal());
35/// ```
36pub struct LossCurveDisplay {
37    loss_curve: LossCurve,
38    width: u32,
39    height: u32,
40    pub(crate) terminal_mode: TerminalMode,
41}
42
43impl LossCurveDisplay {
44    /// Create a new loss curve display.
45    pub fn new(width: u32, height: u32) -> Self {
46        let loss_curve = with_dimensions(
47            LossCurve::new()
48                .add_series(MetricSeries::new("Train", Rgba::rgb(66, 133, 244)))
49                .add_series(MetricSeries::new("Val", Rgba::rgb(255, 128, 0))),
50            width,
51            height,
52        )
53        .margin(2)
54        .best_markers(true)
55        .lower_is_better(true)
56        .build()
57        .expect("LossCurve build should succeed");
58        Self { loss_curve, width, height, terminal_mode: TerminalMode::Unicode }
59    }
60
61    /// Set terminal rendering mode.
62    pub fn terminal_mode(mut self, mode: TerminalMode) -> Self {
63        self.terminal_mode = mode;
64        self
65    }
66
67    /// Set smoothing factor (0.0 = none, 0.99 = heavy).
68    pub fn smoothing(mut self, factor: f32) -> Self {
69        // Re-create with smoothing applied
70        self.loss_curve = with_dimensions(
71            LossCurve::new()
72                .add_series(MetricSeries::new("Train", Rgba::rgb(66, 133, 244)).smoothing(factor))
73                .add_series(MetricSeries::new("Val", Rgba::rgb(255, 128, 0)).smoothing(factor)),
74            self.width,
75            self.height,
76        )
77        .margin(2)
78        .best_markers(true)
79        .lower_is_better(true)
80        .build()
81        .expect("LossCurve build should succeed");
82        self
83    }
84
85    /// Push a training loss value.
86    pub fn push_train_loss(&mut self, value: f32) {
87        self.loss_curve.push(0, value);
88    }
89
90    /// Push a validation loss value.
91    pub fn push_val_loss(&mut self, value: f32) {
92        self.loss_curve.push(1, value);
93    }
94
95    /// Push both train and val loss at once.
96    pub fn push_losses(&mut self, train: f32, val: f32) {
97        self.loss_curve.push_all(&[train, val]);
98    }
99
100    /// Get the number of epochs recorded.
101    pub fn epochs(&self) -> usize {
102        self.loss_curve.max_epochs()
103    }
104
105    /// Get summary of all series.
106    pub fn summary(&self) -> Vec<SeriesSummaryTuple> {
107        self.loss_curve
108            .summary()
109            .into_iter()
110            .map(|s| (s.name, s.min, s.last_smoothed, s.best_epoch))
111            .collect()
112    }
113
114    /// Render to terminal string.
115    pub fn render_terminal(&self) -> String {
116        if self.loss_curve.max_epochs() < 2 {
117            return String::from("(waiting for data...)");
118        }
119
120        let fb = match self.loss_curve.to_framebuffer() {
121            Ok(fb) => fb,
122            Err(_) => return String::from("(render error)"),
123        };
124
125        let trueno_mode = match self.terminal_mode {
126            TerminalMode::Ascii => TruenoTerminalMode::Ascii,
127            TerminalMode::Unicode => TruenoTerminalMode::UnicodeHalfBlock,
128            TerminalMode::Ansi => TruenoTerminalMode::AnsiTrueColor,
129        };
130
131        let encoder =
132            TerminalEncoder::new().mode(trueno_mode).width(self.width).height(self.height / 2); // Terminal chars are ~2:1 aspect
133
134        encoder.render(&fb)
135    }
136
137    /// Print to stdout.
138    pub fn print(&self) {
139        println!("{}", self.render_terminal());
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146
147    #[test]
148    fn test_loss_curve_display_new() {
149        let display = LossCurveDisplay::new(80, 20);
150        assert_eq!(display.width, 80);
151        assert_eq!(display.height, 20);
152        assert_eq!(display.terminal_mode, TerminalMode::Unicode);
153    }
154
155    #[test]
156    fn test_loss_curve_display_terminal_mode() {
157        let display = LossCurveDisplay::new(80, 20).terminal_mode(TerminalMode::Ansi);
158        assert_eq!(display.terminal_mode, TerminalMode::Ansi);
159    }
160
161    #[test]
162    fn test_loss_curve_display_smoothing() {
163        let display = LossCurveDisplay::new(80, 20).smoothing(0.9);
164        // Just verify it doesn't panic
165        assert_eq!(display.epochs(), 0);
166    }
167
168    #[test]
169    fn test_loss_curve_display_push_train_loss() {
170        let mut display = LossCurveDisplay::new(80, 20);
171        display.push_train_loss(1.0);
172        display.push_train_loss(0.9);
173        display.push_train_loss(0.8);
174        assert_eq!(display.epochs(), 3);
175    }
176
177    #[test]
178    fn test_loss_curve_display_push_val_loss() {
179        let mut display = LossCurveDisplay::new(80, 20);
180        display.push_val_loss(1.2);
181        display.push_val_loss(1.1);
182        // epochs count max across series
183        assert!(display.epochs() >= 2);
184    }
185
186    #[test]
187    fn test_loss_curve_display_push_losses() {
188        let mut display = LossCurveDisplay::new(80, 20);
189        display.push_losses(1.0, 1.2);
190        display.push_losses(0.9, 1.1);
191        assert!(display.epochs() >= 2);
192    }
193
194    #[test]
195    fn test_loss_curve_display_summary() {
196        let mut display = LossCurveDisplay::new(80, 20);
197        display.push_train_loss(1.0);
198        display.push_train_loss(0.5);
199        display.push_val_loss(1.2);
200        display.push_val_loss(0.6);
201
202        let summary = display.summary();
203        assert_eq!(summary.len(), 2);
204        assert_eq!(summary[0].0, "Train");
205        assert_eq!(summary[1].0, "Val");
206    }
207
208    #[test]
209    fn test_loss_curve_display_render_insufficient_data() {
210        let mut display = LossCurveDisplay::new(80, 20);
211        display.push_train_loss(1.0);
212        // Only 1 data point
213        let rendered = display.render_terminal();
214        assert!(rendered.contains("waiting for data"));
215    }
216
217    #[test]
218    fn test_loss_curve_display_render_with_data() {
219        let mut display = LossCurveDisplay::new(80, 20);
220        for i in 0..10 {
221            display.push_train_loss(1.0 - i as f32 * 0.1);
222        }
223        let rendered = display.render_terminal();
224        // Should contain actual rendered content
225        assert!(!rendered.contains("waiting for data"));
226    }
227
228    #[test]
229    fn test_loss_curve_display_render_ascii_mode() {
230        let mut display = LossCurveDisplay::new(80, 20).terminal_mode(TerminalMode::Ascii);
231        for i in 0..10 {
232            display.push_train_loss(1.0 - i as f32 * 0.1);
233        }
234        let rendered = display.render_terminal();
235        assert!(!rendered.is_empty());
236    }
237
238    #[test]
239    fn test_loss_curve_display_render_ansi_mode() {
240        let mut display = LossCurveDisplay::new(80, 20).terminal_mode(TerminalMode::Ansi);
241        for i in 0..10 {
242            display.push_train_loss(1.0 - i as f32 * 0.1);
243        }
244        let rendered = display.render_terminal();
245        assert!(!rendered.is_empty());
246    }
247}