use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::{self, Receiver, TryRecvError};
use std::panic::UnwindSafe;
use std::thread;
use failure::{self, Error};
use panic;
use error_handling;
#[macro_export]
macro_rules! export_task {
($( #[$attr:meta] )* Task: $Task:ty; spawn: $spawn:ident; $( $tokens:tt )*) => {
#[allow(dead_code)]
#[no_mangle]
$( #[$attr] )*
pub unsafe extern "C" fn $spawn(task: *const $Task) -> *mut $crate::task::TaskHandle<<$Task as $crate::Task>::Output> {
null_pointer_check!(task);
let task = (&*task).clone();
let handle = $crate::task::TaskHandle::spawn(task);
Box::into_raw(Box::new(handle))
}
export_task!($( #[$attr] )* Task: $Task; $( $tokens )*);
};
($( #[$attr:meta] )* Task: $Task:ty; poll: $poll:ident; $( $tokens:tt )*) => {
#[allow(dead_code)]
#[no_mangle]
$( #[$attr] )*
pub unsafe extern "C" fn $poll(handle: *mut $crate::task::TaskHandle<<$Task as $crate::Task>::Output>) -> *mut <$Task as $crate::Task>::Output {
null_pointer_check!(handle);
match (&*handle).poll() {
Some(Ok(value)) => Box::into_raw(Box::new(value)),
Some(Err(e)) => {
$crate::error_handling::update_last_error(e);
::std::ptr::null_mut()
}
None => ::std::ptr::null_mut()
}
}
export_task!($( #[$attr] )* Task: $Task; $( $tokens )*);
};
($( #[$attr:meta] )* Task: $Task:ty; handle_destroy: $handle_destructor:ident; $( $tokens:tt )*) => {
#[allow(dead_code)]
#[no_mangle]
$( #[$attr] )*
pub unsafe extern "C" fn $handle_destructor(handle: *mut $crate::task::TaskHandle<<$Task as $crate::Task>::Output>) {
null_pointer_check!(handle);
let handle = Box::from_raw(handle);
drop(handle);
}
export_task!($( #[$attr] )* Task: $Task; $( $tokens )*);
};
($( #[$attr:meta] )* Task: $Task:ty; result_destroy: $result_destroy:ident; $( $tokens:tt )*) => {
#[allow(dead_code)]
#[no_mangle]
$( #[$attr] )*
pub unsafe extern "C" fn $result_destroy(result: *mut <$Task as $crate::Task>::Output) {
null_pointer_check!(result);
let result = Box::from_raw(result);
drop(result);
}
export_task!($( #[$attr] )* Task: $Task; $( $tokens )*);
};
($( #[$attr:meta] )* Task: $Task:ty; wait: $wait:ident; $( $tokens:tt )*) => {
#[allow(dead_code)]
#[no_mangle]
$( #[$attr] )*
pub unsafe extern "C" fn $wait(handle: *mut $crate::task::TaskHandle<<$Task as $crate::Task>::Output>)
-> *mut <$Task as $crate::Task>::Output
{
null_pointer_check!(handle);
let handle = Box::from_raw(handle);
let result = handle.wait();
match result {
Ok(value) => Box::into_raw(Box::new(value)),
Err(e) => {
$crate::update_last_error(e);
::std::ptr::null_mut()
}
}
}
export_task!($( #[$attr] )* Task: $Task; $( $tokens )*);
};
($( #[$attr:meta] )* Task: $Task:ty; cancel: $cancel:ident; $( $tokens:tt )*) => {
#[allow(dead_code)]
#[no_mangle]
$( #[$attr] )*
pub unsafe extern "C" fn $cancel(handle: *mut $crate::task::TaskHandle<<$Task as $crate::Task>::Output>) {
null_pointer_check!(handle);
(&*handle).cancel();
}
export_task!($( #[$attr] )* Task: $Task; $( $tokens )*);
};
($( #[$attr:meta] )* Task: $Task:ty; cancelled: $cancelled:ident; $( $tokens:tt )*) => {
#[allow(dead_code)]
#[no_mangle]
$( #[$attr] )*
pub unsafe extern "C" fn $cancelled(handle: *mut $crate::task::TaskHandle<<$Task as $crate::Task>::Output>) -> ::std::os::raw::c_int {
null_pointer_check!(handle);
if (&*handle).cancelled() {
1
} else {
0
}
}
export_task!($( #[$attr] )* Task: $Task; $( $tokens )*);
};
($( #[$attr:meta] )* Task: $Task:ty;) => {};
}
pub trait Task: Send + Sync + Clone {
type Output: Send + Sync;
fn run(&self, cancel_tok: &CancellationToken) -> Result<Self::Output, Error>;
}
#[derive(Debug, Clone)]
pub struct CancellationToken(Arc<AtomicBool>);
impl CancellationToken {
pub fn new() -> CancellationToken {
CancellationToken(Arc::new(AtomicBool::new(false)))
}
pub fn cancelled(&self) -> bool {
self.0.load(Ordering::SeqCst)
}
pub fn cancel(&self) {
self.0.store(true, Ordering::SeqCst);
}
pub fn is_done(&self) -> Result<(), Cancelled> {
if self.cancelled() {
Err(Cancelled)
} else {
Ok(())
}
}
}
impl Default for CancellationToken {
fn default() -> CancellationToken {
CancellationToken::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Fail)]
#[fail(display = "The task was cancelled")]
pub struct Cancelled;
pub struct TaskHandle<T> {
result: Receiver<Result<T, Error>>,
token: CancellationToken,
}
impl<T> TaskHandle<T> {
pub fn spawn<K>(task: K) -> TaskHandle<T>
where
K: Task<Output = T> + UnwindSafe + Send + Sync + 'static,
T: Send + Sync + 'static,
{
let (tx, rx) = mpsc::channel();
let cancel_tok = CancellationToken::new();
let tok_2 = cancel_tok.clone();
thread::spawn(move || {
error_handling::clear_last_error();
let got = panic::catch_panic(move || task.run(&tok_2)).map_err(|_| {
let e = error_handling::take_last_error();
e.unwrap_or_else(|| failure::err_msg("The task failed"))
});
tx.send(got).ok();
});
TaskHandle {
result: rx,
token: cancel_tok,
}
}
pub fn poll(&self) -> Option<Result<T, Error>> {
match self.result.try_recv() {
Ok(value) => Some(value),
Err(TryRecvError::Empty) => None,
Err(e) => Some(Err(e.into())),
}
}
pub fn wait(self) -> Result<T, Error> {
match self.result.recv() {
Ok(Ok(value)) => Ok(value),
Ok(Err(e)) => Err(e),
Err(recv_err) => Err(recv_err.into()),
}
}
pub fn cancel(&self) {
self.token.cancel();
}
pub fn cancelled(&self) -> bool {
self.token.cancelled()
}
}
impl<T> Drop for TaskHandle<T> {
fn drop(&mut self) {
self.token.cancel();
}
}
#[cfg(test)]
#[allow(private_no_mangle_fns)]
mod tests {
use super::*;
use std::time::Duration;
use panic::Panic;
#[derive(Debug, Clone, Copy)]
pub struct Spin;
impl Task for Spin {
type Output = usize;
fn run(&self, cancel_tok: &CancellationToken) -> Result<Self::Output, Error> {
let mut spins = 0;
while !cancel_tok.cancelled() {
thread::sleep(Duration::from_millis(10));
spins += 1;
}
Ok(spins)
}
}
#[test]
fn spawn_a_task() {
let task = Spin;
let handle = TaskHandle::spawn(task);
for _ in 0..10 {
thread::sleep(Duration::from_millis(10));
let got = handle.poll();
assert!(got.is_none());
}
handle.cancel();
let got = handle.wait().unwrap();
assert!(9 <= got && got <= 12);
}
export_task!{
Task: Spin;
spawn: spin_spawn;
wait: spin_wait;
poll: spin_poll;
cancel: spin_cancel;
cancelled: spin_cancelled;
handle_destroy: spin_handle_destroy;
result_destroy: spin_result_destroy;
}
#[test]
fn use_the_c_api() {
use error_handling::*;
let s = Spin;
unsafe {
let handle = spin_spawn(&s);
assert_eq!(
spin_cancelled(handle),
0,
"The spin shouldn't have been cancelled yet"
);
clear_last_error();
let ret = spin_poll(handle);
assert!(ret.is_null(), "The task should still be running");
assert_eq!(
last_error_length(),
0,
"There shouldn't have been any errors"
);
spin_cancel(handle);
let got = spin_wait(handle);
assert_eq!(
last_error_length(),
0,
"There shouldn't have been any errors"
);
assert!(!got.is_null(), "Oops!");
}
}
#[derive(Copy, Clone)]
struct PanicTask;
const PANIC_MESSAGE: &str = "Oops";
impl Task for PanicTask {
type Output = ();
fn run(&self, _: &CancellationToken) -> Result<Self::Output, Error> {
panic!(PANIC_MESSAGE)
}
}
#[test]
fn task_can_catch_panic_messages() {
let task = PanicTask;
let err = TaskHandle::spawn(task).wait().unwrap_err();
if let Some(p) = err.downcast_ref::<Panic>() {
assert_eq!(p.message, PANIC_MESSAGE);
} else {
panic!("Expected a panic failure, got {}", err);
}
}
}