reinfer-cli 0.38.12

Command line interface for Re:infer, the conversational data intelligence platform
use indicatif::{ProgressBar, ProgressStyle};
use std::fmt::Write;
use std::{
    sync::{
        atomic::{AtomicBool, Ordering},
        Arc,
    },
    thread,
    time::Duration,
};

use crate::utils::LOG_PREFIX_INFO;

pub type ProgressMessage = (u64, String);

pub struct Options {
    pub bytes_units: bool,
}

impl Default for Options {
    fn default() -> Self {
        Options { bytes_units: true }
    }
}

pub struct Progress {
    report_progress_flag: Arc<AtomicBool>,
    progress_thread: Option<thread::JoinHandle<()>>,
}

impl Progress {
    pub fn new<ProgressFnT, StatisticsT>(
        progress_fn: ProgressFnT,
        statistics: &Arc<StatisticsT>,
        target_value: Option<u64>,
        options: Options,
    ) -> Self
    where
        ProgressFnT: Fn(&StatisticsT) -> ProgressMessage + Sync + Send + 'static,
        StatisticsT: Sync + Send + 'static,
    {
        let report_progress_flag = Arc::new(AtomicBool::new(true));
        let progress_thread = spawn_progress_thread(
            Arc::clone(statistics),
            progress_fn,
            target_value,
            options,
            Arc::clone(&report_progress_flag),
        );

        Progress {
            report_progress_flag,
            progress_thread: Some(progress_thread),
        }
    }

    pub fn done(&mut self) {
        if let Some(handle) = self.progress_thread.take() {
            self.report_progress_flag.store(false, Ordering::SeqCst);
            handle.join().expect("Could not join progress thread.");
        }
    }
}

impl Drop for Progress {
    fn drop(&mut self) {
        self.done();
    }
}

fn spawn_progress_thread<Statistics, ProgressFn>(
    statistics: Arc<Statistics>,
    progress_fn: ProgressFn,
    max_progress_value: Option<u64>,
    options: Options,
    report_progress: Arc<AtomicBool>,
) -> thread::JoinHandle<()>
where
    ProgressFn: Fn(&Statistics) -> ProgressMessage + Sync + Send + 'static,
    Statistics: Sync + Send + 'static,
{
    use std::ops::Deref;
    let mut template_str = String::new();
    write!(template_str, "{} ", LOG_PREFIX_INFO.deref()).unwrap();
    write!(template_str, "{{spinner:.green}} ").unwrap();
    write!(template_str, "[{{elapsed_precise}}] {{prefix}} ").unwrap();

    match (max_progress_value.is_some(), options.bytes_units) {
        (true, true) => write!(
            template_str,
            "{{bar:32.cyan/blue}} {{bytes}} / {{total_bytes}} ({{eta}})"
        )
        .unwrap(),
        (true, false) => write!(template_str, "{{bar:32.cyan/blue}} {{msg}} ({{eta}})").unwrap(),
        _ => write!(template_str, "{{msg}}").unwrap(),
    }

    let progress_bar = ProgressBar::new(max_progress_value.unwrap_or(0));
    progress_bar.set_style(
        ProgressStyle::default_bar()
            .template(&template_str)
            .expect("invalid template string")
            .progress_chars("#>-"),
    );

    thread::spawn(move || {
        let progress_fn = progress_fn;
        let statistics = Arc::clone(&statistics);
        let sleep_duration = Duration::from_millis(100);

        while report_progress.load(Ordering::SeqCst) {
            thread::sleep(sleep_duration);
            let (progress_value, message) = progress_fn(&statistics);
            progress_bar.set_position(progress_value);
            progress_bar.set_prefix(message);
            match max_progress_value {
                Some(value) => progress_bar.set_message(format!("{progress_value} / {value}")),
                None => progress_bar.set_message(format!("{progress_value}")),
            };
        }

        progress_bar.finish_and_clear();
        eprint!("\r");
    })
}