use std::any::Any;
use std::cell::RefCell;
use std::panic::{self, AssertUnwindSafe, UnwindSafe};
use context::Context;
use context::stack::{Stack, ProtectedFixedSizeStack};
use futures::{Async, Future, Poll};
use futures::unsync::oneshot::{self, Receiver};
use errors::{Dropped, StackError, TaskFailed};
use switch::{Switch, WaitTask};
enum TaskResult<R> {
Panicked(Box<Any + Send + 'static>),
PanicPropagated,
Lost,
Finished(R),
}
pub struct CoroutineResult<R> {
receiver: Receiver<TaskResult<R>>,
}
impl<R> Future for CoroutineResult<R> {
type Item = R;
type Error = TaskFailed;
fn poll(&mut self) -> Poll<R, TaskFailed> {
match self.receiver.poll() {
Ok(Async::NotReady) => Ok(Async::NotReady),
Ok(Async::Ready(TaskResult::Finished(result))) => Ok(Async::Ready(result)),
Ok(Async::Ready(TaskResult::Panicked(reason))) => Err(TaskFailed::Panicked(reason)),
Ok(Async::Ready(TaskResult::PanicPropagated)) => Err(TaskFailed::PanicPropagated),
Ok(Async::Ready(TaskResult::Lost)) | Err(_) => Err(TaskFailed::Lost),
}
}
}
#[derive(Copy, Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub enum CleanupStrategy {
CleanupAlways,
LeakOnPanic,
LeakAlways,
AbortOnPanic,
AbortAlways,
}
struct CoroutineContext {
parent_context: Context,
stack: ProtectedFixedSizeStack,
cleanup_strategy: CleanupStrategy,
}
thread_local! {
static CONTEXTS: RefCell<Vec<CoroutineContext>> = RefCell::new(Vec::new());
static BUILDER: RefCell<Coroutine> = RefCell::new(Coroutine::new());
}
#[derive(Clone, Debug)]
pub struct Coroutine {
stack_size: usize,
cleanup_strategy: CleanupStrategy,
}
impl Coroutine {
pub fn new() -> Self {
Coroutine {
stack_size: Stack::default_size(),
cleanup_strategy: CleanupStrategy::CleanupAlways,
}
}
pub fn stack_size(&mut self, size: usize) -> &mut Self {
self.stack_size = size;
self
}
pub fn cleanup_strategy(&mut self, strategy: CleanupStrategy) -> &mut Self {
self.cleanup_strategy = strategy;
self
}
pub fn with_defaults<R, Task>(task: Task) -> CoroutineResult<R>
where
R: 'static,
Task: FnOnce() -> R + 'static,
{
Coroutine::new().spawn(task).unwrap()
}
fn spawn_inner<R, Task>(&self, task: Task, propagate_panic: bool)
-> Result<CoroutineResult<R>, StackError>
where
R: 'static,
Task: FnOnce() -> R + UnwindSafe + 'static,
{
let (sender, receiver) = oneshot::channel();
let cleanup_strategy = self.cleanup_strategy;
let perform = move |context, stack| {
let my_context = CoroutineContext {
parent_context: context,
stack,
cleanup_strategy,
};
CONTEXTS.with(|c| c.borrow_mut().push(my_context));
let mut panic_result = None;
let result = match panic::catch_unwind(AssertUnwindSafe(task)) {
Ok(res) => TaskResult::Finished(res),
Err(panic) => {
if panic.is::<Dropped>() {
TaskResult::Lost
} else if propagate_panic {
panic_result = Some(panic);
TaskResult::PanicPropagated
} else {
TaskResult::Panicked(panic)
}
},
};
drop(sender.send(result));
let my_context = CONTEXTS.with(|c| c.borrow_mut().pop().unwrap());
(my_context.parent_context, my_context.stack, panic_result)
};
Switch::run_new_coroutine(self.stack_size, Box::new(Some(perform)))?;
Ok(CoroutineResult { receiver })
}
pub fn spawn<R, Task>(&self, task: Task) -> Result<CoroutineResult<R>, StackError>
where
R: 'static,
Task: FnOnce() -> R + 'static,
{
self.spawn_inner(AssertUnwindSafe(task), true)
}
pub fn spawn_catch_panic<R, Task>(&self, task: Task) -> Result<CoroutineResult<R>, StackError>
where
R: 'static,
Task: FnOnce() -> R + UnwindSafe + 'static,
{
self.spawn_inner(task, false)
}
pub fn wait<I, E, Fut>(mut fut: Fut) -> Result<Result<I, E>, Dropped>
where
Fut: Future<Item = I, Error = E>,
{
let my_context = CONTEXTS.with(|c| {
c.borrow_mut().pop().expect("Can't wait outside of a coroutine")
});
let mut result: Option<Result<I, E>> = None;
let (reply_instruction, context) = {
let res_ref = &mut result as *mut _ as usize;
let fut_ref = &mut fut as *mut _ as usize;
let mut poll = move || {
let fut = fut_ref as *mut Fut;
let res = match unsafe { fut.as_mut() }.unwrap().poll() {
Ok(Async::NotReady) => return Ok(Async::NotReady),
Ok(Async::Ready(ok)) => Ok(ok),
Err(err) => Err(err),
};
let result = res_ref as *mut Option<Result<I, E>>;
unsafe { *result = Some(res) };
Ok(Async::Ready(()))
};
let task = WaitTask {
poll: &mut poll,
context: None,
cleanup_strategy: my_context.cleanup_strategy,
stack: Some(my_context.stack),
};
let instruction = Switch::WaitFuture { task };
instruction.exchange(my_context.parent_context)
};
let (result, stack) = match reply_instruction {
Switch::Resume { stack } => (Ok(Ok(result.unwrap())), stack),
Switch::Cleanup { stack } => (Ok(Err(Dropped)), stack),
Switch::PropagateFuturePanic { stack, panic } => (Err(panic), stack),
_ => unreachable!("Invalid instruction on wakeup"),
};
let new_context = CoroutineContext {
parent_context: context,
stack,
cleanup_strategy: my_context.cleanup_strategy,
};
CONTEXTS.with(|c| c.borrow_mut().push(new_context));
match result {
Ok(result) => result,
Err(panic) => panic::resume_unwind(panic),
}
}
pub fn verify(&self) -> Result<(), StackError> {
self.spawn(|| ()).map(|_| ())
}
pub fn set_thread_local(&self) -> Result<(), StackError> {
self.verify()?;
BUILDER.with(|builder| builder.replace(self.clone()));
Ok(())
}
pub fn from_thread_local() -> Self {
BUILDER.with(|builder| builder.borrow().clone())
}
#[cfg(feature = "convenient-run")]
pub fn run<R, Task>(&self, task: Task) -> Result<R, StackError>
where
R: 'static,
Task: FnOnce() -> R + 'static,
{
self.set_thread_local()?;
let result = ::tokio::runtime::current_thread::block_on_all(::futures::future::lazy(|| {
spawn(task)
})).expect("Lost a coroutine when waiting for all of them");
Ok(result)
}
}
impl Default for Coroutine {
fn default() -> Self {
Self::new()
}
}
pub fn spawn<R, Task>(task: Task) -> CoroutineResult<R>
where
R: 'static,
Task: FnOnce() -> R + 'static,
{
BUILDER.with(|builder| builder.borrow().spawn(task))
.expect("Unverified builder in thread local storage")
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicBool, Ordering};
use std::rc::Rc;
use std::time::Duration;
use futures::future;
use tokio::clock;
use tokio::prelude::*;
use tokio::runtime::current_thread::{self, Runtime};
use tokio::timer::Delay;
use super::*;
#[test]
fn spawn_some() {
let s1 = Rc::new(AtomicBool::new(false));
let s2 = Rc::new(AtomicBool::new(false));
let s1c = s1.clone();
let s2c = s2.clone();
let mut builder = Coroutine::new();
builder.stack_size(40960);
let builder_inner = builder.clone();
let result = builder.spawn(move || {
let result = builder_inner.spawn(move || {
s2c.store(true, Ordering::Relaxed);
42
})
.unwrap();
s1c.store(true, Ordering::Relaxed);
result
})
.unwrap();
assert!(s1.load(Ordering::Relaxed), "The outer closure didn't run");
assert!(s2.load(Ordering::Relaxed), "The inner closure didn't run");
let extract = result.and_then(|r| r);
assert_eq!(42, current_thread::block_on_all(extract).unwrap());
}
#[test]
fn coroutine_run() {
let result = Coroutine::new().run(|| {
Coroutine::wait(future::ok::<(), ()>(())).unwrap().unwrap();
42
}).unwrap();
assert_eq!(42, result);
}
#[test]
fn future_wait() {
let result = Coroutine::new().run(|| {
let (sender, receiver) = oneshot::channel();
let all_done = Coroutine::with_defaults(move || {
let msg = Coroutine::wait(receiver).unwrap().unwrap();
msg
});
Coroutine::with_defaults(move || {
let timeout = Delay::new(clock::now() + Duration::from_millis(50));
Coroutine::wait(timeout).unwrap().unwrap();
sender.send(42).unwrap();
});
Coroutine::wait(all_done).unwrap().unwrap()
});
assert_eq!(42, result.unwrap());
}
#[test]
fn panics_catch() {
let mut rt = Runtime::new().unwrap();
let catch_panic = future::lazy(|| {
Coroutine::new().spawn_catch_panic(|| panic!("Test")).unwrap()
});
match rt.block_on(catch_panic) {
Err(TaskFailed::Panicked(_)) => (),
_ => panic!("Panic not reported properly"),
}
assert_eq!(42, rt.block_on(future::lazy(|| Coroutine::with_defaults(|| 42))).unwrap());
}
#[test]
#[should_panic]
fn panics_spawn() {
let _ = Coroutine::new().run(|| {
spawn(|| panic!("Test"))
});
}
#[test]
fn panics_run() {
panic::catch_unwind(|| {
current_thread::block_on_all(future::lazy(|| {
Coroutine::with_defaults(|| {
let _ = Coroutine::wait(future::ok::<(), ()>(()));
panic!("Test");
})
}))
}).unwrap_err();
}
#[test]
#[should_panic]
fn panic_without_coroutine() {
drop(Coroutine::wait(future::ok::<_, ()>(42)));
}
#[test]
fn panic_leak() {
panic::catch_unwind(|| current_thread::block_on_all(future::lazy(|| -> Result<(), ()> {
let _coroutine = Coroutine::new()
.cleanup_strategy(CleanupStrategy::LeakOnPanic)
.spawn(|| {
let _ = Coroutine::wait(future::empty::<(), ()>());
panic!("Should never get here!");
})
.unwrap();
panic!("Test");
}))).unwrap_err();
}
#[test]
fn leak_always() {
let mut rt = Runtime::new().unwrap();
let _ = rt.block_on(future::lazy(|| {
Coroutine::new()
.cleanup_strategy(CleanupStrategy::LeakAlways)
.spawn(|| {
struct Destroyer;
impl Drop for Destroyer {
fn drop(&mut self) {
panic!("Destructor called");
}
}
let _d = Destroyer;
let _ = Coroutine::wait(future::empty::<(), ()>());
})
.unwrap();
Ok::<(), ()>(())
}));
drop(rt);
}
#[test]
fn panic_in_future() {
current_thread::block_on_all(future::lazy(|| {
Coroutine::with_defaults(|| {
struct PanicFuture;
impl Future for PanicFuture {
type Item = ();
type Error = ();
fn poll(&mut self) -> Poll<(), ()> {
panic!("Test");
}
}
if let Ok(_) = panic::catch_unwind(|| Coroutine::wait(PanicFuture)) {
panic!("A panic should fall out of wait");
}
})
})).unwrap();
}
}