use crate::tpch_cli::output_plan::OutputPlan;
use crate::tpch_cli::Table;
use std::collections::BTreeMap;
use std::fmt;
use std::sync::Arc;
pub trait ProgressTracker: Send + Sync + fmt::Debug {
fn register(&self, _table: Table, _total_units: u64) {}
fn increment(&self, table: Table, units: u64);
fn finish(&self) {}
}
#[derive(Debug, Clone, Default)]
pub(crate) struct RunProgress {
tracker: Option<Arc<dyn ProgressTracker>>,
}
impl RunProgress {
pub(crate) fn with_tracker(tracker: Arc<dyn ProgressTracker>) -> Self {
Self {
tracker: Some(tracker),
}
}
pub(crate) fn register_totals(&self, plans: &[OutputPlan]) {
if let Some(tracker) = self.tracker.as_ref() {
let mut totals: BTreeMap<Table, u64> = BTreeMap::new();
for plan in plans {
*totals.entry(plan.table()).or_insert(0) += plan.chunk_count() as u64;
}
for (table, total) in totals {
tracker.register(table, total);
}
}
}
pub(crate) fn increment_for_existing(&self, plan: &OutputPlan) {
if let Some(tracker) = self.tracker.as_ref() {
tracker.increment(plan.table(), plan.chunk_count() as u64);
}
}
pub(crate) fn for_table(&self, table: Table) -> TableProgress {
TableProgress::for_table(self.tracker.clone(), table)
}
pub(crate) fn finish(self) {
if let Some(tracker) = self.tracker {
tracker.finish();
}
}
}
#[derive(Clone, Default)]
pub(crate) struct TableProgress {
tracker: Option<(Arc<dyn ProgressTracker>, Table)>,
}
impl TableProgress {
pub(crate) fn for_table(progress: Option<Arc<dyn ProgressTracker>>, table: Table) -> Self {
Self {
tracker: progress.map(|progress| (progress, table)),
}
}
pub(crate) fn increment_output_unit(&self) {
if let Some((progress, table)) = self.tracker.as_ref() {
progress.increment(*table, 1);
}
}
}
#[cfg(feature = "indicatif-progress")]
pub use indicatif_impl::IndicatifProgress;
#[cfg(feature = "indicatif-progress")]
mod indicatif_impl {
use super::ProgressTracker;
use crate::tpch_cli::Table;
use indicatif::{MultiProgress, ProgressBar, ProgressFinish, ProgressStyle};
use std::collections::BTreeMap;
use std::io::{self, Write};
use std::sync::{OnceLock, RwLock};
#[derive(Debug)]
pub struct IndicatifProgress {
multi: MultiProgress,
tables: RwLock<BTreeMap<Table, ProgressBar>>,
}
impl IndicatifProgress {
pub fn new() -> Self {
Self {
multi: MultiProgress::new(),
tables: RwLock::new(BTreeMap::new()),
}
}
pub fn log_writer(&self) -> Box<dyn io::Write + Send + 'static> {
Box::new(IndicatifLogWriter {
multi: self.multi.clone(),
})
}
}
impl Default for IndicatifProgress {
fn default() -> Self {
Self::new()
}
}
impl ProgressTracker for IndicatifProgress {
fn register(&self, table: Table, total_units: u64) {
let Ok(mut tables) = self.tables.write() else {
return;
};
let pb = self.multi.add(ProgressBar::new(total_units));
pb.set_style(bar_style().clone());
pb.set_message(table.to_string());
let pb = pb.with_finish(ProgressFinish::AndLeave);
tables.insert(table, pb);
}
fn increment(&self, table: Table, units: u64) {
let bar = {
let Ok(tables) = self.tables.read() else {
return;
};
tables.get(&table).cloned()
};
if let Some(bar) = bar {
bar.inc(units);
}
}
fn finish(&self) {
let bars = {
let Ok(tables) = self.tables.read() else {
return;
};
tables.values().cloned().collect::<Vec<_>>()
};
for bar in bars {
bar.finish_using_style();
}
}
}
fn bar_style() -> &'static ProgressStyle {
static STYLE: OnceLock<ProgressStyle> = OnceLock::new();
STYLE.get_or_init(|| {
ProgressStyle::default_bar()
.template("{msg:10} [{bar:28}] Progress: {percent:>3}%")
.expect("static progress bar template is valid")
.progress_chars("█▓░")
})
}
struct IndicatifLogWriter {
multi: MultiProgress,
}
impl Write for IndicatifLogWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.multi.suspend(|| {
let mut stderr = io::stderr().lock();
stderr.write(buf)
})
}
fn flush(&mut self) -> io::Result<()> {
self.multi.suspend(|| {
let mut stderr = io::stderr().lock();
stderr.flush()
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn registers_and_increments() {
let t = IndicatifProgress::new();
t.register(Table::Lineitem, 60);
t.register(Table::Orders, 15);
t.increment(Table::Lineitem, 1);
t.increment(Table::Orders, 5);
let tables = t.tables.read().unwrap();
assert_eq!(tables[&Table::Lineitem].position(), 1);
assert_eq!(tables[&Table::Orders].position(), 5);
}
#[test]
fn reaches_total() {
let t = IndicatifProgress::new();
t.register(Table::Orders, 5);
for _ in 0..5 {
t.increment(Table::Orders, 1);
}
assert_eq!(t.tables.read().unwrap()[&Table::Orders].position(), 5);
}
#[test]
fn unknown_table_is_no_op() {
let t = IndicatifProgress::new();
t.register(Table::Orders, 1);
t.increment(Table::Lineitem, 1);
assert_eq!(t.tables.read().unwrap()[&Table::Orders].position(), 0);
}
#[test]
fn finish_marks_registered_bars_finished() {
let t = IndicatifProgress::new();
t.register(Table::Orders, 2);
t.increment(Table::Orders, 2);
t.finish();
assert!(t.tables.read().unwrap()[&Table::Orders].is_finished());
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tpch_cli::Table;
use std::sync::{
atomic::{AtomicU64, Ordering},
Arc, Mutex,
};
#[derive(Debug, Default)]
struct MockTracker {
registered: Mutex<Vec<(Table, u64)>>,
total_increments: AtomicU64,
finished: AtomicU64,
}
impl ProgressTracker for MockTracker {
fn register(&self, table: Table, total_units: u64) {
self.registered.lock().unwrap().push((table, total_units));
}
fn increment(&self, _table: Table, units: u64) {
self.total_increments.fetch_add(units, Ordering::Relaxed);
}
fn finish(&self) {
self.finished.fetch_add(1, Ordering::Relaxed);
}
}
#[test]
fn mock_tracker_works_through_arc_dyn() {
let mock = Arc::new(MockTracker::default());
let dynamic: Arc<dyn ProgressTracker> = mock.clone();
dynamic.register(Table::Lineitem, 10);
dynamic.register(Table::Orders, 4);
dynamic.increment(Table::Lineitem, 3);
dynamic.increment(Table::Orders, 1);
dynamic.finish();
assert_eq!(
*mock.registered.lock().unwrap(),
vec![(Table::Lineitem, 10), (Table::Orders, 4)]
);
assert_eq!(mock.total_increments.load(Ordering::Relaxed), 4);
assert_eq!(mock.finished.load(Ordering::Relaxed), 1);
}
#[test]
fn default_register_and_finish_are_noops() {
#[derive(Debug)]
struct Minimal(AtomicU64);
impl ProgressTracker for Minimal {
fn increment(&self, _t: Table, c: u64) {
self.0.fetch_add(c, Ordering::Relaxed);
}
}
let m = Minimal(AtomicU64::new(0));
m.register(Table::Region, 99); m.increment(Table::Region, 7);
m.finish(); assert_eq!(m.0.load(Ordering::Relaxed), 7);
}
}