use std::collections::VecDeque;
use std::sync::{Arc,Mutex,Condvar};
use std::thread::JoinHandle;
use std::marker::PhantomData;
use std::any::Any;
extern crate num_cpus;
trait AbstractTask: Send
{
fn run(self: Box<Self>, state: &mut Box<Any>);
}
struct FnState<State, F>(F, PhantomData<State>)
where F: FnOnce(&mut State) + Send;
struct FnStateless<F>(F)
where F: FnOnce() + Send;
impl<State: 'static + Send, F> AbstractTask for FnState<State,F>
where F: FnOnce(&mut State) + Send
{
fn run(self: Box<Self>, state: &mut Box<Any>)
{
let state_any = &mut *state;
let state = state_any.downcast_mut().unwrap();
(*self).0(state);
}
}
impl<F> AbstractTask for FnStateless<F>
where F: FnOnce() + Send
{
fn run(self: Box<Self>, _: &mut Box<Any>)
{
(*self).0();
}
}
trait StateMakerBox
{
fn create(self: &Self) -> Box<Any>;
}
impl<F: Fn()->Box<Any>> StateMakerBox for F
{
fn create(self: &F) -> Box<Any>
{
(*self)()
}
}
type StateMaker<'a> = Box<StateMakerBox + Send + 'a>;
#[derive(PartialEq)]
enum Flag
{
Checkpoint, Resume, Exit, SetState, }
struct Messaging
{
job_queue : VecDeque<Box<AbstractTask>>,
flag : Option<Flag>,
completion_counter : usize,
panic_detected : bool,
state_maker : Option<StateMaker<'static>>,
}
impl Messaging
{
fn new() -> Self
{
Messaging
{
job_queue: VecDeque::new(),
flag: None,
completion_counter: 0,
panic_detected: false,
state_maker: None,
}
}
}
struct PoolStatus
{
messaging : Mutex<Messaging>,
incoming_notif_cv : Condvar,
thread_response_cv : Condvar,
}
pub struct Pool
{
threads: Vec<JoinHandle<()>>,
pool_status : Arc<PoolStatus>,
backlog: Option<usize>,
}
impl Pool
{
pub fn new()
-> Pool
{
let c = num_cpus::get();
Self::base_new(c, Some(c*4))
}
pub fn unbounded()
-> Pool
{
let c = num_cpus::get();
Self::base_new(c, None)
}
pub fn new_threads(nthreads : usize, backlog : usize)
-> Pool
{
Self::base_new(nthreads, Some(backlog))
}
pub fn new_threads_unbounded(nthreads : usize)
-> Pool
{
Self::base_new(nthreads, None)
}
fn run_thread(pool_status : Arc<PoolStatus>)
{
let has_panic = std::panic::catch_unwind(
std::panic::AssertUnwindSafe(||
{
let mut state = Box::new(0) as Box<Any>;
let messaging = &pool_status.messaging;
let mut in_checkpoint = false;
loop
{
let mut messaging = messaging
.lock()
.unwrap_or_else(|e| e.into_inner());
while (
in_checkpoint
&& *messaging.flag.as_ref()
.unwrap_or(&Flag::Checkpoint)
== Flag::Checkpoint
) || (
messaging.job_queue.is_empty()
&& messaging.flag.is_none()
)
{
messaging = pool_status
.incoming_notif_cv
.wait(messaging)
.unwrap_or_else(|e| e.into_inner());
}
match messaging.flag
{
Some(Flag::Checkpoint) =>
{
if !in_checkpoint && messaging.job_queue.is_empty()
{
messaging.completion_counter += 1;
in_checkpoint = true;
pool_status.thread_response_cv.notify_one();
continue;
}
},
Some(Flag::Resume) =>
{
if in_checkpoint
{
messaging.completion_counter += 1;
in_checkpoint = false;
pool_status.thread_response_cv.notify_one();
}
},
Some(Flag::SetState) =>
{
if !in_checkpoint && messaging.job_queue.is_empty()
{
state = messaging.state_maker.as_ref().unwrap().create();
messaging.completion_counter += 1;
in_checkpoint = true;
pool_status.thread_response_cv.notify_one();
}
},
Some(Flag::Exit) =>
{
return;
},
None => { }
}
if let Some(t) = messaging.job_queue.pop_front()
{
pool_status.thread_response_cv.notify_one();
drop(messaging);
t.run(&mut state);
}
}
})
);
let mut messaging = pool_status.messaging
.lock().unwrap_or_else(|e| e.into_inner());
if has_panic.is_err()
{
messaging.panic_detected = true;
}
pool_status.thread_response_cv.notify_one();
}
fn base_new(nthreads : usize, backlog: Option<usize>)
-> Pool
{
assert!(nthreads >= 1);
let mut threads = Vec::with_capacity(nthreads);
let pool_status = Arc::new(
PoolStatus
{
messaging: Mutex::new(Messaging::new()),
incoming_notif_cv : Condvar::new(),
thread_response_cv : Condvar::new(),
}
);
for _ in 0..nthreads
{
let pool_status = pool_status.clone();
let t = std::thread::spawn(
move || Self::run_thread(pool_status)
);
threads.push(t);
}
Pool
{
threads: threads,
pool_status: pool_status,
backlog: backlog,
}
}
pub fn scoped<'pool, 'scope, F>(&'pool mut self, f : F)
where F: FnOnce(Scope<'pool, 'scope>)
{
{
let scope = Scope
{
pool: self,
_scope: PhantomData,
};
f(scope);
}
{ let mut messaging = self.pool_status.messaging
.lock().unwrap_or_else(|e| e.into_inner());
messaging.completion_counter = 0;
messaging.flag = Some(Flag::Checkpoint);
self.pool_status.incoming_notif_cv.notify_all();
while messaging.completion_counter != self.threads.len()
&& !messaging.panic_detected
{
messaging
= self.pool_status
.thread_response_cv
.wait(messaging)
.unwrap_or_else(|e| e.into_inner());
}
if messaging.panic_detected
{
panic!("worker thread panicked");
}
}
{ let mut messaging = self.pool_status.messaging
.lock().unwrap_or_else(|e| e.into_inner());
messaging.completion_counter = 0;
messaging.flag = Some(Flag::Resume);
self.pool_status.incoming_notif_cv.notify_all();
while messaging.completion_counter != self.threads.len()
{
messaging
= self.pool_status
.thread_response_cv
.wait(messaging)
.unwrap_or_else(|e| e.into_inner());
}
messaging.flag = None;
}
}
}
impl Drop for Pool
{
fn drop(&mut self)
{
{ let mut messaging = self.pool_status.messaging
.lock().unwrap_or_else(|e| e.into_inner());
if messaging.job_queue.len() != 0
{
panic!("pond::Pool: one or more worker thread panicked");
}
messaging.flag = Some(Flag::Exit);
self.pool_status.incoming_notif_cv.notify_all();
}
for t in self.threads.drain(..)
{
t.join().unwrap();
}
}
}
pub struct Scope<'pool, 'scope>
{
pool: &'pool Pool,
_scope: PhantomData<::std::cell::Cell<&'scope ()>>,
}
impl<'pool, 'scope> Scope<'pool, 'scope>
{
pub fn with_state<StateMaker, State>(self, state_maker : StateMaker)
-> ScopeWithState<'pool, 'scope, State>
where StateMaker: Fn() -> State + Send + 'scope,
State: 'static
{
let f =
move ||
{
let state = state_maker();
Box::new(state) as Box<Any>
};
let f = unsafe
{
std::mem::transmute::<
Box<StateMakerBox + 'scope + Send>,
Box<StateMakerBox + 'static + Send>
>(Box::new(f))
};
{ let mut messaging = self.pool.pool_status.messaging
.lock().unwrap_or_else(|e| e.into_inner());
messaging.completion_counter = 0;
messaging.flag = Some(Flag::SetState);
messaging.state_maker = Some(f);
self.pool.pool_status.incoming_notif_cv.notify_all();
while messaging.completion_counter != self.pool.threads.len()
&& !messaging.panic_detected
{
messaging
= self.pool.pool_status
.thread_response_cv
.wait(messaging)
.unwrap_or_else(|e| e.into_inner());
}
if messaging.panic_detected
{
panic!("worker thread panicked");
}
}
{ let mut messaging = self.pool.pool_status.messaging
.lock().unwrap_or_else(|e| e.into_inner());
messaging.completion_counter = 0;
messaging.state_maker = None;
messaging.flag = Some(Flag::Resume);
self.pool.pool_status.incoming_notif_cv.notify_all();
while messaging.completion_counter != self.pool.threads.len()
{
messaging
= self.pool.pool_status
.thread_response_cv
.wait(messaging)
.unwrap_or_else(|e| e.into_inner());
}
messaging.flag = None;
}
ScopeWithState
{
pool: self.pool,
_scope: PhantomData,
_state: PhantomData,
}
}
pub fn execute<F>(&self, f: F)
where F: FnOnce() + Send + 'scope
{
let boxed_fn
= unsafe
{
std::mem::transmute::<
Box<AbstractTask + 'scope>,
Box<AbstractTask + 'static>
>(Box::new(FnStateless(f)))
};
let mut messaging = self.pool.pool_status.messaging
.lock().unwrap_or_else(|e| e.into_inner());
while self.pool.backlog.map(
|allowed| allowed < messaging.job_queue.len()
).unwrap_or(false)
{
messaging
= self.pool.pool_status
.thread_response_cv
.wait(messaging)
.unwrap_or_else(|e| e.into_inner());
}
messaging.job_queue.push_back( boxed_fn );
if messaging.panic_detected
{
panic!("worker thread panicked");
}
self.pool.pool_status.incoming_notif_cv.notify_one();
}
}
pub struct ScopeWithState<'pool, 'scope, State>
where State : 'static
{
pool: &'pool Pool,
_scope: PhantomData<::std::cell::Cell<&'scope ()>>,
_state: PhantomData<&'scope State>,
}
impl<'pool, 'scope, State> ScopeWithState<'pool, 'scope, State>
where State: 'static + Send
{
pub fn execute<F>(&self, f: F)
where F: FnOnce(&mut State) + Send + 'scope
{
let boxed_fn
= unsafe
{
std::mem::transmute::<
Box<AbstractTask + 'scope>,
Box<AbstractTask + 'static>
>(Box::new(FnState(f, PhantomData)))
};
let mut messaging = self.pool.pool_status.messaging
.lock().unwrap_or_else(|e| e.into_inner());
while self.pool.backlog.map(
|allowed| allowed < messaging.job_queue.len()
).unwrap_or(false)
{
messaging
= self.pool.pool_status
.thread_response_cv
.wait(messaging)
.unwrap_or_else(|e| e.into_inner());
}
messaging.job_queue.push_back( boxed_fn );
if messaging.panic_detected
{
panic!("worker thread panicked");
}
self.pool.pool_status.incoming_notif_cv.notify_one();
}
}
#[cfg(test)]
mod tests
{
use super::Pool;
use std::thread;
use std::sync;
use std::time;
fn sleep_ms(ms: u64)
{
thread::sleep(time::Duration::from_millis(ms));
}
#[test]
fn smoketest()
{
let mut pool = Pool::new_threads_unbounded(4);
for i in 1..7
{
let mut vec = vec![0, 1, 2, 3, 4];
pool.scoped(
|s|
{
for e in vec.iter_mut()
{
s.execute(
move ||
{
*e += i;
}
);
}
}
);
let mut vec2 = vec![0, 1, 2, 3, 4];
for e in vec2.iter_mut()
{
*e += i;
}
assert_eq!(vec, vec2);
}
}
#[test]
#[should_panic]
fn thread_panic()
{
let mut pool = Pool::new_threads_unbounded(4);
pool.scoped(
|scoped|
{
scoped.execute(
move ||
{
panic!()
}
);
}
);
}
#[test]
#[should_panic]
fn scope_panic()
{
let mut pool = Pool::new_threads_unbounded(4);
pool.scoped(
|_scoped|
{
panic!()
}
);
}
#[test]
#[should_panic]
fn pool_panic()
{
let _pool = Pool::new_threads_unbounded(4);
panic!()
}
#[test]
fn join_all()
{
let mut pool = Pool::new_threads_unbounded(4);
let (tx_, rx) = sync::mpsc::channel();
pool.scoped(
|scoped|
{
let tx = tx_.clone();
scoped.execute(
move ||
{
sleep_ms(1000);
tx.send(2).unwrap();
}
);
let tx = tx_.clone();
scoped.execute(
move ||
{
tx.send(1).unwrap();
}
);
let tx = tx_.clone();
scoped.execute(
move ||
{
sleep_ms(500);
tx.send(3).unwrap();
}
);
}
);
assert_eq!(rx.iter().take(3).collect::<Vec<_>>(), vec![1, 3, 2]);
}
#[test]
fn join_all_with_thread_panic()
{
use std::sync::mpsc::Sender;
struct OnScopeEnd(Sender<u8>);
impl Drop for OnScopeEnd
{
fn drop(&mut self)
{
self.0.send(1).unwrap();
sleep_ms(200);
}
}
let (tx_, rx) = sync::mpsc::channel();
let handle = thread::spawn(
move ||
{
let mut pool = Pool::new_threads_unbounded(8);
let _on_scope_end = OnScopeEnd(tx_.clone());
pool.scoped(
|scoped|
{
scoped.execute(
move ||
{
sleep_ms(1000);
panic!();
}
);
for _ in 1..8
{
let tx = tx_.clone();
scoped.execute(
move ||
{
sleep_ms(2000);
tx.send(0).unwrap();
}
);
}
}
);
}
);
if let Ok(..) = handle.join()
{
panic!("Pool didn't panic as expected");
}
let values: Vec<u8> = rx.into_iter().collect();
assert_eq!(&values[..], &[1, 0, 0, 0, 0, 0, 0, 0]);
}
#[test]
fn no_leak()
{
let counters = ::std::sync::Arc::new(());
let mut pool = Pool::new_threads_unbounded(4);
pool.scoped(
|scoped|
{
let c = ::std::sync::Arc::clone(&counters);
scoped.execute(
move ||
{
let _c = c;
sleep_ms(100);
}
);
}
);
drop(pool);
assert_eq!(::std::sync::Arc::strong_count(&counters), 1);
}
#[test]
fn no_leak2()
{
let mut pool = Pool::new_threads_unbounded(4);
pool.scoped(
|scoped|
{
for _ in 0..4
{
scoped.execute(
move ||
{
sleep_ms(100);
}
);
}
}
);
}
#[test]
fn no_leak_state()
{
let counters = ::std::sync::Arc::new(());
let mut pool = Pool::new_threads_unbounded(4);
pool.scoped(
|scoped|
{
scoped.execute(
||
{
let _c = ::std::sync::Arc::clone(&counters);
}
);
}
);
drop(pool);
assert_eq!(::std::sync::Arc::strong_count(&counters), 1);
}
#[test]
fn safe_execute()
{
let mut pool = Pool::new_threads_unbounded(4);
pool.scoped(
|scoped|
{
scoped.execute(
move ||
{
}
);
}
);
}
#[test]
fn backlog_positive()
{
let mut pool = Pool::new_threads(4, 1);
pool.scoped(
|scoped|
{
let begin = ::std::time::Instant::now();
for _ in 0..16
{
scoped.execute(
move ||
{
sleep_ms(1000);
}
);
}
assert!(::std::time::Instant::now().duration_since(begin).as_secs() > 1);
}
);
}
#[test]
fn backlog_negative()
{
let mut pool = Pool::new_threads_unbounded(20);
pool.scoped(
|scoped|
{
let begin = ::std::time::Instant::now();
for _ in 0..120
{
scoped.execute(
move ||
{
sleep_ms(1000);
}
);
}
assert!(::std::time::Instant::now().duration_since(begin).as_secs() < 2);
}
);
}
#[test]
fn many_threads()
{
let mut pool = Pool::new_threads_unbounded(40);
pool.scoped(
|scoped|
{
for _ in 0..120
{
scoped.execute(
move ||
{
sleep_ms(200);
}
);
}
}
);
}
#[test]
fn state_creator_scope()
{
let sc = "hello".to_string();
let counter = ::std::sync::Mutex::new(0);
let mut pool = Pool::new_threads_unbounded(4);
pool.scoped(
|scoped|
{
let scoped = scoped.with_state(
||
{
*counter.lock().unwrap() += 1;
sc.clone()
}
);
for _ in 0..120
{
scoped.execute(
move |_|
{
}
);
}
}
);
assert_eq!(*counter.lock().unwrap(), 4);
}
#[test]
fn modify_state()
{
let counter = ::std::sync::Arc::new(::std::sync::Mutex::new(0));
let mut pool = Pool::new_threads_unbounded(4);
pool.scoped(
|scoped|
{
let scoped = scoped.with_state(
||
{
counter.clone()
}
);
for _ in 0..120
{
scoped.execute(
move |f|
{
*f.lock().unwrap() += 1;
}
);
}
}
);
assert_eq!(*counter.lock().unwrap(), 120);
}
#[test]
fn do_them_all()
{
let counter = ::std::sync::Arc::new(::std::sync::Mutex::new(0));
let mut pool = Pool::new_threads(512, 1000);
pool.scoped(
|scoped|
{
for _ in 0..3000
{
let counter = counter.clone();
scoped.execute(
move ||
{
*counter.lock().unwrap() += 1;
sleep_ms(1000);
}
);
}
}
);
assert_eq!(*counter.lock().unwrap(), 3000);
}
#[test]
fn state_maker_panic()
{
let mut pool = Pool::new_threads_unbounded(2);
let panic = ::std::panic::catch_unwind(
::std::panic::AssertUnwindSafe(
||
{
pool.scoped(
|scoped|
{
scoped.with_state(
|| panic!()
);
}
);
}
)
);
if let Err(e) = panic
{
let s = e.downcast_ref::<&str>().unwrap();
assert_eq!(
*s,
"worker thread panicked"
);
}
}
}