use job::{Job, JobMode, JobRef};
use std::any::Any;
use std::cell::UnsafeCell;
use std::marker::PhantomData;
use std::mem;
use std::ptr;
use std::sync::atomic::{AtomicUsize, AtomicPtr, Ordering};
use std::sync::{Condvar, Mutex};
use thread_pool::{self, WorkerThread};
use unwind;
#[cfg(test)]
mod test;
pub struct Scope<'scope> {
counter: AtomicUsize,
panic: AtomicPtr<Box<Any + Send + 'static>>,
mutex: Mutex<()>,
job_completed_cvar: Condvar,
marker: PhantomData<fn(&'scope ())>,
}
pub fn scope<'scope, OP, R>(op: OP) -> R
where OP: for<'s> FnOnce(&'s Scope<'scope>) -> R
{
let scope = Scope {
counter: AtomicUsize::new(1),
panic: AtomicPtr::new(ptr::null_mut()),
mutex: Mutex::new(()),
job_completed_cvar: Condvar::new(),
marker: PhantomData,
};
let result = op(&scope);
scope.job_completed_ok(); scope.block_till_jobs_complete();
result
}
impl<'scope> Scope<'scope> {
pub fn spawn<BODY>(&self, body: BODY)
where BODY: FnOnce(&Scope<'scope>) + 'scope
{
unsafe {
let old_value = self.counter.fetch_add(1, Ordering::SeqCst);
assert!(old_value > 0); let job_ref = Box::new(HeapJob::new(self, body)).as_job_ref();
let worker_thread = WorkerThread::current();
if !worker_thread.is_null() {
let worker_thread = &*worker_thread;
let spawn_count = worker_thread.spawn_count();
spawn_count.set(spawn_count.get() + 1);
worker_thread.push(job_ref);
} else {
thread_pool::get_registry().inject(&[job_ref]);
}
}
}
fn job_panicked(&self, err: Box<Any + Send + 'static>) {
let nil = ptr::null_mut();
let mut err = Box::new(err); if self.panic.compare_and_swap(nil, &mut *err, Ordering::SeqCst).is_null() {
mem::forget(err); }
self.job_completed_ok()
}
fn job_completed_ok(&self) {
let old_value = self.counter.fetch_sub(1, Ordering::Release);
if old_value == 1 {
let _guard = self.mutex.lock().unwrap();
self.job_completed_cvar.notify_all();
}
}
fn block_till_jobs_complete(&self) {
let mut guard = self.mutex.lock().unwrap();
while self.counter.load(Ordering::Acquire) > 0 {
guard = self.job_completed_cvar.wait(guard).unwrap();
}
let panic = self.panic.swap(ptr::null_mut(), Ordering::Relaxed);
if !panic.is_null() {
unsafe {
let value: Box<Box<Any + Send + 'static>> = mem::transmute(panic);
unwind::resume_unwinding(*value);
}
}
}
}
struct HeapJob<'scope, BODY>
where BODY: FnOnce(&Scope<'scope>) + 'scope,
{
scope: *const Scope<'scope>,
func: UnsafeCell<Option<BODY>>,
}
impl<'scope, BODY> HeapJob<'scope, BODY>
where BODY: FnOnce(&Scope<'scope>) + 'scope
{
fn new(scope: *const Scope<'scope>, func: BODY) -> Self {
HeapJob {
scope: scope,
func: UnsafeCell::new(Some(func))
}
}
unsafe fn as_job_ref(self: Box<Self>) -> JobRef {
let this: *const Self = mem::transmute(self);
JobRef::new(this)
}
unsafe fn pop_jobs(worker_thread: &WorkerThread, start_count: usize) {
let spawn_count = worker_thread.spawn_count();
let current_count = spawn_count.get();
for _ in start_count .. current_count {
if let Some(job_ref) = worker_thread.pop() {
job_ref.execute(JobMode::Execute);
}
}
spawn_count.set(start_count);
}
}
impl<'scope, BODY> Job for HeapJob<'scope, BODY>
where BODY: FnOnce(&Scope<'scope>) + 'scope
{
unsafe fn execute(this: *const Self, mode: JobMode) {
let this: Box<Self> = mem::transmute(this);
let scope = &*this.scope;
match mode {
JobMode::Execute => {
let worker_thread = &*WorkerThread::current();
let start_count = worker_thread.spawn_count().get();
let func = (*this.func.get()).take().unwrap();
match unwind::halt_unwinding(|| func(&*scope)) {
Ok(()) => { (*scope).job_completed_ok(); }
Err(err) => { (*scope).job_panicked(err); }
}
Self::pop_jobs(worker_thread, start_count);
}
JobMode::Abort => {
(*this.scope).job_completed_ok();
}
}
}
}