use std::cell::RefCell;
use std::future::Future;
use std::panic::AssertUnwindSafe;
use std::sync::atomic::Ordering;
use std::time::Duration;
use std::{fmt, panic, thread};
use parking::Parker;
use slab::Slab;
use super::NEXT_EXECUTOR_ID;
use super::task::{self, CancelToken, Promise, Runnable};
use crate::channel;
use crate::executor::{ExecutorError, SIMULATION_CONTEXT, Signal, SimulationContext};
use crate::macros::scoped_thread_local::scoped_thread_local;
use crate::simulation::CURRENT_MODEL_ID;
const QUEUE_MIN_CAPACITY: usize = 32;
scoped_thread_local!(static EXECUTOR_CONTEXT: ExecutorContext);
scoped_thread_local!(static ACTIVE_TASKS: RefCell<Slab<CancelToken>>);
pub(crate) struct Executor {
inner: Option<Box<ExecutorInner>>,
abort_signal: Signal,
}
impl Executor {
pub(crate) fn new(simulation_context: SimulationContext, abort_signal: Signal) -> Self {
let executor_id = NEXT_EXECUTOR_ID.fetch_add(1, Ordering::Relaxed);
assert!(
executor_id <= usize::MAX / 2,
"too many executors have been instantiated"
);
let context = ExecutorContext::new(executor_id);
let active_tasks = RefCell::new(Slab::new());
Self {
inner: Some(Box::new(ExecutorInner {
context,
active_tasks,
simulation_context,
abort_signal: abort_signal.clone(),
})),
abort_signal,
}
}
pub(crate) fn spawn<T>(&self, future: T) -> Promise<T::Output>
where
T: Future + Send + 'static,
T::Output: Send + 'static,
{
let inner = self.inner.as_ref().unwrap();
let mut active_tasks = inner.active_tasks.borrow_mut();
let task_entry = active_tasks.vacant_entry();
let future = CancellableFuture::new(future, task_entry.key());
let (promise, runnable, cancel_token) =
task::spawn(future, schedule_task, inner.context.executor_id);
task_entry.insert(cancel_token);
let mut queue = inner.context.queue.borrow_mut();
queue.push(runnable);
promise
}
pub(crate) fn spawn_and_forget<T>(&self, future: T)
where
T: Future + Send + 'static,
T::Output: Send + 'static,
{
let inner = self.inner.as_ref().unwrap();
let mut active_tasks = inner.active_tasks.borrow_mut();
let task_entry = active_tasks.vacant_entry();
let future = CancellableFuture::new(future, task_entry.key());
let (runnable, cancel_token) =
task::spawn_and_forget(future, schedule_task, inner.context.executor_id);
task_entry.insert(cancel_token);
let mut queue = inner.context.queue.borrow_mut();
queue.push(runnable);
}
pub(crate) fn run(&mut self, timeout: Duration) -> Result<(), ExecutorError> {
if timeout.is_zero() {
return self.inner.as_mut().unwrap().run();
}
let mut inner = self.inner.take().unwrap();
let parker = Parker::new();
let unparker = parker.unparker();
let th = thread::spawn(move || {
let res = inner.run();
unparker.unpark();
(inner, res)
});
if !parker.park_timeout(timeout) {
self.abort_signal.set();
return Err(ExecutorError::Timeout);
}
let (inner, res) = th.join().unwrap();
self.inner = Some(inner);
res
}
}
struct ExecutorInner {
context: ExecutorContext,
active_tasks: RefCell<Slab<CancelToken>>,
simulation_context: SimulationContext,
abort_signal: Signal,
}
impl ExecutorInner {
fn run(&mut self) -> Result<(), ExecutorError> {
let msg_count_stash = channel::THREAD_MSG_COUNT.replace(self.context.msg_count);
let result = SIMULATION_CONTEXT.set(&self.simulation_context, || {
ACTIVE_TASKS.set(&self.active_tasks, || {
EXECUTOR_CONTEXT.set(&self.context, || {
panic::catch_unwind(AssertUnwindSafe(|| {
loop {
let task = match self.context.queue.borrow_mut().pop() {
Some(task) => task,
None => break,
};
task.run();
if self.abort_signal.is_set() {
return;
}
}
}))
})
})
});
if let Err(payload) = result {
let model_id = CURRENT_MODEL_ID.take();
return Err(ExecutorError::Panic(model_id, payload));
}
self.context.msg_count = channel::THREAD_MSG_COUNT.replace(msg_count_stash);
if self.context.msg_count != 0 {
let msg_count: usize = self.context.msg_count.try_into().unwrap();
return Err(ExecutorError::UnprocessedMessages(msg_count));
}
Ok(())
}
}
impl Drop for ExecutorInner {
fn drop(&mut self) {
EXECUTOR_CONTEXT.set(&self.context, || {
ACTIVE_TASKS.unset(|| {
let mut tasks = self.active_tasks.borrow_mut();
for task in tasks.drain() {
task.cancel();
}
});
});
}
}
impl fmt::Debug for Executor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Executor").finish_non_exhaustive()
}
}
struct ExecutorContext {
queue: RefCell<Vec<Runnable>>,
executor_id: usize,
msg_count: isize,
}
impl ExecutorContext {
fn new(executor_id: usize) -> Self {
Self {
queue: RefCell::new(Vec::with_capacity(QUEUE_MIN_CAPACITY)),
executor_id,
msg_count: 0,
}
}
}
struct CancellableFuture<T: Future> {
inner: T,
cancellation_key: usize,
}
impl<T: Future> CancellableFuture<T> {
fn new(fut: T, cancellation_key: usize) -> Self {
Self {
inner: fut,
cancellation_key,
}
}
}
impl<T: Future> Future for CancellableFuture<T> {
type Output = T::Output;
#[inline(always)]
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
unsafe { self.map_unchecked_mut(|s| &mut s.inner).poll(cx) }
}
}
impl<T: Future> Drop for CancellableFuture<T> {
fn drop(&mut self) {
let _ = ACTIVE_TASKS.map(|active_tasks| {
if let Ok(mut active_tasks) = active_tasks.try_borrow_mut() {
let _cancel_token = active_tasks.try_remove(self.cancellation_key);
}
});
}
}
fn schedule_task(task: Runnable, executor_id: usize) {
EXECUTOR_CONTEXT
.map(|context| {
assert_eq!(
executor_id, context.executor_id,
"Tasks must be awaken on the same executor they are spawned on"
);
let mut queue = context.queue.borrow_mut();
queue.push(task);
})
.expect("Tasks may not be awaken outside executor threads");
}