use crate::{
metric::{MetricName, NumericEntry},
renderer::{EvaluationProgress, TrainingProgress, tui::TuiTag},
};
use super::{FullHistoryPlot, RecentHistoryPlot, TerminalFrame, TuiSplit};
use ratatui::{
crossterm::event::{Event, KeyCode, KeyEventKind},
prelude::{Alignment, Constraint, Direction, Layout, Rect},
style::{Color, Modifier, Style, Stylize},
text::Line,
widgets::{
Axis, BarChart, BarGroup, Block, Borders, Chart, LegendPosition, Padding, Paragraph, Tabs,
},
};
use std::collections::BTreeMap;
const MAX_NUM_SAMPLES_RECENT: usize = 1000;
const MAX_NUM_SAMPLES_FULL: usize = 250;
#[derive(Default)]
pub(crate) struct NumericMetricsState {
data: BTreeMap<MetricName, (RecentHistoryPlot, FullHistoryPlot)>,
names: Vec<MetricName>,
selected: usize,
kind: PlotKind,
num_samples_train: Option<usize>,
num_samples_valid: Option<usize>,
num_samples_test: Option<usize>,
epoch: usize,
}
#[derive(Default, Clone, Copy)]
pub(crate) enum PlotKind {
#[default]
Full,
Recent,
Summary,
}
impl NumericMetricsState {
pub(crate) fn push(&mut self, tag: TuiTag, name: MetricName, data: NumericEntry) {
if let Some((recent, full)) = self.data.get_mut(name.as_ref()) {
recent.push(tag.clone(), data.current());
full.push(tag, data);
} else {
let mut recent = RecentHistoryPlot::new(MAX_NUM_SAMPLES_RECENT);
let mut full = FullHistoryPlot::new(MAX_NUM_SAMPLES_FULL);
recent.push(tag.clone(), data.current());
full.push(tag, data);
self.names.push(name.clone());
self.data.insert(name, (recent, full));
}
}
pub(crate) fn update_progress_train(&mut self, progress: &TrainingProgress) {
self.epoch = progress.epoch;
if self.num_samples_train.is_some() {
return;
}
self.num_samples_train = Some(progress.progress.items_total);
}
pub(crate) fn update_progress_valid(&mut self, progress: &TrainingProgress) {
if self.num_samples_valid.is_some() {
return;
}
if let Some(num_sample_train) = self.num_samples_train {
for (_, (_recent, full)) in self.data.iter_mut() {
let ratio = progress.progress.items_total as f64 / num_sample_train as f64;
full.update_max_sample(TuiSplit::Valid, ratio);
}
}
self.epoch = progress.epoch;
self.num_samples_valid = Some(progress.progress.items_total);
}
pub(crate) fn update_progress_test(&mut self, progress: &EvaluationProgress) {
if self.num_samples_test.is_some() {
return;
}
if let Some(num_sample_train) = self.num_samples_train {
for (_, (_recent, full)) in self.data.iter_mut() {
let ratio = progress.progress.items_total as f64 / num_sample_train as f64;
full.update_max_sample(TuiSplit::Test, ratio);
}
}
self.num_samples_test = Some(progress.progress.items_total);
}
pub(crate) fn view(&self) -> NumericMetricView<'_> {
match self.names.is_empty() {
true => NumericMetricView::None,
false => match self.kind {
PlotKind::Summary => {
NumericMetricView::BarPlots(&self.names, self.selected, self.bar_chart())
}
_ => NumericMetricView::LinePlots(
&self.names,
self.selected,
self.line_chart(),
self.kind,
),
},
}
}
pub(crate) fn on_event(&mut self, event: &Event) {
if let Event::Key(key) = event {
match key.kind {
KeyEventKind::Release | KeyEventKind::Repeat => (),
#[cfg(target_os = "windows")] KeyEventKind::Press => return,
#[cfg(not(target_os = "windows"))]
KeyEventKind::Press => (),
}
match key.code {
KeyCode::Right => self.next_metric(),
KeyCode::Left => self.previous_metric(),
KeyCode::Up => self.switch_kind(),
KeyCode::Down => self.switch_kind(),
_ => {}
}
}
}
fn switch_kind(&mut self) {
self.kind = match self.kind {
PlotKind::Full => PlotKind::Recent,
PlotKind::Recent => PlotKind::Summary,
PlotKind::Summary => PlotKind::Full,
};
}
fn next_metric(&mut self) {
self.selected = (self.selected + 1) % {
let this = &self;
this.data.len()
};
}
fn previous_metric(&mut self) {
if self.selected > 0 {
self.selected -= 1;
} else {
self.selected = ({
let this = &self;
this.data.len()
}) - 1;
}
}
fn line_chart<'a>(&'a self) -> Chart<'a> {
let name = self.names.get(self.selected).unwrap();
let (recent, full) = self.data.get(name).unwrap();
let (datasets, axes) = match self.kind {
PlotKind::Full => (full.datasets(), &full.axes),
PlotKind::Recent => (recent.datasets(), &recent.axes),
_ => unreachable!(),
};
Chart::<'a>::new(datasets)
.block(Block::default())
.x_axis(
Axis::default()
.style(Style::default().fg(Color::DarkGray))
.title("Iteration")
.labels(axes.labels_x.clone().into_iter().map(|s| s.bold()))
.bounds(axes.bounds_x),
)
.y_axis(
Axis::default()
.style(Style::default().fg(Color::DarkGray))
.labels(axes.labels_y.clone().into_iter().map(|s| s.bold()))
.bounds(axes.bounds_y),
)
.legend_position(Some(LegendPosition::Right))
}
fn bar_chart<'a>(&'a self) -> BarChart<'a> {
let name = self.names.get(self.selected).unwrap();
let (_recent, full) = self.data.get(name).unwrap();
let mut bar_width = 0;
let bars = full.bars(100, &mut bar_width);
let data = BarGroup::default().bars(&bars);
BarChart::default()
.block(Block::default().padding(Padding::new(2, 2, 2, 0)))
.bar_width(bar_width as u16)
.bar_gap(2)
.data(data)
}
}
#[allow(clippy::large_enum_variant)]
#[derive(new)]
pub(crate) enum NumericMetricView<'a> {
LinePlots(&'a [MetricName], usize, Chart<'a>, PlotKind),
BarPlots(&'a [MetricName], usize, BarChart<'a>),
None,
}
impl NumericMetricView<'_> {
pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) {
match self {
Self::LinePlots(titles, selected, chart, kind) => {
let block = Block::default()
.borders(Borders::ALL)
.title("Plots")
.title_alignment(Alignment::Left);
let size_new = block.inner(size);
frame.render_widget(block, size);
let size = size_new;
let chunks = Layout::default()
.direction(Direction::Vertical)
.constraints(
[
Constraint::Length(2),
Constraint::Length(1),
Constraint::Min(0),
]
.as_ref(),
)
.split(size);
let tabs = Tabs::new(
titles
.iter()
.map(|i| Line::from(vec![i.to_string().yellow()])),
)
.select(selected)
.style(Style::default())
.highlight_style(
Style::default()
.add_modifier(Modifier::BOLD)
.add_modifier(Modifier::UNDERLINED)
.fg(Color::LightYellow),
);
let title = match kind {
PlotKind::Full => "Full History",
PlotKind::Recent => "Recent History",
_ => unreachable!(),
};
let plot_type =
Paragraph::new(Line::from(title.bold())).alignment(Alignment::Center);
frame.render_widget(tabs, chunks[0]);
frame.render_widget(plot_type, chunks[1]);
frame.render_widget(chart, chunks[2]);
}
Self::BarPlots(titles, selected, chart) => {
let block = Block::default()
.borders(Borders::ALL)
.title("Summary")
.title_alignment(Alignment::Left);
let size_new = block.inner(size);
frame.render_widget(block, size);
let size = size_new;
let chunks = Layout::default()
.direction(Direction::Vertical)
.constraints([
Constraint::Length(2),
Constraint::Length(1),
Constraint::Min(0),
])
.split(size);
let tabs = Tabs::new(
titles
.iter()
.map(|i| Line::from(vec![i.to_string().yellow()])),
)
.select(selected)
.style(Style::default())
.highlight_style(
Style::default()
.add_modifier(Modifier::BOLD)
.add_modifier(Modifier::UNDERLINED)
.fg(Color::LightYellow),
);
let title = "Summary";
let plot_type =
Paragraph::new(Line::from(title.bold())).alignment(Alignment::Center);
frame.render_widget(tabs, chunks[0]);
frame.render_widget(plot_type, chunks[1]);
frame.render_widget(chart, chunks[2]);
}
Self::None => {}
};
}
}