use std::sync::mpsc;
use std::sync::mpsc::{Receiver, RecvError, RecvTimeoutError, Sender};
use std::thread;
use std::thread::{JoinHandle, ThreadId};
use std::time::Duration;
#[derive(Debug,PartialEq)]
pub enum JoinError {
AllDone,
Panicked, Timeout,
Disconnected,
}
pub struct ThreadGroup<T> {
tx: Sender<ThreadId>,
rx: Receiver<ThreadId>,
handles: Vec<JoinHandle<T>>,
}
struct SendOnDrop {
tx: Sender<ThreadId>,
}
impl Drop for SendOnDrop {
fn drop(&mut self) {
self.tx.send(thread::current().id()).unwrap();
}
}
impl<T> ThreadGroup<T> {
pub fn new() -> ThreadGroup<T> {
let (tx, rx): (Sender<ThreadId>, Receiver<ThreadId>) = mpsc::channel();
ThreadGroup::<T>{tx: tx, rx: rx, handles: vec![]}
}
pub fn spawn<F, R>(&mut self, f: F)
where
F: FnOnce() -> T,
F: Send + 'static,
R: Send + 'static,
T: Send + 'static,
{
let thread_tx = self.tx.clone();
let jh: JoinHandle<T> = thread::spawn(move || {
let _sender = SendOnDrop{tx: thread_tx.clone()};
f()
});
self.handles.push(jh);
}
pub fn len(&self) -> usize {
self.handles.len()
}
pub fn is_empty(&self) -> bool {
self.handles.is_empty()
}
pub fn join(&mut self) -> Result<T, JoinError> {
match self.handles.is_empty() {
true => Err(JoinError::AllDone),
false => match self.rx.recv() {
Ok(id) => self.do_join(id),
Err(RecvError{}) => Err(JoinError::Disconnected)
}
}
}
pub fn join_timeout(&mut self, timeout: Duration) -> Result<T, JoinError> {
match self.handles.is_empty() {
true => Err(JoinError::AllDone),
false => match self.rx.recv_timeout(timeout) {
Ok(id) => self.do_join(id),
Err(RecvTimeoutError::Timeout) => Err(JoinError::Timeout),
Err(RecvTimeoutError::Disconnected) => Err(JoinError::Disconnected)
}
}
}
fn find(&self, id: ThreadId) -> Option<usize> {
for (i,jh) in self.handles.iter().enumerate() {
if jh.thread().id() == id {
return Some(i)
}
}
None
}
fn do_join(&mut self, id: ThreadId) -> Result<T, JoinError> {
let i = self.find(id).unwrap();
match self.handles.remove(i).join() {
Ok(ret) => Ok(ret),
Err(_) => Err(JoinError::Panicked),
}
}
}
#[cfg(test)]
mod tests {
use std::thread::sleep;
use std::time::Duration;
use ::{JoinError, ThreadGroup};
#[test]
fn empty_group() {
let mut tg: ThreadGroup<u32> = ThreadGroup::new();
assert!(tg.is_empty());
assert_eq!(tg.len(), 0);
assert_eq!(JoinError::AllDone, tg.join().unwrap_err());
}
#[test]
fn basic_join() {
let mut tg: ThreadGroup<u32> = ThreadGroup::new();
tg.spawn::<_,u32>(|| {sleep(Duration::new(0,1000000));1});
tg.spawn::<_,u32>(|| {sleep(Duration::new(0,3000000));3});
tg.spawn::<_,u32>(|| {sleep(Duration::new(0,2000000));2});
assert_eq!(1, tg.join().unwrap());
assert_eq!(2, tg.join().unwrap());
assert_eq!(3, tg.join().unwrap());
assert_eq!(JoinError::AllDone, tg.join().unwrap_err());
}
#[test]
fn panic_join() {
let mut tg: ThreadGroup<u32> = ThreadGroup::new();
tg.spawn::<_,u32>(|| {sleep(Duration::new(0,1500000));panic!()});
tg.spawn::<_,u32>(|| {sleep(Duration::new(0,1000000));1});
assert_eq!(1, tg.join().unwrap());
assert_eq!(JoinError::Panicked, tg.join().unwrap_err());
assert_eq!(JoinError::AllDone, tg.join().unwrap_err());
}
#[test]
fn timeout_join() {
let mut tg: ThreadGroup<u32> = ThreadGroup::new();
tg.spawn::<_,u32>(|| {sleep(Duration::new(1000000,0));2});
tg.spawn::<_,u32>(|| {sleep(Duration::new(0,1000000));1});
let t = Duration::new(1,0);
assert_eq!(1, tg.join_timeout(t).unwrap());
assert_eq!(JoinError::Timeout, tg.join_timeout(t).unwrap_err());
assert!(!tg.is_empty());
assert_eq!(tg.len(), 1);
}
}