use crossbeam_channel::{Receiver, Sender};
use super::{DropIndicator, Scope};
use std::{
cmp,
collections::HashMap,
sync::{
atomic::{AtomicBool, Ordering::SeqCst},
Arc,
},
};
struct ParallelMapInner<I, O> {
tx: Option<crossbeam_channel::Sender<(usize, I)>>,
rx: crossbeam_channel::Receiver<(usize, O)>,
}
pub struct ParallelMapBuilder<I>
where
I: Iterator,
{
iter: I,
num_threads: Option<usize>,
buffer_size: Option<usize>,
}
impl<I> ParallelMapBuilder<I>
where
I: Iterator,
{
pub fn new(iter: I) -> Self {
Self {
iter,
num_threads: None,
buffer_size: None,
}
}
pub fn threads(self, num: usize) -> Self {
Self {
num_threads: Some(num),
..self
}
}
pub fn buffer_size(self, num: usize) -> Self {
Self {
buffer_size: Some(num),
..self
}
}
fn num_threads<T: Into<Option<usize>>>(num_threads: T) -> usize {
let mut num = num_threads.into().unwrap_or(0);
if num == 0 {
num = num_cpus::get_physical();
}
if num == 0 {
num = 1
}
num
}
fn with_common<O>(
self,
) -> (
ParallelMap<I, O>,
Receiver<(usize, I::Item)>,
Sender<(usize, O)>,
)
where
I: Iterator,
{
let num_threads = Self::num_threads(self.num_threads);
let buffer_size = cmp::max(1, self.buffer_size.unwrap_or(num_threads * 2));
let (in_tx, in_rx) = crossbeam_channel::bounded(buffer_size);
let (out_tx, out_rx) = crossbeam_channel::bounded(buffer_size);
(
ParallelMap {
iter: self.iter,
iter_done: false,
worker_panicked: Arc::new(AtomicBool::new(false)),
num_threads,
buffer_size,
out_of_order: HashMap::new(),
next_tx_i: 0,
next_rx_i: 0,
inner: Some(ParallelMapInner {
tx: Some(in_tx),
rx: out_rx,
}),
},
in_rx,
out_tx,
)
}
pub fn with<F, O>(self, f: F) -> ParallelMap<I, O>
where
I: Iterator + 'static,
F: 'static + Send + Clone,
O: Send + 'static,
I::Item: Send + 'static,
F: FnMut(I::Item) -> O,
{
let (ret, in_rx, out_tx) = self.with_common();
for _ in 0..ret.num_threads {
let in_rx = in_rx.clone();
let out_tx = out_tx.clone();
let mut f = f.clone();
let drop_indicator = DropIndicator::new(ret.worker_panicked.clone());
std::thread::spawn(move || {
for (i, item) in in_rx.into_iter() {
let _ = out_tx.send((i, (f)(item)));
}
drop_indicator.cancel();
});
}
ret
}
pub fn with_scoped<'env, 'scope, F, O>(
self,
scope: &'scope Scope<'env>,
f: F,
) -> ParallelMap<I, O>
where
I: Iterator + 'env,
F: 'env + Send + Clone,
O: Send + 'env,
I::Item: Send + 'env,
F: FnMut(I::Item) -> O,
{
let (ret, in_rx, out_tx) = self.with_common();
for _ in 0..ret.num_threads {
let in_rx = in_rx.clone();
let out_tx = out_tx.clone();
let mut f = f.clone();
let drop_indicator = DropIndicator::new(ret.worker_panicked.clone());
scope.spawn(move |_scope| {
for (i, item) in in_rx.into_iter() {
let _ = out_tx.send((i, (f)(item)));
}
drop_indicator.cancel();
});
}
ret
}
}
pub struct ParallelMap<I, O>
where
I: Iterator,
{
iter: I,
iter_done: bool,
num_threads: usize,
buffer_size: usize,
next_tx_i: usize,
next_rx_i: usize,
worker_panicked: Arc<AtomicBool>,
out_of_order: HashMap<usize, O>,
inner: Option<ParallelMapInner<I::Item, O>>,
}
impl<I, O> ParallelMap<I, O>
where
I: Iterator,
I::Item: Send,
O: Send,
{
fn pump_tx(&mut self) {
if self.iter_done {
return;
}
while self.next_tx_i < self.next_rx_i + self.buffer_size {
if let Some(item) = self.iter.next() {
self.inner
.as_ref()
.expect("not started")
.tx
.as_ref()
.expect("inner-iterator exhausted")
.send((self.next_tx_i, item))
.expect("send failed");
self.next_tx_i += 1;
} else {
self.iter_done = true;
self.inner.as_mut().expect("not started").tx = None;
break;
}
}
}
}
impl<I, O> Iterator for ParallelMap<I, O>
where
I: Iterator,
I::Item: Send,
O: Send,
{
type Item = O;
fn next(&mut self) -> Option<Self::Item> {
self.pump_tx();
loop {
if self.next_rx_i == self.next_tx_i && self.iter_done {
return None;
}
if let Some(item) = self.out_of_order.remove(&self.next_rx_i) {
self.next_rx_i += 1;
self.pump_tx();
return Some(item);
}
match self
.inner
.as_ref()
.expect("not started")
.rx
.recv_timeout(std::time::Duration::from_micros(100))
{
Ok((item_i, item)) => {
if item_i == self.next_rx_i {
self.next_rx_i += 1;
self.pump_tx();
return Some(item);
} else {
assert!(item_i > self.next_rx_i);
self.out_of_order.insert(item_i, item);
}
}
Err(crossbeam_channel::RecvTimeoutError::Timeout) => {
if self.worker_panicked.load(SeqCst) {
panic!("parallel_map worker thread panicked: panic indicator set");
}
}
Err(crossbeam_channel::RecvTimeoutError::Disconnected) => {
panic!("parallel_map worker thread panicked: channel disconnected");
}
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.iter.size_hint()
}
}