mod injector;
mod pool_manager;
use std::cell::Cell;
use std::fmt;
use std::future::Future;
use std::panic::{self, AssertUnwindSafe};
use std::sync::atomic::{AtomicIsize, Ordering};
use std::sync::{Arc, Mutex};
use std::thread::{self, JoinHandle};
use std::time::{Duration, Instant};
use parking::{Parker, Unparker};
use slab::Slab;
use crate::channel;
use crate::executor::task::{self, CancelToken, Promise, Runnable};
use crate::executor::{
ExecutorError, NEXT_EXECUTOR_ID, SIMULATION_CONTEXT, Signal, SimulationContext,
};
use crate::macros::scoped_thread_local::scoped_thread_local;
use crate::simulation::CURRENT_MODEL_ID;
use crate::util::rng::Rng;
use pool_manager::PoolManager;
const BUCKET_SIZE: usize = 128;
const QUEUE_SIZE: usize = BUCKET_SIZE * 2;
type Bucket = injector::Bucket<Runnable, BUCKET_SIZE>;
type Injector = injector::Injector<Runnable, BUCKET_SIZE>;
type LocalQueue = st3::fifo::Worker<Runnable>;
type Stealer = st3::fifo::Stealer<Runnable>;
scoped_thread_local!(static LOCAL_WORKER: Worker);
scoped_thread_local!(static ACTIVE_TASKS: Mutex<Slab<CancelToken>>);
pub(crate) struct Executor {
context: Arc<ExecutorContext>,
active_tasks: Arc<Mutex<Slab<CancelToken>>>,
parker: Parker,
worker_handles: Vec<JoinHandle<()>>,
abort_signal: Signal,
}
impl Executor {
pub(crate) fn new(
num_threads: usize,
simulation_context: SimulationContext,
abort_signal: Signal,
) -> Self {
let parker = Parker::new();
let unparker = parker.unparker().clone();
let (local_queues_and_parkers, stealers_and_unparkers): (Vec<_>, Vec<_>) = (0..num_threads)
.map(|_| {
let parker = Parker::new();
let unparker = parker.unparker().clone();
let local_queue = LocalQueue::new(QUEUE_SIZE);
let stealer = local_queue.stealer();
((local_queue, parker), (stealer, unparker))
})
.unzip();
let executor_id = NEXT_EXECUTOR_ID.fetch_add(1, Ordering::Relaxed);
assert!(
executor_id <= usize::MAX / 2,
"too many executors have been instantiated"
);
let context = Arc::new(ExecutorContext::new(
executor_id,
unparker,
stealers_and_unparkers.into_iter(),
));
let active_tasks = Arc::new(Mutex::new(Slab::new()));
context.pool_manager.set_all_workers_active();
let worker_handles: Vec<_> = local_queues_and_parkers
.into_iter()
.enumerate()
.map(|(id, (local_queue, worker_parker))| {
let thread_builder = thread::Builder::new().name(format!("Worker #{id}"));
thread_builder
.spawn({
let context = context.clone();
let active_tasks = active_tasks.clone();
let simulation_context = simulation_context.clone();
let abort_signal = abort_signal.clone();
move || {
let worker = Worker::new(local_queue, context);
SIMULATION_CONTEXT.set(&simulation_context, || {
ACTIVE_TASKS.set(&active_tasks, || {
LOCAL_WORKER.set(&worker, || {
run_local_worker(&worker, id, worker_parker, abort_signal)
})
})
});
}
})
.unwrap()
})
.collect();
parker.park();
assert!(context.pool_manager.pool_is_idle());
Self {
context,
active_tasks,
parker,
worker_handles,
abort_signal,
}
}
pub(crate) fn spawn<T>(&self, future: T) -> Promise<T::Output>
where
T: Future + Send + 'static,
T::Output: Send + 'static,
{
let mut active_tasks = self.active_tasks.lock().unwrap();
let task_entry = active_tasks.vacant_entry();
let future = CancellableFuture::new(future, task_entry.key());
let (promise, runnable, cancel_token) =
task::spawn(future, schedule_task, self.context.executor_id);
task_entry.insert(cancel_token);
self.context.injector.insert_task(runnable);
promise
}
pub(crate) fn spawn_and_forget<T>(&self, future: T)
where
T: Future + Send + 'static,
T::Output: Send + 'static,
{
let mut active_tasks = self.active_tasks.lock().unwrap();
let task_entry = active_tasks.vacant_entry();
let future = CancellableFuture::new(future, task_entry.key());
let (runnable, cancel_token) =
task::spawn_and_forget(future, schedule_task, self.context.executor_id);
task_entry.insert(cancel_token);
self.context.injector.insert_task(runnable);
}
pub(crate) fn run(&mut self, timeout: Duration) -> Result<(), ExecutorError> {
self.context.pool_manager.activate_worker();
loop {
if let Some((model_id, payload)) = self.context.pool_manager.take_panic() {
return Err(ExecutorError::Panic(model_id, payload));
}
if self.context.pool_manager.pool_is_idle() {
let msg_count = self.context.msg_count.load(Ordering::Relaxed);
if msg_count != 0 {
let msg_count: usize = msg_count.try_into().unwrap();
return Err(ExecutorError::UnprocessedMessages(msg_count));
}
return Ok(());
}
if timeout.is_zero() {
self.parker.park();
} else if !self.parker.park_timeout(timeout) {
self.abort_signal.set();
self.context.pool_manager.activate_all_workers();
return Err(ExecutorError::Timeout);
}
}
}
}
impl Drop for Executor {
fn drop(&mut self) {
self.abort_signal.set();
self.context.pool_manager.activate_all_workers();
for handle in self.worker_handles.drain(0..) {
handle.join().unwrap();
}
let worker = Worker::new(LocalQueue::new(QUEUE_SIZE), self.context.clone());
LOCAL_WORKER.set(&worker, || {
ACTIVE_TASKS.unset(|| {
let mut tasks = self.active_tasks.lock().unwrap();
for task in tasks.drain() {
task.cancel();
}
});
});
}
}
impl fmt::Debug for Executor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Executor").finish_non_exhaustive()
}
}
struct ExecutorContext {
injector: Injector,
executor_id: usize,
executor_unparker: Unparker,
pool_manager: PoolManager,
msg_count: AtomicIsize,
}
impl ExecutorContext {
pub(super) fn new(
executor_id: usize,
executor_unparker: Unparker,
stealers_and_unparkers: impl Iterator<Item = (Stealer, Unparker)>,
) -> Self {
let (stealers, worker_unparkers): (Vec<_>, Vec<_>) =
stealers_and_unparkers.into_iter().unzip();
let worker_unparkers = worker_unparkers.into_boxed_slice();
Self {
injector: Injector::new(),
executor_id,
executor_unparker,
pool_manager: PoolManager::new(
worker_unparkers.len(),
stealers.into_boxed_slice(),
worker_unparkers,
),
msg_count: AtomicIsize::new(0),
}
}
}
struct CancellableFuture<T: Future> {
inner: T,
cancellation_key: usize,
}
impl<T: Future> CancellableFuture<T> {
fn new(fut: T, cancellation_key: usize) -> Self {
Self {
inner: fut,
cancellation_key,
}
}
}
impl<T: Future> Future for CancellableFuture<T> {
type Output = T::Output;
#[inline(always)]
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
unsafe { self.map_unchecked_mut(|s| &mut s.inner).poll(cx) }
}
}
impl<T: Future> Drop for CancellableFuture<T> {
fn drop(&mut self) {
let _ = ACTIVE_TASKS.map(|active_tasks| {
if let Ok(mut active_tasks) = active_tasks.lock() {
let _cancel_token = active_tasks.try_remove(self.cancellation_key);
}
});
}
}
pub(crate) struct Worker {
local_queue: LocalQueue,
fast_slot: Cell<Option<Runnable>>,
executor_context: Arc<ExecutorContext>,
}
impl Worker {
fn new(local_queue: LocalQueue, executor_context: Arc<ExecutorContext>) -> Self {
Self {
local_queue,
fast_slot: Cell::new(None),
executor_context,
}
}
}
fn schedule_task(task: Runnable, executor_id: usize) {
LOCAL_WORKER
.map(|worker| {
let pool_manager = &worker.executor_context.pool_manager;
let injector = &worker.executor_context.injector;
let local_queue = &worker.local_queue;
let fast_slot = &worker.fast_slot;
assert_eq!(
executor_id, worker.executor_context.executor_id,
"Tasks must be awaken on the same executor they are spawned on"
);
let prev_task = match fast_slot.replace(Some(task)) {
Some(t) => t,
None => return,
};
if let Err(prev_task) = local_queue.push(prev_task) {
if let Ok(drain) = local_queue.drain(|_| Bucket::capacity()) {
injector.push_bucket(Bucket::from_iter(drain));
local_queue.push(prev_task).unwrap();
} else {
injector.insert_task(prev_task);
}
}
if pool_manager.searching_worker_count() == 0 {
pool_manager.activate_worker_relaxed();
}
})
.expect("Tasks may not be awaken outside executor threads");
}
fn run_local_worker(worker: &Worker, id: usize, parker: Parker, abort_signal: Signal) {
let pool_manager = &worker.executor_context.pool_manager;
let injector = &worker.executor_context.injector;
let executor_unparker = &worker.executor_context.executor_unparker;
let local_queue = &worker.local_queue;
let fast_slot = &worker.fast_slot;
let update_msg_count = || {
let thread_msg_count = channel::THREAD_MSG_COUNT.replace(0);
worker
.executor_context
.msg_count
.fetch_add(thread_msg_count, Ordering::Relaxed);
};
let result = panic::catch_unwind(AssertUnwindSafe(|| {
const MAX_SEARCH_DURATION: Duration = Duration::from_nanos(1000);
let rng = Rng::new(id as u64);
loop {
update_msg_count();
if pool_manager.try_set_worker_inactive(id) {
parker.park();
} else if injector.is_empty() {
pool_manager.set_all_workers_inactive();
executor_unparker.unpark();
parker.park();
} else {
pool_manager.begin_worker_search();
}
if abort_signal.is_set() {
return;
}
let mut search_start = Instant::now();
loop {
if let Some(bucket) = injector.pop_bucket() {
let bucket_iter = bucket.into_iter();
while local_queue.spare_capacity() < bucket_iter.len() {}
local_queue.extend(bucket_iter);
} else {
let mut stealers = pool_manager.shuffled_stealers(Some(id), &rng);
if stealers.all(|stealer| {
stealer
.steal_and_pop(local_queue, |n| n - n / 2)
.map(|(task, _)| {
let prev_task = fast_slot.replace(Some(task));
assert!(prev_task.is_none());
})
.is_err()
}) {
if (Instant::now() - search_start) > MAX_SEARCH_DURATION {
pool_manager.end_worker_search();
break;
}
continue;
}
}
pool_manager.end_worker_search();
while let Some(task) = fast_slot.take().or_else(|| local_queue.pop()) {
if abort_signal.is_set() {
return;
}
task.run();
}
pool_manager.begin_worker_search();
search_start = Instant::now();
}
}
}));
if let Err(payload) = result {
let model_id = CURRENT_MODEL_ID.take();
pool_manager.register_panic(model_id, payload);
abort_signal.set();
pool_manager.activate_all_workers();
executor_unparker.unpark();
}
}