entrenar/train/tui/charts/
loss_curve.rs1use trueno_viz::output::{TerminalEncoder, TerminalMode as TruenoTerminalMode};
4use trueno_viz::plots::{LossCurve, MetricSeries};
5use trueno_viz::prelude::{Rgba, WithDimensions};
6
7fn with_dimensions(mut curve: LossCurve, width: u32, height: u32) -> LossCurve {
9 curve.set_dimensions(width, height);
10 curve
11}
12
13use crate::train::tui::capability::TerminalMode;
14
15pub type SeriesSummaryTuple = (String, Option<f32>, Option<f32>, Option<usize>);
17
18pub struct LossCurveDisplay {
37 loss_curve: LossCurve,
38 width: u32,
39 height: u32,
40 pub(crate) terminal_mode: TerminalMode,
41}
42
43impl LossCurveDisplay {
44 pub fn new(width: u32, height: u32) -> Self {
46 let loss_curve = with_dimensions(
47 LossCurve::new()
48 .add_series(MetricSeries::new("Train", Rgba::rgb(66, 133, 244)))
49 .add_series(MetricSeries::new("Val", Rgba::rgb(255, 128, 0))),
50 width,
51 height,
52 )
53 .margin(2)
54 .best_markers(true)
55 .lower_is_better(true)
56 .build()
57 .expect("LossCurve build should succeed");
58 Self { loss_curve, width, height, terminal_mode: TerminalMode::Unicode }
59 }
60
61 pub fn terminal_mode(mut self, mode: TerminalMode) -> Self {
63 self.terminal_mode = mode;
64 self
65 }
66
67 pub fn smoothing(mut self, factor: f32) -> Self {
69 self.loss_curve = with_dimensions(
71 LossCurve::new()
72 .add_series(MetricSeries::new("Train", Rgba::rgb(66, 133, 244)).smoothing(factor))
73 .add_series(MetricSeries::new("Val", Rgba::rgb(255, 128, 0)).smoothing(factor)),
74 self.width,
75 self.height,
76 )
77 .margin(2)
78 .best_markers(true)
79 .lower_is_better(true)
80 .build()
81 .expect("LossCurve build should succeed");
82 self
83 }
84
85 pub fn push_train_loss(&mut self, value: f32) {
87 self.loss_curve.push(0, value);
88 }
89
90 pub fn push_val_loss(&mut self, value: f32) {
92 self.loss_curve.push(1, value);
93 }
94
95 pub fn push_losses(&mut self, train: f32, val: f32) {
97 self.loss_curve.push_all(&[train, val]);
98 }
99
100 pub fn epochs(&self) -> usize {
102 self.loss_curve.max_epochs()
103 }
104
105 pub fn summary(&self) -> Vec<SeriesSummaryTuple> {
107 self.loss_curve
108 .summary()
109 .into_iter()
110 .map(|s| (s.name, s.min, s.last_smoothed, s.best_epoch))
111 .collect()
112 }
113
114 pub fn render_terminal(&self) -> String {
116 if self.loss_curve.max_epochs() < 2 {
117 return String::from("(waiting for data...)");
118 }
119
120 let fb = match self.loss_curve.to_framebuffer() {
121 Ok(fb) => fb,
122 Err(_) => return String::from("(render error)"),
123 };
124
125 let trueno_mode = match self.terminal_mode {
126 TerminalMode::Ascii => TruenoTerminalMode::Ascii,
127 TerminalMode::Unicode => TruenoTerminalMode::UnicodeHalfBlock,
128 TerminalMode::Ansi => TruenoTerminalMode::AnsiTrueColor,
129 };
130
131 let encoder =
132 TerminalEncoder::new().mode(trueno_mode).width(self.width).height(self.height / 2); encoder.render(&fb)
135 }
136
137 pub fn print(&self) {
139 println!("{}", self.render_terminal());
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146
147 #[test]
148 fn test_loss_curve_display_new() {
149 let display = LossCurveDisplay::new(80, 20);
150 assert_eq!(display.width, 80);
151 assert_eq!(display.height, 20);
152 assert_eq!(display.terminal_mode, TerminalMode::Unicode);
153 }
154
155 #[test]
156 fn test_loss_curve_display_terminal_mode() {
157 let display = LossCurveDisplay::new(80, 20).terminal_mode(TerminalMode::Ansi);
158 assert_eq!(display.terminal_mode, TerminalMode::Ansi);
159 }
160
161 #[test]
162 fn test_loss_curve_display_smoothing() {
163 let display = LossCurveDisplay::new(80, 20).smoothing(0.9);
164 assert_eq!(display.epochs(), 0);
166 }
167
168 #[test]
169 fn test_loss_curve_display_push_train_loss() {
170 let mut display = LossCurveDisplay::new(80, 20);
171 display.push_train_loss(1.0);
172 display.push_train_loss(0.9);
173 display.push_train_loss(0.8);
174 assert_eq!(display.epochs(), 3);
175 }
176
177 #[test]
178 fn test_loss_curve_display_push_val_loss() {
179 let mut display = LossCurveDisplay::new(80, 20);
180 display.push_val_loss(1.2);
181 display.push_val_loss(1.1);
182 assert!(display.epochs() >= 2);
184 }
185
186 #[test]
187 fn test_loss_curve_display_push_losses() {
188 let mut display = LossCurveDisplay::new(80, 20);
189 display.push_losses(1.0, 1.2);
190 display.push_losses(0.9, 1.1);
191 assert!(display.epochs() >= 2);
192 }
193
194 #[test]
195 fn test_loss_curve_display_summary() {
196 let mut display = LossCurveDisplay::new(80, 20);
197 display.push_train_loss(1.0);
198 display.push_train_loss(0.5);
199 display.push_val_loss(1.2);
200 display.push_val_loss(0.6);
201
202 let summary = display.summary();
203 assert_eq!(summary.len(), 2);
204 assert_eq!(summary[0].0, "Train");
205 assert_eq!(summary[1].0, "Val");
206 }
207
208 #[test]
209 fn test_loss_curve_display_render_insufficient_data() {
210 let mut display = LossCurveDisplay::new(80, 20);
211 display.push_train_loss(1.0);
212 let rendered = display.render_terminal();
214 assert!(rendered.contains("waiting for data"));
215 }
216
217 #[test]
218 fn test_loss_curve_display_render_with_data() {
219 let mut display = LossCurveDisplay::new(80, 20);
220 for i in 0..10 {
221 display.push_train_loss(1.0 - i as f32 * 0.1);
222 }
223 let rendered = display.render_terminal();
224 assert!(!rendered.contains("waiting for data"));
226 }
227
228 #[test]
229 fn test_loss_curve_display_render_ascii_mode() {
230 let mut display = LossCurveDisplay::new(80, 20).terminal_mode(TerminalMode::Ascii);
231 for i in 0..10 {
232 display.push_train_loss(1.0 - i as f32 * 0.1);
233 }
234 let rendered = display.render_terminal();
235 assert!(!rendered.is_empty());
236 }
237
238 #[test]
239 fn test_loss_curve_display_render_ansi_mode() {
240 let mut display = LossCurveDisplay::new(80, 20).terminal_mode(TerminalMode::Ansi);
241 for i in 0..10 {
242 display.push_train_loss(1.0 - i as f32 * 0.1);
243 }
244 let rendered = display.render_terminal();
245 assert!(!rendered.is_empty());
246 }
247}