use std::default::Default;
use std::rt::util::min_stack;
use thunk::Thunk;
use std::mem::transmute;
use std::rt::unwind::try;
use std::any::Any;
use std::cell::UnsafeCell;
use std::ops::Deref;
use std::ptr;
use std::sync::Arc;
use std::cell::RefCell;
use context::Context;
use stack::{StackPool, Stack};
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum State {
Normal,
Suspended,
Blocked,
Running,
Finished,
Panicked,
}
pub type ResumeResult<T> = Result<T, Box<Any + Send>>;
#[derive(Debug)]
pub struct Options {
pub stack_size: usize,
pub name: Option<String>,
}
impl Default for Options {
fn default() -> Options {
Options {
stack_size: min_stack(),
name: None,
}
}
}
#[derive(Debug, Clone)]
pub struct Handle(Arc<RefCell<Coroutine>>);
unsafe impl Send for Handle {}
unsafe impl Sync for Handle {}
impl Handle {
fn new(c: Coroutine) -> Handle {
Handle(Arc::new(RefCell::new(c)))
}
unsafe fn get_inner_mut(&self) -> &mut Coroutine {
&mut *self.0.as_unsafe_cell().get()
}
unsafe fn get_inner(&self) -> &Coroutine {
&*self.0.as_unsafe_cell().get()
}
pub fn resume(&self) -> ResumeResult<()> {
match self.state() {
State::Finished | State::Running => return Ok(()),
State::Panicked => panic!("Trying to resume a panicked coroutine"),
_ => {}
}
let env = Environment::current();
let from_coro_hdl = Coroutine::current();
{
let from_coro: &mut Coroutine = unsafe { from_coro_hdl.get_inner_mut() };
let to_coro: &mut Coroutine = unsafe { self.get_inner_mut() };
to_coro.set_state(State::Running);
to_coro.parent = from_coro;
from_coro.set_state(State::Normal);
env.current_running = self.clone();
Context::swap(&mut from_coro.saved_context, &to_coro.saved_context);
}
env.current_running = from_coro_hdl;
match env.running_state.take() {
Some(err) => Err(err),
None => Ok(()),
}
}
#[inline]
pub fn join(&self) -> ResumeResult<()> {
loop {
match self.state() {
State::Suspended => try!(self.resume()),
_ => break,
}
}
Ok(())
}
#[inline]
pub fn state(&self) -> State {
unsafe { self.get_inner().state() }
}
#[inline]
pub fn set_state(&self, state: State) {
unsafe { self.get_inner_mut().set_state(state) }
}
}
impl Deref for Handle {
type Target = Coroutine;
#[inline]
fn deref(&self) -> &Coroutine {
unsafe { self.get_inner() }
}
}
#[allow(raw_pointer_derive)]
#[derive(Debug)]
pub struct Coroutine {
current_stack_segment: Option<Stack>,
saved_context: Context,
parent: *mut Coroutine,
state: State,
name: Option<String>,
}
unsafe impl Send for Coroutine {}
unsafe impl Sync for Coroutine {}
impl Drop for Coroutine {
fn drop(&mut self) {
match self.current_stack_segment.take() {
Some(stack) => {
let env = Environment::current();
env.stack_pool.give_stack(stack);
},
None => {}
}
}
}
extern "C" fn coroutine_initialize(_: usize, f: *mut ()) -> ! {
let func: Box<Thunk> = unsafe { transmute(f) };
let ret = unsafe { try(move|| func.invoke(())) };
let env = Environment::current();
let cur: &mut Coroutine = unsafe { env.current_running.get_inner_mut() };
let state = match ret {
Ok(..) => {
env.running_state = None;
State::Finished
}
Err(err) => {
{
use std::io::stderr;
use std::io::Write;
let msg = match err.downcast_ref::<&'static str>() {
Some(s) => *s,
None => match err.downcast_ref::<String>() {
Some(s) => &s[..],
None => "Box<Any>",
}
};
let name = cur.name().unwrap_or("<unnamed>");
let _ = writeln!(&mut stderr(), "Coroutine '{}' panicked at '{}'", name, msg);
}
env.running_state = Some(err);
State::Panicked
}
};
loop {
Coroutine::yield_now(state);
}
}
impl Coroutine {
unsafe fn empty(name: Option<String>, state: State) -> Handle {
Handle::new(Coroutine {
current_stack_segment: None,
saved_context: Context::empty(),
parent: ptr::null_mut(),
state: state,
name: name,
})
}
fn new(name: Option<String>, stack: Stack, ctx: Context, state: State) -> Handle {
Handle::new(Coroutine {
current_stack_segment: Some(stack),
saved_context: ctx,
parent: ptr::null_mut(),
state: state,
name: name,
})
}
pub fn spawn_opts<F>(f: F, opts: Options) -> Handle
where F: FnOnce() + Send + 'static {
let env = Environment::current();
let mut stack = env.stack_pool.take_stack(opts.stack_size);
let ctx = Context::new(coroutine_initialize,
0,
f,
&mut stack);
Coroutine::new(opts.name, stack, ctx, State::Suspended)
}
pub fn spawn<F>(f: F) -> Handle
where F: FnOnce() + Send + 'static {
Coroutine::spawn_opts(f, Default::default())
}
#[inline]
pub fn yield_now(state: State) {
assert!(state != State::Running);
let env = Environment::current();
unsafe {
let from_coro = env.current_running.get_inner_mut();
from_coro.set_state(state);
let to_coro: &mut Coroutine = &mut *from_coro.parent;
Context::swap(&mut from_coro.saved_context, &to_coro.saved_context);
}
}
#[inline]
pub fn sched() {
Coroutine::yield_now(State::Suspended)
}
#[inline]
pub fn block() {
Coroutine::yield_now(State::Blocked)
}
#[inline]
pub fn current() -> Handle {
Environment::current().current_running.clone()
}
#[inline(always)]
fn state(&self) -> State {
self.state
}
#[inline(always)]
fn set_state(&mut self, state: State) {
self.state = state
}
#[inline(always)]
pub fn name(&self) -> Option<&str> {
self.name.as_ref().map(|s| &**s)
}
}
thread_local!(static COROUTINE_ENVIRONMENT: UnsafeCell<Environment> = UnsafeCell::new(Environment::new()));
#[allow(raw_pointer_derive)]
struct Environment {
stack_pool: StackPool,
current_running: Handle,
__main_coroutine: Handle,
running_state: Option<Box<Any + Send>>,
}
impl Environment {
fn new() -> Environment {
let coro = unsafe {
let coro = Coroutine::empty(Some("<Environment Root Coroutine>".to_string()), State::Running);
coro.0.borrow_mut().parent = coro.get_inner_mut(); coro
};
Environment {
stack_pool: StackPool::new(),
current_running: coro.clone(),
__main_coroutine: coro,
running_state: None,
}
}
fn current() -> &'static mut Environment {
COROUTINE_ENVIRONMENT.with(|env| unsafe {
&mut *env.get()
})
}
}
pub struct Builder {
opts: Options,
}
impl Builder {
pub fn new() -> Builder {
Builder {
opts: Default::default(),
}
}
pub fn name(mut self, name: String) -> Builder {
self.opts.name = Some(name);
self
}
pub fn stack_size(mut self, size: usize) -> Builder {
self.opts.stack_size = size;
self
}
pub fn spawn<F>(self, f: F) -> Handle
where F: FnOnce() + Send + 'static {
Coroutine::spawn_opts(f, self.opts)
}
}
pub fn spawn<F>(f: F) -> Handle
where F: FnOnce() + Send + 'static {
Builder::new().spawn(f)
}
pub fn current() -> Handle {
Coroutine::current()
}
pub fn resume(coro: &Handle) -> ResumeResult<()> {
coro.resume()
}
pub fn sched() {
Coroutine::sched()
}
#[cfg(test)]
mod test {
use std::sync::mpsc::channel;
use super::Coroutine;
use super::Builder;
#[test]
fn test_coroutine_basic() {
let (tx, rx) = channel();
Coroutine::spawn(move|| {
tx.send(1).unwrap();
}).resume().ok().expect("Failed to resume");
assert_eq!(rx.recv().unwrap(), 1);
}
#[test]
fn test_coroutine_yield() {
let (tx, rx) = channel();
let coro = Coroutine::spawn(move|| {
tx.send(1).unwrap();
Coroutine::sched();
tx.send(2).unwrap();
});
coro.resume().ok().expect("Failed to resume");
assert_eq!(rx.recv().unwrap(), 1);
assert!(rx.try_recv().is_err());
coro.resume().ok().expect("Failed to resume");
assert_eq!(rx.recv().unwrap(), 2);
}
#[test]
fn test_coroutine_spawn_inside() {
let (tx, rx) = channel();
Coroutine::spawn(move|| {
tx.send(1).unwrap();
Coroutine::spawn(move|| {
tx.send(2).unwrap();
}).join().ok().expect("Failed to join");
}).join().ok().expect("Failed to join");
assert_eq!(rx.recv().unwrap(), 1);
assert_eq!(rx.recv().unwrap(), 2);
}
#[test]
fn test_coroutine_panic() {
let coro = Coroutine::spawn(move|| {
panic!("Panic inside a coroutine!!");
});
assert!(coro.join().is_err());
}
#[test]
fn test_coroutine_child_panic() {
Coroutine::spawn(move|| {
let _ = Coroutine::spawn(move|| {
panic!("Panic inside a coroutine's child!!");
}).join();
}).join().ok().expect("Failed to join");
}
#[test]
fn test_coroutine_resume_after_finished() {
let coro = Coroutine::spawn(move|| {});
assert!(coro.resume().is_ok());
assert!(coro.resume().is_ok());
}
#[test]
fn test_coroutine_resume_itself() {
let coro = Coroutine::spawn(move|| {
Coroutine::current().resume().ok().expect("Failed to resume");
});
assert!(coro.resume().is_ok());
}
#[test]
fn test_coroutine_yield_in_main() {
Coroutine::sched();
}
#[test]
fn test_builder_basic() {
let (tx, rx) = channel();
Builder::new().name("Test builder".to_string()).spawn(move|| {
tx.send(1).unwrap();
}).join().ok().expect("Failed to join");
assert_eq!(rx.recv().unwrap(), 1);
}
}
#[cfg(test)]
mod bench {
use std::sync::mpsc::channel;
use test::Bencher;
use super::Coroutine;
#[bench]
fn bench_coroutine_spawning_with_recycle(b: &mut Bencher) {
b.iter(|| {
let _ = Coroutine::spawn(move|| {}).join();
});
}
#[bench]
fn bench_normal_counting(b: &mut Bencher) {
b.iter(|| {
const MAX_NUMBER: usize = 100;
let (tx, rx) = channel();
let mut result = 0;
for _ in 0..MAX_NUMBER {
tx.send(1).unwrap();
result += rx.recv().unwrap();
}
assert_eq!(result, MAX_NUMBER);
});
}
#[bench]
fn bench_coroutine_counting(b: &mut Bencher) {
b.iter(|| {
const MAX_NUMBER: usize = 100;
let (tx, rx) = channel();
let coro = Coroutine::spawn(move|| {
for _ in 0..MAX_NUMBER {
tx.send(1).unwrap();
Coroutine::sched();
}
});
coro.resume().ok().expect("Failed to resume");
let mut result = 0;
for n in rx.iter() {
coro.resume().ok().expect("Failed to resume");
result += n;
}
assert_eq!(result, MAX_NUMBER);
});
}
}