use crate::runtime::execution::ExecutionState;
use crate::{ContinuationFunctionBehavior, UNGRACEFUL_SHUTDOWN_CONFIG};
use corosensei::Yielder;
use corosensei::{stack::DefaultStack, Coroutine, CoroutineResult};
use scoped_tls::scoped_thread_local;
use std::cell::{Cell, RefCell};
use std::collections::VecDeque;
use std::ops::Deref;
use std::ops::DerefMut;
use std::panic::Location;
use std::rc::Rc;
use tracing::trace;
scoped_thread_local! {
pub(crate) static CONTINUATION_POOL: ContinuationPool
}
pub(crate) struct Continuation {
coroutine: Coroutine<ContinuationInput, ContinuationOutput, ContinuationOutput>,
function: ContinuationFunction,
state: ContinuationState,
pub yielder: *const Yielder<ContinuationInput, ContinuationOutput>,
}
#[allow(clippy::type_complexity)]
#[derive(Clone)]
struct ContinuationFunction(Rc<Cell<Option<Box<dyn FnOnce()>>>>);
unsafe impl Send for ContinuationFunction {}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub(crate) enum ContinuationInput {
Resume,
Exit,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub(crate) enum ContinuationOutput {
Yielded,
Finished(*const Yielder<ContinuationInput, ContinuationOutput>),
Exited,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
enum ContinuationState {
NotReady, Initialized, Ready, Running, FinishedIteration, Exited, }
impl Continuation {
pub fn new(stack_size: usize) -> Self {
let function = ContinuationFunction(Rc::new(Cell::new(None)));
let mut coroutine = {
let function = function.clone();
Coroutine::with_stack(DefaultStack::new(stack_size).unwrap(), move |yielder, input| {
if let ContinuationInput::Exit = input {
return ContinuationOutput::Exited;
}
let _ = &function;
loop {
match yielder.suspend(ContinuationOutput::Finished(yielder as *const _)) {
ContinuationInput::Exit => break,
ContinuationInput::Resume => {}
};
let f = function.0.take().expect("must have a function to run");
f();
}
ContinuationOutput::Exited
})
};
let yielder = match coroutine.resume(ContinuationInput::Resume) {
CoroutineResult::Yield(ContinuationOutput::Finished(yielder)) => yielder,
_ => panic!("Coroutine should yield a pointer to its `corosensei::Yielder` from the first resume"),
};
Self {
coroutine,
yielder,
function,
state: ContinuationState::NotReady,
}
}
pub fn initialize(&mut self, fun: Box<dyn FnOnce()>) {
debug_assert!(self.reusable(), "shouldn't replace a function before it completes");
let old = self.function.0.replace(Some(fun));
debug_assert!(old.is_none(), "shouldn't replace a function before it runs");
self.state = ContinuationState::Initialized;
}
pub fn resume(&mut self) -> bool {
debug_assert!(self.state == ContinuationState::Ready || self.state == ContinuationState::Initialized);
let ret = self.resume_with_input(ContinuationInput::Resume);
debug_assert_ne!(
ret,
ContinuationOutput::Exited,
"continuation should not exit if resumed from user code"
);
matches!(ret, ContinuationOutput::Finished(_))
}
fn resume_with_input(&mut self, input: ContinuationInput) -> ContinuationOutput {
self.state = ContinuationState::Running;
match self.coroutine.resume(input) {
CoroutineResult::Yield(output) => {
self.state = match output {
ContinuationOutput::Finished(_) => ContinuationState::FinishedIteration,
ContinuationOutput::Yielded => ContinuationState::Ready,
ContinuationOutput::Exited => ContinuationState::Exited,
};
output
}
CoroutineResult::Return(output) => {
self.state = ContinuationState::Exited;
output
}
}
}
fn reusable(&self) -> bool {
self.state == ContinuationState::NotReady || self.state == ContinuationState::FinishedIteration
}
}
impl Drop for Continuation {
fn drop(&mut self) {
match self.state {
ContinuationState::Initialized | ContinuationState::FinishedIteration | ContinuationState::NotReady => {
let ret = self.resume_with_input(ContinuationInput::Exit);
debug_assert_eq!(ret, ContinuationOutput::Exited);
}
ContinuationState::Running | ContinuationState::Ready => {
if std::thread::panicking() {
unsafe {
self.coroutine.force_reset();
}
}
self.coroutine.force_unwind();
}
ContinuationState::Exited => {
}
}
}
}
pub(crate) struct ContinuationPool {
continuations: Rc<RefCell<VecDeque<Continuation>>>,
}
impl ContinuationPool {
pub fn new() -> Self {
Self {
continuations: Rc::new(RefCell::new(VecDeque::new())),
}
}
pub fn acquire(stack_size: usize) -> PooledContinuation {
CONTINUATION_POOL.with(|p| p.acquire_inner(stack_size))
}
fn acquire_inner(&self, stack_size: usize) -> PooledContinuation {
let continuation = self
.continuations
.borrow_mut()
.pop_front()
.unwrap_or_else(move || Continuation::new(stack_size));
PooledContinuation {
continuation: Some(continuation),
queue: self.continuations.clone(),
}
}
}
pub(crate) struct PooledContinuation {
continuation: Option<Continuation>,
queue: Rc<RefCell<VecDeque<Continuation>>>,
}
impl Drop for PooledContinuation {
fn drop(&mut self) {
let mut c = self.continuation.take().unwrap();
if c.reusable() {
self.queue.borrow_mut().push_back(c);
} else if matches!(c.state, ContinuationState::Initialized) {
let old = c.function.0.replace(None);
c.state = ContinuationState::NotReady;
if std::thread::panicking() {
match UNGRACEFUL_SHUTDOWN_CONFIG.get().continuation_function_behavior {
ContinuationFunctionBehavior::Drop => drop(old),
ContinuationFunctionBehavior::Leak => std::mem::forget(old),
}
} else {
drop(old);
}
self.queue.borrow_mut().push_back(c);
}
}
}
impl Deref for PooledContinuation {
type Target = Continuation;
fn deref(&self) -> &Self::Target {
self.continuation.as_ref().unwrap()
}
}
impl DerefMut for PooledContinuation {
fn deref_mut(&mut self) -> &mut Self::Target {
self.continuation.as_mut().unwrap()
}
}
impl std::fmt::Debug for PooledContinuation {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("PooledContinuation").finish()
}
}
unsafe impl Send for PooledContinuation {}
#[track_caller]
pub(crate) fn switch() {
crate::annotations::record_tick();
trace!("switch from {}", Location::caller());
if ExecutionState::maybe_yield() {
let yielder = ExecutionState::with(|state| state.current().yielder);
match unsafe { &(*yielder) }.suspend(ContinuationOutput::Yielded) {
ContinuationInput::Exit => panic!("unexpected exit continuation"),
ContinuationInput::Resume => {}
};
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Config;
#[test]
fn reusable_continuation_drop() {
let pool = ContinuationPool::new();
let config: Config = Default::default();
let mut c = pool.acquire_inner(config.stack_size);
c.initialize(Box::new(|| {
let _ = 1 + 1;
}));
let yielder = c.yielder;
let r = c.resume();
assert!(r, "continuation only has one step");
drop(c);
assert_eq!(
pool.continuations.borrow().len(),
1,
"continuation should be reusable because the function finished"
);
let mut c = pool.acquire_inner(config.stack_size);
c.initialize(Box::new(move || {
unsafe { &(*yielder) }.suspend(ContinuationOutput::Yielded);
let _ = 1 + 1;
}));
let r = c.resume();
assert!(!r, "continuation yields once, shouldn't be finished yet");
drop(c);
assert_eq!(
pool.continuations.borrow().len(),
0,
"continuation should not be reusable because the function wasn't finished"
);
let c = pool.acquire_inner(config.stack_size);
drop(pool);
drop(c);
}
}