use trueno_viz::output::{TerminalEncoder, TerminalMode as TruenoTerminalMode};
use trueno_viz::plots::{LossCurve, MetricSeries};
use trueno_viz::prelude::{Rgba, WithDimensions};
fn with_dimensions(mut curve: LossCurve, width: u32, height: u32) -> LossCurve {
curve.set_dimensions(width, height);
curve
}
use crate::train::tui::capability::TerminalMode;
pub type SeriesSummaryTuple = (String, Option<f32>, Option<f32>, Option<usize>);
pub struct LossCurveDisplay {
loss_curve: LossCurve,
width: u32,
height: u32,
pub(crate) terminal_mode: TerminalMode,
}
impl LossCurveDisplay {
pub fn new(width: u32, height: u32) -> Self {
let loss_curve = with_dimensions(
LossCurve::new()
.add_series(MetricSeries::new("Train", Rgba::rgb(66, 133, 244)))
.add_series(MetricSeries::new("Val", Rgba::rgb(255, 128, 0))),
width,
height,
)
.margin(2)
.best_markers(true)
.lower_is_better(true)
.build()
.expect("LossCurve build should succeed");
Self { loss_curve, width, height, terminal_mode: TerminalMode::Unicode }
}
pub fn terminal_mode(mut self, mode: TerminalMode) -> Self {
self.terminal_mode = mode;
self
}
pub fn smoothing(mut self, factor: f32) -> Self {
self.loss_curve = with_dimensions(
LossCurve::new()
.add_series(MetricSeries::new("Train", Rgba::rgb(66, 133, 244)).smoothing(factor))
.add_series(MetricSeries::new("Val", Rgba::rgb(255, 128, 0)).smoothing(factor)),
self.width,
self.height,
)
.margin(2)
.best_markers(true)
.lower_is_better(true)
.build()
.expect("LossCurve build should succeed");
self
}
pub fn push_train_loss(&mut self, value: f32) {
self.loss_curve.push(0, value);
}
pub fn push_val_loss(&mut self, value: f32) {
self.loss_curve.push(1, value);
}
pub fn push_losses(&mut self, train: f32, val: f32) {
self.loss_curve.push_all(&[train, val]);
}
pub fn epochs(&self) -> usize {
self.loss_curve.max_epochs()
}
pub fn summary(&self) -> Vec<SeriesSummaryTuple> {
self.loss_curve
.summary()
.into_iter()
.map(|s| (s.name, s.min, s.last_smoothed, s.best_epoch))
.collect()
}
pub fn render_terminal(&self) -> String {
if self.loss_curve.max_epochs() < 2 {
return String::from("(waiting for data...)");
}
let fb = match self.loss_curve.to_framebuffer() {
Ok(fb) => fb,
Err(_) => return String::from("(render error)"),
};
let trueno_mode = match self.terminal_mode {
TerminalMode::Ascii => TruenoTerminalMode::Ascii,
TerminalMode::Unicode => TruenoTerminalMode::UnicodeHalfBlock,
TerminalMode::Ansi => TruenoTerminalMode::AnsiTrueColor,
};
let encoder =
TerminalEncoder::new().mode(trueno_mode).width(self.width).height(self.height / 2);
encoder.render(&fb)
}
pub fn print(&self) {
println!("{}", self.render_terminal());
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_loss_curve_display_new() {
let display = LossCurveDisplay::new(80, 20);
assert_eq!(display.width, 80);
assert_eq!(display.height, 20);
assert_eq!(display.terminal_mode, TerminalMode::Unicode);
}
#[test]
fn test_loss_curve_display_terminal_mode() {
let display = LossCurveDisplay::new(80, 20).terminal_mode(TerminalMode::Ansi);
assert_eq!(display.terminal_mode, TerminalMode::Ansi);
}
#[test]
fn test_loss_curve_display_smoothing() {
let display = LossCurveDisplay::new(80, 20).smoothing(0.9);
assert_eq!(display.epochs(), 0);
}
#[test]
fn test_loss_curve_display_push_train_loss() {
let mut display = LossCurveDisplay::new(80, 20);
display.push_train_loss(1.0);
display.push_train_loss(0.9);
display.push_train_loss(0.8);
assert_eq!(display.epochs(), 3);
}
#[test]
fn test_loss_curve_display_push_val_loss() {
let mut display = LossCurveDisplay::new(80, 20);
display.push_val_loss(1.2);
display.push_val_loss(1.1);
assert!(display.epochs() >= 2);
}
#[test]
fn test_loss_curve_display_push_losses() {
let mut display = LossCurveDisplay::new(80, 20);
display.push_losses(1.0, 1.2);
display.push_losses(0.9, 1.1);
assert!(display.epochs() >= 2);
}
#[test]
fn test_loss_curve_display_summary() {
let mut display = LossCurveDisplay::new(80, 20);
display.push_train_loss(1.0);
display.push_train_loss(0.5);
display.push_val_loss(1.2);
display.push_val_loss(0.6);
let summary = display.summary();
assert_eq!(summary.len(), 2);
assert_eq!(summary[0].0, "Train");
assert_eq!(summary[1].0, "Val");
}
#[test]
fn test_loss_curve_display_render_insufficient_data() {
let mut display = LossCurveDisplay::new(80, 20);
display.push_train_loss(1.0);
let rendered = display.render_terminal();
assert!(rendered.contains("waiting for data"));
}
#[test]
fn test_loss_curve_display_render_with_data() {
let mut display = LossCurveDisplay::new(80, 20);
for i in 0..10 {
display.push_train_loss(1.0 - i as f32 * 0.1);
}
let rendered = display.render_terminal();
assert!(!rendered.contains("waiting for data"));
}
#[test]
fn test_loss_curve_display_render_ascii_mode() {
let mut display = LossCurveDisplay::new(80, 20).terminal_mode(TerminalMode::Ascii);
for i in 0..10 {
display.push_train_loss(1.0 - i as f32 * 0.1);
}
let rendered = display.render_terminal();
assert!(!rendered.is_empty());
}
#[test]
fn test_loss_curve_display_render_ansi_mode() {
let mut display = LossCurveDisplay::new(80, 20).terminal_mode(TerminalMode::Ansi);
for i in 0..10 {
display.push_train_loss(1.0 - i as f32 * 0.1);
}
let rendered = display.render_terminal();
assert!(!rendered.is_empty());
}
}