use std::collections::BinaryHeap;
use std::iter::IntoIterator;
use std::{marker, mem};
use std::sync::{mpsc, atomic, Mutex, Arc};
use std::thread;
use fnbox::FnBox;
use crossbeam::{self, Scope};
type JobInner<'b> = Box<for<'a> FnBox<&'a [mpsc::Sender<Work>]> + Send + 'b>;
struct Job {
func: JobInner<'static>,
}
pub struct Pool {
job_queue: mpsc::Sender<(Option<Job>, mpsc::Sender<Result<(), ()>>)>,
job_status: Option<Arc<Mutex<JobStatus>>>,
n_threads: usize,
}
#[derive(Copy, Clone)]
struct WorkerId { n: usize }
type WorkInner<'a> = &'a mut (FnMut(WorkerId) + Send + 'a);
struct Work {
func: WorkInner<'static>
}
struct JobStatus {
wait: bool,
job_finished: mpsc::Receiver<Result<(), ()>>,
}
pub struct JobHandle<'pool, 'f> {
pool: &'pool mut Pool,
status: Arc<Mutex<JobStatus>>,
_funcs: marker::PhantomData<&'f ()>,
}
impl JobStatus {
fn wait(&mut self) {
if self.wait {
self.wait = false;
self.job_finished.recv().unwrap().unwrap();
}
}
}
impl<'pool, 'f> JobHandle<'pool, 'f> {
pub fn wait(&self) {
self.status.lock().unwrap().wait();
}
}
impl<'pool, 'f> Drop for JobHandle<'pool, 'f> {
fn drop(&mut self) {
self.wait();
self.pool.job_status = None;
}
}
impl Drop for Pool {
fn drop(&mut self) {
let (tx, rx) = mpsc::channel();
self.job_queue.send((None, tx)).unwrap();
rx.recv().unwrap().unwrap();
}
}
struct PanicCanary<'a> {
flag: &'a atomic::AtomicBool
}
impl<'a> Drop for PanicCanary<'a> {
fn drop(&mut self) {
if thread::panicking() {
self.flag.store(true, atomic::Ordering::SeqCst)
}
}
}
impl Pool {
pub fn new(n_threads: usize) -> Pool {
let (tx, rx) = mpsc::channel::<(Option<Job>, mpsc::Sender<Result<(), ()>>)>();
thread::spawn(move || {
let panicked = Arc::new(atomic::AtomicBool::new(false));
let mut _guards = Vec::with_capacity(n_threads);
let mut txs = Vec::with_capacity(n_threads);
for i in 0..n_threads {
let id = WorkerId { n: i };
let (subtx, subrx) = mpsc::channel::<Work>();
txs.push(subtx);
let panicked = panicked.clone();
_guards.push(thread::spawn(move || {
let _canary = PanicCanary {
flag: &panicked
};
loop {
match subrx.recv() {
Ok(mut work) => {
(work.func)(id)
}
Err(_) => break,
}
}
}))
}
loop {
match rx.recv() {
Ok((Some(job), finished_tx)) => {
(job.func).call_box(&txs);
let job_panicked = panicked.load(atomic::Ordering::SeqCst);
let msg = if job_panicked { Err(()) } else { Ok(()) };
finished_tx.send(msg).unwrap();
if job_panicked { break }
}
Ok((None, finished_tx)) => {
finished_tx.send(Ok(())).unwrap();
break
}
Err(_) => break,
}
}
});
Pool {
job_queue: tx,
job_status: None,
n_threads: n_threads,
}
}
pub fn for_<Iter: IntoIterator, F>(&mut self, iter: Iter, ref f: F)
where Iter::Item: Send,
Iter: Send,
F: Fn(Iter::Item) + Sync
{
let (needwork_tx, needwork_rx) = mpsc::channel();
let mut work_txs = Vec::with_capacity(self.n_threads);
let mut work_rxs = Vec::with_capacity(self.n_threads);
for _ in 0..self.n_threads {
let (t, r) = mpsc::channel();
work_txs.push(t);
work_rxs.push(r);
}
let mut work_rxs = work_rxs.into_iter();
crossbeam::scope(|scope| unsafe {
let handle = self.execute(
scope,
needwork_tx,
|needwork_tx| {
let mut needwork_tx = Some(needwork_tx.clone());
let mut work_rx = Some(work_rxs.next().unwrap());
move |id| {
let work_rx = work_rx.take().unwrap();
let needwork = needwork_tx.take().unwrap();
loop {
needwork.send(id).unwrap();
match work_rx.recv() {
Ok(Some(elem)) => {
f(elem);
}
Ok(None) | Err(_) => break
}
}
}
},
move |needwork_tx| {
let mut iter = iter.into_iter().fuse();
drop(needwork_tx);
loop {
match needwork_rx.recv() {
Err(_) => break,
Ok(id) => {
work_txs[id.n].send(iter.next()).unwrap();
}
}
}
});
handle.wait();
})
}
pub fn unordered_map<'pool, 'a, I: IntoIterator, F, T>(&'pool mut self, scope: &Scope<'a>, iter: I, f: F)
-> UnorderedParMap<'pool, 'a, T>
where I: 'a + Send,
I::Item: Send + 'a,
F: 'a + Sync + Send + Fn(I::Item) -> T,
T: Send + 'a
{
let nthreads = self.n_threads;
let (needwork_tx, needwork_rx) = mpsc::channel();
let (work_tx, work_rx) = mpsc::channel();
struct Shared<Chan, Atom, F> {
work: Chan,
sent: Atom,
finished: Atom,
func: F,
}
let shared = Arc::new(Shared {
work: Mutex::new(work_rx),
sent: atomic::AtomicUsize::new(0),
finished: atomic::AtomicUsize::new(0),
func: f,
});
let (tx, rx) = mpsc::channel();
const INITIAL_FACTOR: usize = 4;
const BUFFER_FACTOR: usize = INITIAL_FACTOR / 2;
let handle = unsafe {
self.execute(scope, (needwork_tx, shared),
move |&mut (ref needwork_tx, ref shared)| {
let mut needwork_tx = Some(needwork_tx.clone());
let tx = tx.clone();
let shared = shared.clone();
move |_id| {
let needwork = needwork_tx.take().unwrap();
loop {
let data = {
let guard = shared.work.lock().unwrap();
guard.recv()
};
match data {
Ok(Some((idx, elem))) => {
let data = (shared.func)(elem);
let status = tx.send(Packet {
idx: idx, data: data
});
if status.is_err() {
let _ = needwork.send(true);
break
}
}
Ok(None) | Err(_) => {
break
}
};
let old =
shared.finished.fetch_add(1, atomic::Ordering::SeqCst);
let sent = shared.sent.load(atomic::Ordering::SeqCst);
if old + BUFFER_FACTOR * nthreads == sent {
if needwork.send(false).is_err() {
break
}
}
}
}
},
move |(needwork_tx, shared)| {
let mut iter = iter.into_iter().fuse().enumerate();
drop(needwork_tx);
let mut send_data = |n: usize| {
shared.sent.fetch_add(n, atomic::Ordering::SeqCst);
for _ in 0..n {
let _ = work_tx.send(iter.next());
}
};
send_data(INITIAL_FACTOR * nthreads);
loop {
match needwork_rx.recv() {
Ok(true) | Err(_) => break,
Ok(false) => {
let _ = send_data(BUFFER_FACTOR * nthreads);
}
}
}
})
};
UnorderedParMap {
rx: rx,
_guard: handle,
}
}
pub fn map<'pool, 'a, I: IntoIterator, F, T>(&'pool mut self, scope: &Scope<'a>, iter: I, f: F)
-> ParMap<'pool, 'a, T>
where I: 'a + Send,
I::Item: Send + 'a,
F: 'a + Send + Sync + Fn(I::Item) -> T,
T: Send + 'a
{
ParMap {
unordered: self.unordered_map(scope, iter, f),
looking_for: 0,
queue: BinaryHeap::new(),
}
}
}
impl Pool {
pub unsafe fn execute<'pool, 'f, A, GenFn, WorkerFn, MainFn>(
&'pool mut self, scope: &Scope<'f>, data: A, gen_fn: GenFn, main_fn: MainFn) -> JobHandle<'pool, 'f>
where A: 'f + Send,
GenFn: 'f + FnMut(&mut A) -> WorkerFn + Send,
WorkerFn: 'f + FnMut(WorkerId) + Send,
MainFn: 'f + FnOnce(A) + Send,
{
self.execute_nonunsafe(scope, data, gen_fn, main_fn)
}
fn execute_nonunsafe<'pool, 'f, A, GenFn, WorkerFn, MainFn>(
&'pool mut self, scope: &Scope<'f>, mut data: A,
mut gen_fn: GenFn, main_fn: MainFn) -> JobHandle<'pool, 'f>
where A: 'f + Send,
GenFn: 'f + FnMut(&mut A) -> WorkerFn + Send,
WorkerFn: 'f + FnMut(WorkerId) + Send,
MainFn: 'f + FnOnce(A) + Send,
{
let n_threads = self.n_threads;
let func: JobInner<'f> = Box::new(move |workers: &[mpsc::Sender<Work>]| {
assert_eq!(workers.len(), n_threads);
let mut worker_fns: Vec<_> = (0..n_threads).map(|_| gen_fn(&mut data)).collect();
for (func, worker) in worker_fns.iter_mut().zip(workers.iter()) {
let func: WorkInner = func;
let func: WorkInner<'static> = unsafe {
mem::transmute(func)
};
worker.send(Work { func: func }).unwrap();
}
main_fn(data)
});
let func: JobInner<'static> = unsafe {
mem::transmute(func)
};
let (tx, rx) = mpsc::channel();
self.job_queue.send((Some(Job { func: func }), tx)).unwrap();
let status = Arc::new(Mutex::new(JobStatus {
wait: true,
job_finished: rx,
}));
self.job_status = Some(status.clone());
let status_ = status.clone();
scope.defer(move || {
status_.lock().unwrap().wait();
});
JobHandle {
pool: self,
status: status,
_funcs: marker::PhantomData,
}
}
}
use std::cmp::Ordering;
struct Packet<T> {
idx: usize,
data: T,
}
impl<T> PartialOrd for Packet<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) }
}
impl<T> Ord for Packet<T> {
fn cmp(&self, other: &Self) -> Ordering { other.idx.cmp(&self.idx) }
}
impl<T> PartialEq for Packet<T> {
fn eq(&self, other: &Self) -> bool { self.idx == other.idx }
}
impl<T> Eq for Packet<T> {}
pub struct UnorderedParMap<'pool, 'a, T: 'a + Send> {
rx: mpsc::Receiver<Packet<T>>,
_guard: JobHandle<'pool, 'a>,
}
impl<'pool, 'a,T: 'a + Send> Iterator for UnorderedParMap<'pool , 'a, T> {
type Item = (usize, T);
fn next(&mut self) -> Option<(usize, T)> {
match self.rx.recv() {
Ok(Packet { data, idx }) => Some((idx, data)),
Err(mpsc::RecvError) => None,
}
}
}
pub struct ParMap<'pool, 'a, T: 'a + Send> {
unordered: UnorderedParMap<'pool, 'a, T>,
looking_for: usize,
queue: BinaryHeap<Packet<T>>
}
impl<'pool, 'a, T: Send + 'a> Iterator for ParMap<'pool, 'a, T> {
type Item = T;
fn next(&mut self) -> Option<T> {
loop {
if self.queue.peek().map_or(false, |x| x.idx == self.looking_for) {
let packet = self.queue.pop().unwrap();
self.looking_for += 1;
return Some(packet.data)
}
match self.unordered.rx.recv() {
Ok(packet) => self.queue.push(packet),
Err(mpsc::RecvError) => return None,
}
}
}
}