use std::any::Any;
use std::cell::{Cell, UnsafeCell};
use std::rc::Rc;
use std::{mem, panic, ptr};
use ignore_result::Ignore;
use static_assertions::assert_not_impl_any;
use super::Coroutine;
use crate::error::{JoinError, PanicError};
use crate::select::{Identifier, Permit, PermitReader, Selectable, Selector, TrySelectError};
use crate::task::{self, Yielding};
enum SuspensionState<T: 'static> {
Empty,
Value(T),
Panicked(PanicError),
Joining(ptr::NonNull<Coroutine>),
Selector(Selector),
Joined,
}
struct SuspensionJoint<T: 'static> {
state: UnsafeCell<SuspensionState<T>>,
wakers: Cell<usize>,
}
impl<T> Yielding for SuspensionJoint<T> {
fn interrupt(&self, reason: &'static str) -> bool {
self.cancel(PanicError::Static(reason));
true
}
}
impl<T> SuspensionJoint<T> {
fn new() -> Rc<SuspensionJoint<T>> {
Rc::new(SuspensionJoint { state: UnsafeCell::new(SuspensionState::Empty), wakers: Cell::new(1) })
}
fn is_ready(&self) -> bool {
let state = unsafe { &*self.state.get() };
matches!(state, SuspensionState::Value(_) | SuspensionState::Panicked(_))
}
fn wake_coroutine(co: ptr::NonNull<Coroutine>) {
let task = unsafe { task::current().as_mut() };
task.resume(co);
}
fn add_waker(&self) {
let wakers = self.wakers.get() + 1;
self.wakers.set(wakers);
}
fn remove_waker(&self) {
let wakers = self.wakers.get() - 1;
self.wakers.set(wakers);
if wakers == 0 {
self.fault(PanicError::Static("suspend: no resumption"));
}
}
fn cancel(&self, err: PanicError) -> Option<ptr::NonNull<Coroutine>> {
let state = unsafe { &mut *self.state.get() };
if matches!(state, SuspensionState::Value(_) | SuspensionState::Panicked(_) | SuspensionState::Joined) {
return None;
}
let state = unsafe { ptr::replace(state, SuspensionState::Panicked(err)) };
if let SuspensionState::Joining(co) = state {
return Some(co);
} else if let SuspensionState::Selector(selector) = state {
selector.apply(Permit::default());
}
None
}
fn fault(&self, err: PanicError) {
if let Some(co) = self.cancel(err) {
Self::wake_coroutine(co);
}
}
pub fn wake(&self, value: T) -> Result<(), T> {
let state = unsafe { &mut *self.state.get() };
if matches!(state, SuspensionState::Value(_) | SuspensionState::Panicked(_) | SuspensionState::Joined) {
return Err(value);
}
let state = unsafe { ptr::replace(state, SuspensionState::Value(value)) };
if let SuspensionState::Joining(co) = state {
Self::wake_coroutine(co);
} else if let SuspensionState::Selector(selector) = state {
selector.apply(Permit::default());
}
Ok(())
}
fn set_result(&self, result: Result<T, Box<dyn Any + Send + 'static>>) {
match result {
Ok(value) => self.wake(value).ignore(),
Err(err) => self.fault(PanicError::Unwind(err)),
}
}
fn watch_permit(&self, selector: Selector) -> bool {
let state = unsafe { &mut *self.state.get() };
match state {
SuspensionState::Value(_) | SuspensionState::Panicked(_) | SuspensionState::Joined => {
selector.apply(Permit::default());
return true;
},
SuspensionState::Joining(_) => unreachable!("suspension: joining state"),
SuspensionState::Selector(_) => unreachable!("suspension: selecting"),
SuspensionState::Empty => unsafe { ptr::write(state, SuspensionState::Selector(selector)) },
}
false
}
fn unwatch_permit(&self, identifer: &Identifier) {
let state = unsafe { &mut *self.state.get() };
if let SuspensionState::Selector(selector) = state {
assert!(selector.identify(identifer), "suspension: selecting by other");
*state = SuspensionState::Empty;
}
}
fn consume_permit(&self) -> Result<T, PanicError> {
self.take()
}
fn take(&self) -> Result<T, PanicError> {
let state = mem::replace(unsafe { &mut *self.state.get() }, SuspensionState::Joined);
match state {
SuspensionState::Value(value) => Ok(value),
SuspensionState::Panicked(err) => Err(err),
SuspensionState::Empty => unreachable!("suspension: empty state"),
SuspensionState::Joining(_) => unreachable!("suspension: joining state"),
SuspensionState::Joined => unreachable!("suspension: joined state"),
SuspensionState::Selector(_) => unreachable!("suspension: selecting"),
}
}
fn join(&self) -> Result<T, PanicError> {
let co = super::current();
let state = mem::replace(unsafe { &mut *self.state.get() }, SuspensionState::Joining(co));
match state {
SuspensionState::Empty => {
let task = unsafe { task::current().as_mut() };
task.suspend(co, self);
self.take()
},
SuspensionState::Value(value) => Ok(value),
SuspensionState::Panicked(err) => Err(err),
SuspensionState::Joining(_) => unreachable!("suspension: join joining state"),
SuspensionState::Joined => unreachable!("suspension: join joined state"),
SuspensionState::Selector(_) => unreachable!("suspension: selecting"),
}
}
}
pub struct Suspension<T: 'static>(Rc<SuspensionJoint<T>>);
pub struct Resumption<T: 'static> {
joint: Rc<SuspensionJoint<T>>,
}
assert_not_impl_any!(Suspension<()>: Send);
assert_not_impl_any!(Resumption<()>: Send);
impl<T> Suspension<T> {
unsafe fn into_joint(self) -> Rc<SuspensionJoint<T>> {
let joint = ptr::read(&self.0);
mem::forget(self);
joint
}
pub fn is_ready(&self) -> bool {
self.0.is_ready()
}
pub fn suspend(self) -> T {
let joint = unsafe { self.into_joint() };
match joint.join() {
Ok(value) => value,
Err(PanicError::Unwind(err)) => panic::resume_unwind(err),
Err(PanicError::Static(s)) => panic::panic_any(s),
}
}
}
impl<T> Drop for Suspension<T> {
fn drop(&mut self) {
self.0.cancel(PanicError::Static("suspension dropped"));
}
}
impl<T> Resumption<T> {
fn new(joint: Rc<SuspensionJoint<T>>) -> Self {
Resumption { joint }
}
unsafe fn into_joint(self) -> Rc<SuspensionJoint<T>> {
let joint = ptr::read(&self.joint);
mem::forget(self);
joint
}
pub fn resume(self, value: T) -> bool {
let joint = unsafe { self.into_joint() };
joint.wake(value).is_ok()
}
pub fn send(self, value: T) -> Result<(), T> {
let joint = unsafe { self.into_joint() };
joint.wake(value)
}
pub(super) fn set_result(self, result: Result<T, Box<dyn Any + Send + 'static>>) {
let joint = unsafe { self.into_joint() };
joint.set_result(result);
}
}
impl<T> Clone for Resumption<T> {
fn clone(&self) -> Self {
self.joint.add_waker();
Resumption { joint: self.joint.clone() }
}
}
impl<T> Drop for Resumption<T> {
fn drop(&mut self) {
self.joint.remove_waker();
}
}
pub fn suspension<T>() -> (Suspension<T>, Resumption<T>) {
let joint = SuspensionJoint::new();
let suspension = Suspension(joint.clone());
(suspension, Resumption::new(joint))
}
pub struct JoinHandle<T: 'static> {
suspension: Option<Suspension<T>>,
}
assert_not_impl_any!(JoinHandle<()>: Send);
impl<T> JoinHandle<T> {
pub(super) fn new(suspension: Suspension<T>) -> Self {
JoinHandle { suspension: Some(suspension) }
}
pub fn join(mut self) -> Result<T, JoinError> {
if let Some(suspension) = self.suspension.take() {
let joint = unsafe { suspension.into_joint() };
joint.join().map_err(JoinError::new)
} else {
panic!("already joined by select")
}
}
}
impl<T: 'static> Selectable for JoinHandle<T> {
fn parallel(&self) -> bool {
false
}
fn select_permit(&self) -> Result<Permit, TrySelectError> {
if let Some(suspension) = self.suspension.as_ref() {
if suspension.is_ready() {
Ok(Permit::default())
} else {
Err(TrySelectError::WouldBlock)
}
} else {
Err(TrySelectError::Completed)
}
}
fn watch_permit(&self, selector: Selector) -> bool {
if let Some(suspension) = self.suspension.as_ref() {
suspension.0.watch_permit(selector)
} else {
false
}
}
fn unwatch_permit(&self, identifier: &Identifier) {
if let Some(suspension) = self.suspension.as_ref() {
suspension.0.unwatch_permit(identifier);
}
}
}
impl<T: 'static> PermitReader for JoinHandle<T> {
type Result = Result<T, JoinError>;
fn consume_permit(&mut self, _permit: Permit) -> Result<T, JoinError> {
if let Some(suspension) = self.suspension.take() {
let joint = unsafe { suspension.into_joint() };
joint.consume_permit().map_err(JoinError::new)
} else {
panic!("JoinHandle: already consumed")
}
}
}
#[cfg(test)]
mod tests {
use ignore_result::Ignore;
use crate::{coroutine, select};
#[crate::test(crate = "crate")]
fn resumption() {
let (suspension, resumption) = coroutine::suspension();
drop(resumption.clone());
assert_eq!(suspension.0.is_ready(), false);
let co1 = coroutine::spawn({
let resumption = resumption.clone();
move || resumption.send(5)
});
let co2 = coroutine::spawn(move || resumption.send(6));
let value = suspension.suspend();
let mut result1 = co1.join().unwrap();
let mut result2 = co2.join().unwrap();
if result1.is_err() {
std::mem::swap(&mut result1, &mut result2);
}
assert_eq!(result1, Ok(()));
assert_eq!(result2.is_err(), true);
assert_eq!(value, 11 - result2.unwrap_err());
}
#[crate::test(crate = "crate")]
fn suspension_dropped() {
let (suspension, resumption) = coroutine::suspension::<()>();
drop(suspension);
assert_eq!(resumption.joint.is_ready(), true);
}
#[crate::test(crate = "crate")]
#[should_panic(expected = "deadlock suspending coroutines")]
fn suspension_deadlock() {
let (suspension, resumption) = coroutine::suspension::<()>();
suspension.suspend();
drop(resumption);
}
#[crate::test(crate = "crate")]
fn join_handle_join() {
let join_handle = coroutine::spawn(|| 5);
assert_eq!(join_handle.join().unwrap(), 5);
}
#[crate::test(crate = "crate")]
fn join_handle_join_panic() {
const REASON: &'static str = "oooooops";
let co = coroutine::spawn(|| panic!("{}", REASON));
let err = co.join().unwrap_err();
assert!(err.to_string().contains(REASON))
}
#[crate::test(crate = "crate")]
fn join_handle_select() {
let mut join_handle = coroutine::spawn(|| 5);
select! {
r = <-join_handle => assert_eq!(r.unwrap(), 5),
}
}
#[crate::test(crate = "crate")]
fn join_handle_select_complete() {
let mut join_handle = coroutine::spawn(|| 5);
select! {
r = <-join_handle => assert_eq!(r.unwrap(), 5),
}
select! {
r = <-join_handle => assert_eq!(r.unwrap(), 5),
complete => {},
}
}
#[crate::test(crate = "crate")]
#[should_panic(expected = "already joined by select")]
fn join_handle_join_consumed() {
let mut join_handle = coroutine::spawn(|| 5);
select! {
r = <-join_handle => assert_eq!(r.unwrap(), 5),
}
join_handle.join().ignore();
}
#[crate::test(crate = "crate")]
#[should_panic(expected = "all selectables are disabled or completed")]
fn join_handle_select_consumed() {
let mut join_handle = coroutine::spawn(|| 5);
select! {
r = <-join_handle => assert_eq!(r.unwrap(), 5),
}
select! {
r = <-join_handle => assert_eq!(r.unwrap(), 5),
}
}
}