use std::cell::RefCell;
use std::fmt;
use std::marker::PhantomData;
use std::mem;
use std::ops::DerefMut;
use std::rc::Rc;
use std::thread;
use std::io;
#[doc(hidden)]
trait FnBox<T> {
fn call_box(self: Box<Self>) -> T;
}
impl<T, F: FnOnce() -> T> FnBox<T> for F {
fn call_box(self: Box<Self>) -> T {
(*self)()
}
}
pub unsafe fn spawn_unsafe<'a, F>(f: F) -> thread::JoinHandle<()>
where
F: FnOnce() + Send + 'a,
{
let builder = thread::Builder::new();
builder_spawn_unsafe(builder, f).unwrap()
}
pub unsafe fn builder_spawn_unsafe<'a, F>(
builder: thread::Builder,
f: F,
) -> io::Result<thread::JoinHandle<()>>
where
F: FnOnce() + Send + 'a,
{
let closure: Box<FnBox<()> + 'a> = Box::new(f);
let closure: Box<FnBox<()> + Send> = mem::transmute(closure);
builder.spawn(move || closure.call_box())
}
pub struct Scope<'a> {
dtors: RefCell<Option<DtorChain<'a, ()>>>,
_marker: PhantomData<*const ()>,
}
struct DtorChain<'a, T> {
dtor: Box<FnBox<T> + 'a>,
next: Option<Box<DtorChain<'a, T>>>,
}
impl<'a, T> DtorChain<'a, T> {
pub fn pop(chain: &mut Option<DtorChain<'a, T>>) -> Option<Box<FnBox<T> + 'a>> {
chain.take().map(|mut node| {
*chain = node.next.take().map(|b| *b);
node.dtor
})
}
}
struct JoinState<T> {
join_handle: thread::JoinHandle<()>,
result: usize,
_marker: PhantomData<T>,
}
impl<T: Send> JoinState<T> {
fn new(join_handle: thread::JoinHandle<()>, result: usize) -> JoinState<T> {
JoinState {
join_handle: join_handle,
result: result,
_marker: PhantomData,
}
}
fn join(self) -> thread::Result<T> {
let result = self.result;
self.join_handle.join().map(|_| {
unsafe { *Box::from_raw(result as *mut T) }
})
}
}
pub struct ScopedJoinHandle<'a, T: 'a> {
inner: Rc<RefCell<Option<JoinState<T>>>>,
thread: thread::Thread,
_marker: PhantomData<&'a T>,
}
pub fn scope<'a, F, R>(f: F) -> R
where
F: FnOnce(&Scope<'a>) -> R,
{
let mut scope = Scope {
dtors: RefCell::new(None),
_marker: PhantomData,
};
let ret = f(&scope);
scope.drop_all();
ret
}
impl<'a> fmt::Debug for Scope<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Scope {{ ... }}")
}
}
impl<'a, T> fmt::Debug for ScopedJoinHandle<'a, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "ScopedJoinHandle {{ ... }}")
}
}
impl<'a> Scope<'a> {
fn drop_all(&mut self) {
while let Some(dtor) = DtorChain::pop(&mut self.dtors.borrow_mut()) {
dtor.call_box();
}
}
pub fn defer<F>(&self, f: F)
where
F: FnOnce() + 'a,
{
let mut dtors = self.dtors.borrow_mut();
*dtors = Some(DtorChain {
dtor: Box::new(f),
next: dtors.take().map(Box::new),
});
}
pub fn spawn<'s, F, T>(&'s self, f: F) -> ScopedJoinHandle<'a, T>
where
'a: 's,
F: FnOnce() -> T + Send + 'a,
T: Send + 'a,
{
self.builder().spawn(f).unwrap()
}
pub fn builder<'s>(&'s self) -> ScopedThreadBuilder<'s, 'a> {
ScopedThreadBuilder {
scope: self,
builder: thread::Builder::new(),
}
}
}
pub struct ScopedThreadBuilder<'s, 'a: 's> {
scope: &'s Scope<'a>,
builder: thread::Builder,
}
impl<'s, 'a: 's> ScopedThreadBuilder<'s, 'a> {
pub fn name(mut self, name: String) -> ScopedThreadBuilder<'s, 'a> {
self.builder = self.builder.name(name);
self
}
pub fn stack_size(mut self, size: usize) -> ScopedThreadBuilder<'s, 'a> {
self.builder = self.builder.stack_size(size);
self
}
pub fn spawn<F, T>(self, f: F) -> io::Result<ScopedJoinHandle<'a, T>>
where
F: FnOnce() -> T + Send + 'a,
T: Send + 'a,
{
let result = Box::into_raw(Box::<T>::new(unsafe { mem::uninitialized() })) as usize;
let join_handle = try!(unsafe {
builder_spawn_unsafe(self.builder, move || {
let mut result = Box::from_raw(result as *mut T);
*result = f();
mem::forget(result);
})
});
let thread = join_handle.thread().clone();
let join_state = JoinState::<T>::new(join_handle, result);
let deferred_handle = Rc::new(RefCell::new(Some(join_state)));
let my_handle = deferred_handle.clone();
self.scope.defer(move || {
let state = mem::replace(deferred_handle.borrow_mut().deref_mut(), None);
if let Some(state) = state {
state.join().unwrap();
}
});
Ok(ScopedJoinHandle {
inner: my_handle,
thread: thread,
_marker: PhantomData,
})
}
}
impl<'a, T: Send + 'a> ScopedJoinHandle<'a, T> {
pub fn join(self) -> thread::Result<T> {
let state = mem::replace(self.inner.borrow_mut().deref_mut(), None);
state.unwrap().join()
}
pub fn thread(&self) -> &thread::Thread {
&self.thread
}
}
impl<'a> Drop for Scope<'a> {
fn drop(&mut self) {
self.drop_all()
}
}