#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
use std::fmt;
use std::iter::{self, FromIterator};
use std::mem;
use std::panic;
use std::process;
use std::sync::mpsc;
use std::thread;
#[must_use]
pub struct Parallel<'a, T> {
closures: Vec<Box<dyn FnOnce() -> T + Send + 'a>>,
}
impl<'a, T> Parallel<'a, T> {
pub fn new() -> Parallel<'a, T> {
Parallel {
closures: Vec::new(),
}
}
#[allow(clippy::should_implement_trait)]
pub fn add<F>(mut self, f: F) -> Parallel<'a, T>
where
F: FnOnce() -> T + Send + 'a,
T: Send + 'a,
{
self.closures.push(Box::new(f));
self
}
pub fn each<A, I, F>(mut self, iter: I, f: F) -> Parallel<'a, T>
where
I: IntoIterator<Item = A>,
F: FnOnce(A) -> T + Clone + Send + 'a,
A: Send + 'a,
T: Send + 'a,
{
for t in iter.into_iter() {
let f = f.clone();
self.closures.push(Box::new(|| f(t)));
}
self
}
pub fn collect<C>(mut self) -> C
where
T: Send + 'a,
C: FromIterator<T> + Extend<T>,
{
let f = match self.closures.pop() {
None => return iter::empty().collect(),
Some(f) => f,
};
let (mut results, r) = self.finish_in::<_, _, C>(f);
results.extend(Some(r));
results
}
pub fn run(self) -> Vec<T>
where
T: Send + 'a,
{
self.collect()
}
pub fn finish<F, R>(self, f: F) -> (Vec<T>, R)
where
F: FnOnce() -> R,
T: Send + 'a,
{
self.finish_in::<_, _, Vec<T>>(f)
}
pub fn finish_in<F, R, C>(self, f: F) -> (C, R)
where
F: FnOnce() -> R,
T: Send + 'a,
C: FromIterator<T>,
{
let guard = NoPanic;
let mut handles = Vec::new();
let mut receivers = Vec::new();
for f in self.closures.into_iter() {
let (sender, receiver) = mpsc::channel();
let f = move || sender.send(f()).unwrap();
let f: Box<dyn FnOnce() + Send + 'a> = Box::new(f);
let f: Box<dyn FnOnce() + Send + 'static> = unsafe { mem::transmute(f) };
handles.push(thread::spawn(f));
receivers.push(receiver);
}
let mut last_err = None;
let res = panic::catch_unwind(panic::AssertUnwindSafe(f));
for h in handles {
if let Err(err) = h.join() {
last_err = Some(err);
}
}
drop(guard);
if let Some(err) = last_err {
panic::resume_unwind(err);
}
let results = receivers.into_iter().map(|r| r.recv().unwrap()).collect();
match res {
Ok(r) => (results, r),
Err(err) => panic::resume_unwind(err),
}
}
}
impl<T> fmt::Debug for Parallel<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Parallel")
.field("len", &self.closures.len())
.finish()
}
}
impl<T> Default for Parallel<'_, T> {
fn default() -> Self {
Self::new()
}
}
struct NoPanic;
impl Drop for NoPanic {
fn drop(&mut self) {
if thread::panicking() {
process::abort();
}
}
}