use core::cell::Cell;
use core::hint::unreachable_unchecked;
use core::marker::PhantomData;
use core::mem::{self, ManuallyDrop};
use core::ptr;
use crate::arch::{self, STACK_ALIGNMENT};
use crate::sanitizer::SanitizerFiber;
#[cfg(feature = "default-stack")]
use crate::stack::DefaultStack;
#[cfg(windows)]
use crate::stack::StackTebFields;
use crate::stack::{self, StackPointer};
use crate::trap::CoroutineTrapHandler;
use crate::unwind::{self, initial_func_abi, CaughtPanic, ForcedUnwindErr};
use crate::util::{self, EncodedValue};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum CoroutineResult<Yield, Return> {
Yield(Yield),
Return(Return),
}
impl<Yield, Return> CoroutineResult<Yield, Return> {
pub fn as_yield(self) -> Option<Yield> {
match self {
CoroutineResult::Yield(val) => Some(val),
CoroutineResult::Return(_) => None,
}
}
pub fn as_return(self) -> Option<Return> {
match self {
CoroutineResult::Yield(_) => None,
CoroutineResult::Return(val) => Some(val),
}
}
}
#[inline]
pub(crate) fn adjusted_stack_base(stack: &impl stack::Stack) -> StackPointer {
unsafe { StackPointer::new_unchecked(stack.base().get() - mem::size_of::<SanitizerFiber>()) }
}
#[cfg(feature = "default-stack")]
pub struct Coroutine<Input, Yield, Return, Stack: stack::Stack = DefaultStack> {
stack: Stack,
stack_ptr: Option<StackPointer>,
initial_stack_ptr: StackPointer,
drop_fn: unsafe fn(ptr: *mut u8),
sanitizer_fiber: SanitizerFiber,
marker: PhantomData<fn(Input) -> CoroutineResult<Yield, Return>>,
marker2: PhantomData<*mut ()>,
}
#[cfg(not(feature = "default-stack"))]
pub struct Coroutine<Input, Yield, Return, Stack: stack::Stack> {
stack: Stack,
stack_ptr: Option<StackPointer>,
initial_stack_ptr: StackPointer,
drop_fn: unsafe fn(ptr: *mut u8),
sanitizer_fiber: SanitizerFiber,
marker: PhantomData<fn(Input) -> CoroutineResult<Yield, Return>>,
marker2: PhantomData<*mut ()>,
}
unsafe impl<Input, Yield, Return, Stack: stack::Stack + Sync> Sync
for Coroutine<Input, Yield, Return, Stack>
{
}
#[cfg(feature = "default-stack")]
impl<Input, Yield, Return> Coroutine<Input, Yield, Return, DefaultStack> {
pub fn new<F>(func: F) -> Self
where
F: FnOnce(&Yielder<Input, Yield>, Input) -> Return,
F: 'static,
Input: 'static,
Yield: 'static,
Return: 'static,
{
Self::with_stack(Default::default(), func)
}
}
impl<Input, Yield, Return, Stack: stack::Stack> Coroutine<Input, Yield, Return, Stack> {
pub fn with_stack<F>(stack: Stack, func: F) -> Self
where
F: FnOnce(&Yielder<Input, Yield>, Input) -> Return,
F: 'static,
Input: 'static,
Yield: 'static,
Return: 'static,
Stack: 'static,
{
unsafe { Self::with_stack_unchecked(stack, func) }
}
pub unsafe fn with_stack_unchecked<F>(stack: Stack, func: F) -> Self
where
F: FnOnce(&Yielder<Input, Yield>, Input) -> Return,
{
initial_func_abi! {
unsafe fn coroutine_func<Input, Yield, Return, F>(
input: EncodedValue,
parent_link: &mut StackPointer,
func: *mut F,
) -> !
where
F: FnOnce(&Yielder<Input, Yield>, Input) -> Return,
{
let yielder = &*(parent_link as *mut StackPointer as *const Yielder<Input, Yield>);
unsafe {
*SanitizerFiber::from_parent_link(parent_link) =
SanitizerFiber::finish_switch(ptr::null_mut());
}
debug_assert_eq!(func as usize % mem::align_of::<F>(), 0);
let f = func.read();
let input : Result<Input, ForcedUnwindErr> = util::decode_val(input);
let input = match input {
Ok(input) => input,
#[cfg_attr(feature = "asm-unwind", allow(unreachable_patterns))]
Err(_) => unreachable_unchecked(),
};
let result = unwind::catch_unwind_at_root(|| f(&yielder, input));
unsafe {
(*SanitizerFiber::from_parent_link(parent_link)).start_switch();
}
let mut result = ManuallyDrop::new(result);
arch::switch_and_reset(util::encode_val(&mut result), yielder.stack_ptr.as_ptr());
}
}
unsafe fn drop_fn<T>(ptr: *mut u8) {
ptr::drop_in_place(ptr as *mut T);
}
unsafe {
let stack_ptr =
arch::init_stack(&stack, coroutine_func::<Input, Yield, Return, F>, func);
let sanitizer_fiber = stack.sanitizer_fiber();
Self {
stack,
stack_ptr: Some(stack_ptr),
initial_stack_ptr: stack_ptr,
drop_fn: drop_fn::<F>,
sanitizer_fiber,
marker: PhantomData,
marker2: PhantomData,
}
}
}
pub fn resume(&mut self, val: Input) -> CoroutineResult<Yield, Return> {
unsafe {
let stack_ptr = self
.stack_ptr
.expect("attempt to resume a completed coroutine");
match self.resume_inner(stack_ptr, Ok(val)) {
CoroutineResult::Yield(val) => CoroutineResult::Yield(val),
CoroutineResult::Return(result) => {
CoroutineResult::Return(unwind::maybe_resume_unwind(result))
}
}
}
}
unsafe fn resume_inner(
&mut self,
stack_ptr: StackPointer,
input: Result<Input, ForcedUnwindErr>,
) -> CoroutineResult<Yield, Result<Return, CaughtPanic>> {
self.stack_ptr = None;
let fake_stack = self.sanitizer_fiber.start_switch();
let mut input = ManuallyDrop::new(input);
let (result, stack_ptr) = arch::switch_and_link(
util::encode_val(&mut input),
stack_ptr,
adjusted_stack_base(&self.stack),
);
self.stack_ptr = stack_ptr;
self.sanitizer_fiber = SanitizerFiber::finish_switch(fake_stack);
if stack_ptr.is_some() {
CoroutineResult::Yield(util::decode_val(result))
} else {
CoroutineResult::Return(util::decode_val(result))
}
}
pub fn started(&self) -> bool {
self.stack_ptr != Some(self.initial_stack_ptr)
}
pub fn done(&self) -> bool {
self.stack_ptr.is_none()
}
pub unsafe fn force_reset(&mut self) {
#[cfg(windows)]
if let Some(stack_ptr) = self.stack_ptr {
if self.started() {
arch::reset_teb_fields_from_suspended(adjusted_stack_base(&self.stack), stack_ptr);
}
}
self.stack_ptr = None;
}
pub fn force_unwind(&mut self) {
if let Some(stack_ptr) = self.stack_ptr {
self.force_unwind_slow(stack_ptr);
}
}
#[cold]
#[allow(unused_mut)]
fn force_unwind_slow(&mut self, mut stack_ptr: StackPointer) {
if !self.started() {
unsafe {
arch::drop_initial_obj(adjusted_stack_base(&self.stack), stack_ptr, self.drop_fn);
}
self.stack_ptr = None;
return;
}
if cfg!(not(feature = "unwind")) {
panic!("can't unwind a suspended coroutine without the \"unwind\" feature");
}
#[cfg(feature = "unwind")]
loop {
extern crate std;
let forced_unwind = unwind::ForcedUnwind(self.initial_stack_ptr);
let result = unwind::catch_forced_unwind(|| {
#[cfg(not(feature = "asm-unwind"))]
let result = unsafe { self.resume_inner(stack_ptr, Err(forced_unwind)) };
#[cfg(feature = "asm-unwind")]
let result = unsafe { self.resume_with_exception(stack_ptr, forced_unwind) };
result
});
match result {
CoroutineResult::Yield(_) => {
stack_ptr = self.stack_ptr.unwrap();
continue;
}
CoroutineResult::Return(Ok(_)) => return,
CoroutineResult::Return(Err(e)) => {
if let Some(forced_unwind) = e.downcast_ref::<unwind::ForcedUnwind>() {
if forced_unwind.0 == self.initial_stack_ptr {
return;
}
}
std::panic::resume_unwind(e);
}
}
}
}
#[cfg(feature = "asm-unwind")]
unsafe fn resume_with_exception(
&mut self,
stack_ptr: StackPointer,
forced_unwind: unwind::ForcedUnwind,
) -> CoroutineResult<Yield, Result<Return, CaughtPanic>> {
self.stack_ptr = None;
let (result, stack_ptr) =
arch::switch_and_throw(forced_unwind, stack_ptr, adjusted_stack_base(&self.stack));
self.stack_ptr = stack_ptr;
if stack_ptr.is_some() {
CoroutineResult::Yield(util::decode_val(result))
} else {
CoroutineResult::Return(util::decode_val(result))
}
}
#[allow(unused_mut)]
pub fn into_stack(mut self) -> Stack {
assert!(
self.done(),
"cannot extract stack from an incomplete coroutine"
);
unsafe {
self.stack.sanitizer_fiber().unpoison_stack();
}
#[cfg(windows)]
unsafe {
arch::update_stack_teb_fields(&mut self.stack);
}
unsafe {
let stack = ptr::read(&self.stack);
mem::forget(self);
stack
}
}
pub fn trap_handler(&self) -> CoroutineTrapHandler<Return> {
CoroutineTrapHandler {
stack_base: adjusted_stack_base(&self.stack),
stack_limit: self.stack.limit(),
marker: PhantomData,
}
}
}
impl<Input, Yield, Return, Stack: stack::Stack> Drop for Coroutine<Input, Yield, Return, Stack> {
fn drop(&mut self) {
let guard = scopeguard::guard((), |()| {
panic!("cannot propagate coroutine panic with #![no_std]");
});
self.force_unwind();
mem::forget(guard);
unsafe {
self.stack.sanitizer_fiber().unpoison_stack();
}
#[cfg(windows)]
unsafe {
arch::update_stack_teb_fields(&mut self.stack);
}
}
}
#[repr(transparent)]
pub struct Yielder<Input, Yield> {
stack_ptr: Cell<StackPointer>,
marker: PhantomData<fn(Yield) -> Input>,
}
impl<Input, Yield> Yielder<Input, Yield> {
pub fn suspend(&self, val: Yield) -> Input {
unsafe {
let parent_link = self as *const Self as *mut StackPointer;
let sanitizer_fiber = &mut *SanitizerFiber::from_parent_link(parent_link);
let fake_stack = sanitizer_fiber.start_switch();
let mut val = ManuallyDrop::new(val);
let result = arch::switch_yield(util::encode_val(&mut val), self.stack_ptr.as_ptr());
*sanitizer_fiber = SanitizerFiber::finish_switch(fake_stack);
unwind::maybe_force_unwind(util::decode_val(result))
}
}
pub fn on_parent_stack<F, R>(&self, f: F) -> R
where
F: FnOnce() -> R,
F: Send,
{
let stack_ptr = unsafe {
StackPointer::new_unchecked(self.stack_ptr.get().get() - arch::PARENT_STACK_OFFSET)
};
let parent_link = self as *const Self as *mut StackPointer;
let sanitizer_fiber = unsafe { *SanitizerFiber::from_parent_link(parent_link) };
let stack = unsafe { ParentStack::new(stack_ptr, sanitizer_fiber) };
on_stack(stack, f)
}
}
pub fn on_stack<F, R>(stack: impl stack::Stack, f: F) -> R
where
F: FnOnce() -> R,
{
union FuncOrResult<F, R> {
func: ManuallyDrop<F>,
result: ManuallyDrop<Result<R, CaughtPanic>>,
}
initial_func_abi! {
unsafe fn wrapper<F, R>(ptr: *mut u8)
where
F: FnOnce() -> R,
{
let sanitizer_fiber = SanitizerFiber::finish_switch(ptr::null_mut());
let data = &mut *(ptr as *mut FuncOrResult<F, R>);
let func = ManuallyDrop::take(&mut data.func);
let result = unwind::catch_unwind_at_root(func);
data.result = ManuallyDrop::new(result);
sanitizer_fiber.start_switch();
}
}
unsafe {
let mut data = FuncOrResult {
func: ManuallyDrop::new(f),
};
let sanitizer_fiber = stack.sanitizer_fiber();
let fake_stack = sanitizer_fiber.start_switch();
arch::on_stack(&mut data as *mut _ as *mut u8, stack, wrapper::<F, R>);
SanitizerFiber::finish_switch(fake_stack);
unwind::maybe_resume_unwind(ManuallyDrop::take(&mut data.result))
}
}
struct ParentStack {
stack_base: StackPointer,
#[cfg(windows)]
stack_ptr: StackPointer,
sanitizer_fiber: SanitizerFiber,
}
impl ParentStack {
#[inline]
unsafe fn new(stack_ptr: StackPointer, sanitizer_fiber: SanitizerFiber) -> Self {
let stack_base = StackPointer::new_unchecked(stack_ptr.get() & !(STACK_ALIGNMENT - 1));
Self {
stack_base,
#[cfg(windows)]
stack_ptr,
sanitizer_fiber,
}
}
}
unsafe impl stack::Stack for ParentStack {
#[inline]
fn base(&self) -> StackPointer {
self.stack_base
}
#[inline]
fn limit(&self) -> StackPointer {
self.stack_base
}
#[inline]
#[cfg(windows)]
fn teb_fields(&self) -> StackTebFields {
unsafe { arch::read_parent_stack_teb_fields(self.stack_ptr) }
}
#[inline]
#[cfg(windows)]
fn update_teb_fields(&mut self, stack_limit: usize, guaranteed_stack_bytes: usize) {
unsafe {
arch::update_parent_stack_teb_fields(
self.stack_ptr,
stack_limit,
guaranteed_stack_bytes,
);
}
}
#[inline]
fn sanitizer_fiber(&self) -> SanitizerFiber {
self.sanitizer_fiber
}
}