progpool 0.3.2

Job pool with terminal progress bar
Documentation
use std::collections::{BTreeMap, BTreeSet};
use std::convert::TryInto as _;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};

fn progress_bar_style(task_count: usize) -> indicatif::ProgressStyle {
  let task_count_digits = task_count.to_string().len();
  let count = format!("{{pos:>{}}}/{{len}}", 8 - task_count_digits);
  let template = format!(
    "[{{elapsed_precise}}] {{prefix}} {} {{bar:40.cyan/blue}}: {{msg}}",
    count
  );
  indicatif::ProgressStyle::with_template(&template)
    .unwrap()
    .progress_chars("##-")
}

#[derive(Clone, Eq, PartialEq, Ord, PartialOrd)]
struct TaskMonitorId((Instant, String, usize));

struct TaskMonitor {
  counter: usize,
  tasks: BTreeSet<TaskMonitorId>,
  progress_bar: indicatif::ProgressBar,
}

impl TaskMonitor {
  fn new(progress_bar: indicatif::ProgressBar) -> TaskMonitor {
    TaskMonitor {
      counter: 0,
      tasks: BTreeSet::new(),
      progress_bar,
    }
  }

  fn started(&mut self, task_name: String) -> TaskMonitorId {
    let now = Instant::now();
    let task_id = TaskMonitorId((now, task_name, self.counter));
    self.counter += 1;
    self.tasks.insert(task_id.clone());
    self.update_progress_bar();
    task_id
  }

  fn finished(&mut self, task_id: TaskMonitorId) {
    let removed = self.tasks.remove(&task_id);
    assert!(removed);
    self.progress_bar.inc(1);
    self.update_progress_bar();
  }

  fn update_progress_bar(&mut self) {
    match self.tasks.iter().next() {
      Some(task) => {
        self.progress_bar.set_message((task.0).1.clone());
      }

      None => {
        self.progress_bar.set_message("");
      }
    }
  }
}

pub struct Pool {
  thread_count: Option<usize>,
  thread_pool: Option<rayon::ThreadPool>,
  quiet: bool,
}

pub struct ExecutionResult<T> {
  pub name: String,
  pub result: T,
}

pub struct ExecutionResults<T, E> {
  pub successful: Vec<ExecutionResult<T>>,
  pub failed: Vec<ExecutionResult<E>>,
}

impl Pool {
  pub fn with_default_size() -> Pool {
    Pool {
      thread_count: None,
      thread_pool: None,
      quiet: false,
    }
  }

  pub fn with_size(thread_count: usize) -> Pool {
    Pool {
      thread_count: Some(thread_count),
      thread_pool: None,
      quiet: false,
    }
  }

  pub fn quiet(&mut self, quiet: bool) {
    self.quiet = quiet;
  }

  fn thread_pool(&mut self) -> &mut rayon::ThreadPool {
    let thread_count = self.thread_count;
    self.thread_pool.get_or_insert_with(|| {
      let builder = rayon::ThreadPoolBuilder::new();
      let builder = match thread_count {
        Some(count) => builder.num_threads(count),
        None => builder,
      };

      builder
        .build()
        .unwrap_or_else(|err| panic!("failed to create job pool: {}", err))
    })
  }

  pub fn execute<T: Send, E: Send>(&mut self, mut job: Job<T, E>) -> ExecutionResults<T, E> {
    let mut task_monitor = TaskMonitor::new(if self.quiet {
      indicatif::ProgressBar::hidden()
    } else {
      let task_count = job.tasks.len();
      let pb = indicatif::ProgressBar::new(task_count.try_into().unwrap());
      pb.set_style(progress_bar_style(task_count));
      pb.set_prefix(job.name.clone());
      pb.enable_steady_tick(Duration::from_secs(1));
      pb
    });

    tracing::span!(tracing::Level::INFO, "Pool::execute", name = job.name).in_scope(|| {
      let mut successful = BTreeMap::new();
      let mut failed = BTreeMap::new();
      {
        let state = Arc::new(Mutex::new((&mut task_monitor, &mut successful, &mut failed)));
        self.thread_pool().scope(|scope| {
          for (i, task) in job.tasks.drain(..).enumerate() {
            let state = state.clone();
            scope.spawn(move |_| {
              tracing::span!(tracing::Level::INFO, "Task::task", name = task.name).in_scope(|| {
                let task_id = {
                  let mut guard = state.lock().unwrap();
                  let (task_monitor, _, _) = &mut *guard;
                  task_monitor.started(task.name.clone())
                };

                let result = (task.task)();

                let mut guard = state.lock().unwrap();
                let (task_monitor, successful, failed) = &mut *guard;
                task_monitor.finished(task_id);
                match result {
                  Ok(result) => {
                    successful.insert(
                      i,
                      ExecutionResult {
                        name: task.name,
                        result,
                      },
                    );
                  }

                  Err(err) => {
                    failed.insert(
                      i,
                      ExecutionResult {
                        name: task.name,
                        result: err,
                      },
                    );
                  }
                }
              })
            })
          }
        });
      }
      task_monitor.progress_bar.finish();

      ExecutionResults {
        successful: successful.into_values().collect(),
        failed: failed.into_values().collect(),
      }
    })
  }
}

pub struct Job<'task, T: Send, E: Send> {
  name: String,
  tasks: Vec<Task<'task, T, E>>,
}

impl<'task, T: Send, E: Send> Job<'task, T, E> {
  pub fn with_name<N: Into<String>>(name: N) -> Self {
    Job {
      name: name.into(),
      tasks: Vec::new(),
    }
  }

  pub fn add_task<N: Into<String>, F: FnOnce() -> Result<T, E> + Send + 'task>(&mut self, name: N, task: F) {
    let name = name.into();
    self.tasks.push(Task {
      name,
      task: Box::new(task),
    });
  }
}

struct Task<'task, T: Send, E: Send> {
  name: String,
  task: Box<dyn FnOnce() -> Result<T, E> + Send + 'task>,
}

#[cfg(test)]
mod tests {
  #[test]
  fn smoke() {
    use super::*;

    let (send_1, recv_1) = std::sync::mpsc::channel();
    let (send_2, recv_2) = std::sync::mpsc::channel();
    let (send_3, recv_3) = std::sync::mpsc::channel();

    let mut job = Job::with_name("forward");
    job.add_task("first", move || {
      let data = recv_1.recv().unwrap();
      send_2.send(data + 1).unwrap();
      Ok(data)
    });

    job.add_task("second", move || {
      let data = recv_2.recv().unwrap();
      send_3.send(data + 1).unwrap();
      Ok(data)
    });

    job.add_task("third", move || {
      let data = recv_3.recv().unwrap();
      Err(format!("{}", data + 1))
    });

    send_1.send(0).unwrap();

    // Dispatch order is not guaranteed, must schedule all at once to avoid deadlock.
    let mut pool = Pool::with_size(job.tasks.len());
    let results = pool.execute(job);

    assert_eq!(results.successful.len(), 2);
    assert_eq!(results.successful[0].name, "first");
    assert_eq!(results.successful[0].result, 0);

    assert_eq!(results.successful[1].name, "second");
    assert_eq!(results.successful[1].result, 1);

    assert_eq!(results.failed.len(), 1);
    assert_eq!(results.failed[0].name, "third");
    assert_eq!(format!("{}", results.failed[0].result), format!("{}", 3));
  }
}