burn 0.3.0

BURN: Burn Unstoppable Rusty Neurons
Documentation
use super::{DashboardMetricState, DashboardRenderer, TextPlot, TrainingProgress};
use indicatif::{MultiProgress, ProgressBar, ProgressState, ProgressStyle};
use std::{collections::HashMap, fmt::Write};

static MAX_REFRESH_RATE_MILLIS: u128 = 50;

pub struct CLIDashboardRenderer {
    pb_epoch: ProgressBar,
    pb_iteration: ProgressBar,
    last_update: std::time::Instant,
    progress: TrainingProgress,
    metric_train: HashMap<String, String>,
    metric_valid: HashMap<String, String>,
    metric_both_plot: HashMap<String, TextPlot>,
    metric_train_plot: HashMap<String, TextPlot>,
    metric_valid_plot: HashMap<String, TextPlot>,
}

impl Default for CLIDashboardRenderer {
    fn default() -> Self {
        CLIDashboardRenderer::new()
    }
}

impl Drop for CLIDashboardRenderer {
    fn drop(&mut self) {
        self.pb_iteration.finish();
        self.pb_epoch.finish();
    }
}

impl DashboardRenderer for CLIDashboardRenderer {
    fn update_train(&mut self, state: DashboardMetricState) {
        match state {
            DashboardMetricState::Generic(state) => {
                self.metric_train.insert(state.name(), state.pretty());
            }
            DashboardMetricState::Numeric(state, value) => {
                self.metric_train.insert(state.name(), state.pretty());

                let name = state.name();
                if let Some(mut plot) = self.text_plot_in_both(&name) {
                    plot.update_train(value as f32);
                    self.metric_both_plot.insert(name, plot);
                    return;
                }

                if let Some(plot) = self.metric_train_plot.get_mut(&name) {
                    plot.update_train(value as f32);
                } else {
                    let mut plot = TextPlot::new();
                    plot.update_train(value as f32);
                    self.metric_train_plot.insert(state.name(), plot);
                }
            }
        };
    }

    fn update_valid(&mut self, state: DashboardMetricState) {
        match state {
            DashboardMetricState::Generic(state) => {
                self.metric_valid.insert(state.name(), state.pretty());
            }
            DashboardMetricState::Numeric(state, value) => {
                self.metric_valid.insert(state.name(), state.pretty());

                let name = state.name();
                if let Some(mut plot) = self.text_plot_in_both(&name) {
                    plot.update_valid(value as f32);
                    self.metric_both_plot.insert(name, plot);
                    return;
                }

                if let Some(plot) = self.metric_valid_plot.get_mut(&name) {
                    plot.update_valid(value as f32);
                } else {
                    let mut plot = TextPlot::new();
                    plot.update_valid(value as f32);
                    self.metric_valid_plot.insert(state.name(), plot);
                }
            }
        };
    }

    fn render_train(&mut self, item: TrainingProgress) {
        self.progress = item;
        self.render();
    }

    fn render_valid(&mut self, item: TrainingProgress) {
        self.progress = item;
        self.render();
    }
}

impl CLIDashboardRenderer {
    pub fn new() -> Self {
        let pb = MultiProgress::new();
        let pb_epoch = ProgressBar::new(0);
        let pb_iteration = ProgressBar::new(0);

        let pb_iteration = pb.add(pb_iteration);
        let pb_epoch = pb.add(pb_epoch);

        Self {
            pb_epoch,
            pb_iteration,
            last_update: std::time::Instant::now(),
            progress: TrainingProgress::none(),
            metric_train: HashMap::new(),
            metric_valid: HashMap::new(),
            metric_both_plot: HashMap::new(),
            metric_train_plot: HashMap::new(),
            metric_valid_plot: HashMap::new(),
        }
    }

    fn text_plot_in_both(&mut self, key: &str) -> Option<TextPlot> {
        if let Some(plot) = self.metric_both_plot.remove(key) {
            return Some(plot);
        }
        if self.metric_train_plot.contains_key(key) && self.metric_valid_plot.contains_key(key) {
            let plot_train = self.metric_train_plot.remove(key).unwrap();
            let plot_valid = self.metric_valid_plot.remove(key).unwrap();

            return Some(plot_train.merge(plot_valid));
        }

        None
    }

    fn register_template_plots(&self, template: String) -> String {
        let mut template = template;
        let mut metrics_keys = Vec::new();

        for (name, metric) in self.metric_both_plot.iter() {
            metrics_keys.push(format!(
                "  - {} RED: train | BLUE: valid \n{}",
                name,
                metric.render()
            ));
        }
        for (name, metric) in self.metric_train_plot.iter() {
            metrics_keys.push(format!("  - Train {}: \n{}", name, metric.render()));
        }
        for (name, metric) in self.metric_valid_plot.iter() {
            metrics_keys.push(format!("  - Valid {}: \n{}", name, metric.render()));
        }

        if !metrics_keys.is_empty() {
            let metrics_template = metrics_keys.join("\n");
            template += format!("{}\n{}\n", PLOTS_TAG, metrics_template).as_str();
        }

        template
    }
    fn register_template_metrics(&self, template: String) -> String {
        let mut template = template;
        let mut metrics_keys = Vec::new();

        for (name, metric) in self.metric_train.iter() {
            metrics_keys.push(format!("  - Train {}: {}", name, metric));
        }
        for (name, metric) in self.metric_valid.iter() {
            metrics_keys.push(format!("  - Valid {}: {}", name, metric));
        }

        if !metrics_keys.is_empty() {
            let metrics_template = metrics_keys.join("\n");
            template += format!("{}\n{}\n", METRICS_TAG, metrics_template).as_str();
        }

        template
    }

    fn register_style_progress(
        &self,
        name: &'static str,
        style: ProgressStyle,
        value: String,
    ) -> ProgressStyle {
        self.register_key_item(name, style, name.to_string(), value)
    }

    fn register_template_progress(&self, progress: &str, template: String) -> String {
        let mut template = template;

        let bar = "[{wide_bar:.cyan/blue}] ({eta})";
        template += format!("  - {} {}", progress, bar).as_str();
        template
    }

    fn render(&mut self) {
        if std::time::Instant::now()
            .duration_since(self.last_update)
            .as_millis()
            < MAX_REFRESH_RATE_MILLIS
        {
            return;
        }

        let template = self.register_template_plots(String::default());
        let template = self.register_template_metrics(template);
        let template = template
            + format!(
                "\n{}\n  - Iteration {} Epoch {}/{}\n",
                PROGRESS_TAG,
                self.progress.iteration,
                self.progress.epoch,
                self.progress.epoch_total
            )
            .as_str();

        let template = self.register_template_progress("iteration", template);
        let style_iteration = ProgressStyle::with_template(&template).unwrap();
        let style_iteration = self.register_style_progress(
            "iteration",
            style_iteration,
            format!("{}", self.progress.iteration),
        );

        let template = self.register_template_progress("epoch    ", String::default());
        let style_epoch = ProgressStyle::with_template(&template).unwrap();
        let style_epoch =
            self.register_style_progress("epoch", style_epoch, format!("{}", self.progress.epoch));

        self.pb_iteration
            .set_style(style_iteration.progress_chars("#>-"));
        self.pb_iteration
            .set_position(self.progress.progress.items_processed as u64);
        self.pb_iteration
            .set_length(self.progress.progress.items_total as u64);

        self.pb_epoch.set_style(style_epoch.progress_chars("#>-"));
        self.pb_epoch.set_position(self.progress.epoch as u64);
        self.pb_epoch.set_length(self.progress.epoch_total as u64);

        self.last_update = std::time::Instant::now();
    }

    pub fn register_key_item(
        &self,
        key: &'static str,
        style: ProgressStyle,
        name: String,
        formatted: String,
    ) -> ProgressStyle {
        style.with_key(key, move |_state: &ProgressState, w: &mut dyn Write| {
            write!(w, "{}: {}", name, formatted).unwrap()
        })
    }
}

static METRICS_TAG: &str = "[Metrics]";
static PLOTS_TAG: &str = "[Plots]";
static PROGRESS_TAG: &str = "[Progress]";