use std::any::Any;
use std::panic::{self, AssertUnwindSafe};
use std::process;
use std::thread;
use context::{Context, Transfer};
use context::stack::ProtectedFixedSizeStack;
use futures::{Async, Future, Poll};
use tokio_current_thread::TaskExecutor;
use coroutine::CleanupStrategy;
use errors::StackError;
use stack_cache;
pub(crate) trait BoxableTask {
fn perform(&mut self, Context, ProtectedFixedSizeStack) ->
(Context, ProtectedFixedSizeStack, Option<Box<Any + Send>>);
}
impl<F> BoxableTask for Option<F>
where
F: FnOnce(Context, ProtectedFixedSizeStack) ->
(Context, ProtectedFixedSizeStack, Option<Box<Any + Send>>),
{
fn perform(&mut self, context: Context, stack: ProtectedFixedSizeStack) ->
(Context, ProtectedFixedSizeStack, Option<Box<Any + Send>>)
{
self.take().unwrap()(context, stack)
}
}
type BoxedTask = Box<BoxableTask>;
pub(crate) struct WaitTask {
pub(crate) poll: *mut FnMut() -> Poll<(), ()>,
pub(crate) context: Option<Context>,
pub(crate) stack: Option<ProtectedFixedSizeStack>,
pub(crate) cleanup_strategy: CleanupStrategy,
}
impl Future for WaitTask {
type Item = ();
type Error = ();
fn poll(&mut self) -> Poll<(), ()> {
assert!(self.context.is_some());
match panic::catch_unwind(AssertUnwindSafe(unsafe {
self.poll
.as_mut()
.unwrap()
})) {
Ok(Ok(Async::NotReady)) => Ok(Async::NotReady),
Ok(result) => {
Switch::Resume {
stack: self.stack.take().unwrap(),
}
.run_child(self.context.take().unwrap());
result
},
Err(panic) => {
Switch::PropagateFuturePanic {
stack: self.stack.take().unwrap(),
panic
}
.run_child(self.context.take().unwrap());
Err(())
},
}
}
}
impl Drop for WaitTask {
fn drop(&mut self) {
if let Some(context) = self.context.take() {
let perform_cleanup = match (self.cleanup_strategy, thread::panicking()) {
(CleanupStrategy::CleanupAlways, _)
| (CleanupStrategy::LeakOnPanic, false)
| (CleanupStrategy::AbortOnPanic, false) => true,
(CleanupStrategy::LeakAlways, _)
| (CleanupStrategy::LeakOnPanic, true) => false,
(CleanupStrategy::AbortAlways, _)
| (CleanupStrategy::AbortOnPanic, true) => {
process::abort();
}
};
if perform_cleanup {
Switch::Cleanup {
stack: self.stack.take().expect("Taken stack, but not context?")
}
.run_child(context);
}
}
}
}
fn coroutine_internal(transfer: Transfer) -> (Switch, Context) {
let mut context = transfer.context;
let switch = Switch::extract(transfer.data);
let result = match switch {
Switch::StartTask { stack, mut task } => {
let (ctx, stack, panic) = task.perform(context, stack);
context = ctx;
Switch::Destroy {
stack,
panic,
}
},
_ => panic!("Invalid switch instruction on coroutine entry"),
};
(result, context)
}
extern "C" fn coroutine(transfer: Transfer) -> ! {
let (result, context) = coroutine_internal(transfer);
result.exchange(context);
unreachable!("Woken up after termination!");
}
pub(crate) enum Switch {
StartTask {
stack: ProtectedFixedSizeStack,
task: BoxedTask,
},
WaitFuture {
task: WaitTask,
},
PropagateFuturePanic {
stack: ProtectedFixedSizeStack,
panic: Box<Any + Send>,
},
Resume {
stack: ProtectedFixedSizeStack,
},
Cleanup {
stack: ProtectedFixedSizeStack,
},
Destroy {
stack: ProtectedFixedSizeStack,
panic: Option<Box<Any + Send>>,
},
}
impl Switch {
fn extract(transfer_data: usize) -> Switch {
let ptr = transfer_data as *mut Option<Self>;
let optref = unsafe { ptr.as_mut() }
.expect("NULL pointer passed through a coroutine switch");
optref.take().expect("Switch instruction already extracted")
}
pub(crate) fn exchange(self, context: Context) -> (Self, Context) {
let mut sw = Some(self);
let swp: *mut Option<Self> = &mut sw;
let transfer = unsafe { context.resume(swp as usize) };
(Self::extract(transfer.data), transfer.context)
}
pub(crate) fn run_child(self, context: Context) {
let (reply, context) = self.exchange(context);
use self::Switch::*;
match reply {
Destroy { stack, panic } => {
drop(context);
stack_cache::put(stack);
if let Some(panic) = panic {
panic::resume_unwind(panic);
}
},
WaitFuture { mut task } => {
task.context = Some(context);
let _ = TaskExecutor::current().spawn_local(Box::new(task));
},
_ => unreachable!("Invalid switch instruction when switching out"),
}
}
pub(crate) fn run_new_coroutine(stack_size: usize, task: BoxedTask) -> Result<(), StackError> {
let stack = stack_cache::get(stack_size)?;
assert!(stack.len() >= stack_size);
let context = unsafe { Context::new(&stack, coroutine) };
Switch::StartTask { stack, task }.run_child(context);
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::cell::Cell;
use std::rc::Rc;
use super::*;
#[test]
fn switch_coroutine() {
let called = Rc::new(Cell::new(false));
let called_cp = called.clone();
let task = move |context, stack| {
called_cp.set(true);
(context, stack, None)
};
Switch::run_new_coroutine(40960, Box::new(Some(task))).unwrap();
assert!(called.get());
assert_eq!(1, Rc::strong_count(&called));
}
}