use std::{
cell::{Cell, OnceCell},
future::Future,
mem::MaybeUninit,
sync::{
atomic::{AtomicUsize, Ordering},
mpsc::{sync_channel, Receiver},
Arc, Mutex,
},
thread::JoinHandle,
};
use crate::{
executor::ExecutorTask,
spawner::{new_executor_spawner, JoinHandle as SpawnerJoinHandle, Spawner},
};
thread_local! {
pub static RUNTIME_GUARD: Cell<bool> = const { Cell::new(false) };
pub static RUNTIME: OnceCell<Arc<Runtime>> = const { OnceCell::new() };
}
pub struct RuntimeBuilder {
num_workers: usize,
}
impl RuntimeBuilder {
pub fn new() -> Self {
Self { num_workers: 1 }
}
pub fn num_workers(mut self, num_workers: usize) -> Self {
self.num_workers = num_workers;
self
}
pub fn build(self) -> Handle {
let (panic_tx, panic_rx) = sync_channel(1);
let panic_rx_arc = Arc::new(Mutex::new(panic_rx));
RUNTIME.with(|cell| {
cell.get_or_init(|| {
let panic_rx_clone = Arc::clone(&panic_rx_arc);
let mut executor_handles = Vec::with_capacity(self.num_workers);
let mut spawners = Vec::with_capacity(self.num_workers);
let mut runtime_txs = Vec::with_capacity(self.num_workers);
for i in 0..self.num_workers {
let (runtime_tx, runtime_rx) = std::sync::mpsc::sync_channel(1);
runtime_txs.push(runtime_tx);
let (executor, spawner) = new_executor_spawner(panic_tx.clone());
spawners.push(spawner);
let panic_tx_clone = panic_tx.clone();
let handle = std::thread::Builder::new()
.name(format!("executor_thread_{}", i))
.spawn(move || {
let runtime = runtime_rx.recv().unwrap();
RUNTIME.with(move |cell| {
cell.get_or_init(move || runtime);
});
if let Err(err) = std::panic::catch_unwind(|| {
set_runtime_guard();
executor.run();
exit_runtime_guard();
}) {
println!("EXECUTOR HAS PANICKED. ERROR: {:?}", err);
let _ = panic_tx_clone.send(());
}
})
.unwrap();
executor_handles.push(Some(handle));
}
let handle = Handle {
panic_rx: panic_rx_clone,
};
let rt = Arc::new(Runtime {
dispatch_worker: AtomicUsize::new(0),
spawners,
handles: executor_handles,
runtime_handle: handle,
});
runtime_txs.into_iter().for_each(|tx| {
tx.send(Arc::clone(&rt)).unwrap();
});
rt
});
});
Runtime::handle()
}
}
impl Default for RuntimeBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct Runtime {
dispatch_worker: AtomicUsize, spawners: Vec<Spawner>, handles: Vec<Option<JoinHandle<()>>>, runtime_handle: Handle, }
impl Drop for Runtime {
fn drop(&mut self) {
self.handles
.iter_mut()
.zip(&self.spawners)
.for_each(|(handle, spawner)| {
let handle = handle.take();
spawner.spawn_task(ExecutorTask::Finished);
if let Some(handle) = handle {
let _ = handle.join();
}
});
}
}
impl Runtime {
pub fn handle() -> Handle {
let mut handle: Option<Handle> = None;
RUNTIME.with(|cell| {
let runtime = cell.get().unwrap();
let inner_handle = runtime.runtime_handle.clone();
handle = Some(inner_handle);
});
handle.unwrap()
}
pub fn dispatch_job<Fut, T>(&self, future: Fut) -> SpawnerJoinHandle<T>
where
T: Send + 'static,
Fut: Future<Output = T> + Send + 'static,
{
let dispatch_idx = self.get_dispatch_worker_idx();
self.spawners[dispatch_idx].spawn_self(future)
}
pub fn run_blocking<Fut, T>(&self, future: Fut) -> T
where
T: Send + 'static,
Fut: Future<Output = T> + Send + 'static,
{
let (tx, rx) = sync_channel(1);
let dispatch_idx = self.get_dispatch_worker_idx();
self.spawners[dispatch_idx].spawn_self(async move {
let res = future.await;
tx.send(res).unwrap();
});
rx.recv().unwrap()
}
fn get_dispatch_worker_idx(&self) -> usize {
self.dispatch_worker.fetch_add(1, Ordering::Relaxed) % self.spawners.len()
}
}
fn set_runtime_guard() {
if RUNTIME_GUARD.get() {
panic!("Cannot run nested runtimes");
}
RUNTIME_GUARD.replace(true);
}
fn exit_runtime_guard() {
RUNTIME_GUARD.replace(false);
}
#[derive(Debug, Clone)]
pub struct Handle {
panic_rx: Arc<Mutex<Receiver<()>>>,
}
impl Handle {
pub fn run_blocking<Fut, T>(&self, future: Fut) -> T
where
T: Send + 'static,
Fut: Future<Output = T> + Send + 'static,
{
set_runtime_guard();
let mut res = None;
RUNTIME.with(|cell| {
let runtime = cell.get().unwrap();
let blocking_res = runtime.run_blocking(future);
res = Some(blocking_res)
});
if self.panic_rx.lock().unwrap().try_recv().is_ok() {
exit_runtime_guard();
panic!("Executor panicked");
}
exit_runtime_guard();
res.unwrap()
}
pub fn spawn<Fut, T>(&self, future: Fut) -> SpawnerJoinHandle<T>
where
T: Send + 'static,
Fut: Future<Output = T> + Send + 'static,
{
let mut join_handle: MaybeUninit<SpawnerJoinHandle<T>> = MaybeUninit::uninit();
let join_handle_ptr: *mut SpawnerJoinHandle<T> = join_handle.as_mut_ptr();
RUNTIME.with(|cell| {
let runtime = cell.get().unwrap();
let join_handle = runtime.dispatch_job(future);
unsafe { join_handle_ptr.write(join_handle) }
});
unsafe { join_handle.assume_init() }
}
pub fn num_workers(&self) -> usize {
let mut num_workers = 0;
RUNTIME.with(|cell| {
let runtime = cell.get().unwrap();
num_workers = runtime.spawners.len();
});
num_workers
}
}
#[cfg(test)]
mod tests {
use std::sync::{
atomic::{AtomicU8, Ordering},
Arc,
};
use super::*;
async fn increment(ctr: Arc<AtomicU8>) {
ctr.fetch_add(1, Ordering::Relaxed);
}
async fn get_thread_name() -> String {
std::thread::current().name().unwrap().to_string()
}
#[test]
fn test_runtime_builder_default() {
assert!(RuntimeBuilder::default().num_workers == 1);
}
#[test]
fn test_runtime_builder_num_workers() {
assert!(RuntimeBuilder::default().num_workers(2).num_workers == 2);
}
#[test]
fn test_runtime_num_workers() {
let handle = RuntimeBuilder::default().num_workers(4).build();
assert!(handle.num_workers() == 4);
}
#[test]
fn test_get_dispatch_worker_idx() {
let _ = RuntimeBuilder::default().num_workers(2).build();
RUNTIME.with(|cell| {
let runtime = cell.get().unwrap();
assert!(runtime.get_dispatch_worker_idx() == 0);
assert!(runtime.get_dispatch_worker_idx() == 1);
assert!(runtime.get_dispatch_worker_idx() == 0);
assert!(runtime.get_dispatch_worker_idx() == 1);
});
}
#[test]
fn test_run_blocking() {
let runtime = RuntimeBuilder::default().num_workers(2).build();
let ctr = Arc::new(AtomicU8::new(0));
runtime.run_blocking(increment(Arc::clone(&ctr)));
assert!(ctr.swap(0, Ordering::Relaxed) == 1);
}
#[test]
fn test_multithread_round_robin_dispatch() {
let handle = RuntimeBuilder::default().num_workers(3).build();
handle.run_blocking(async {
let thread_name = get_thread_name().await;
assert!(thread_name == "executor_thread_0");
});
handle.run_blocking(async {
let thread_name = get_thread_name().await;
assert!(thread_name == "executor_thread_1");
});
handle.run_blocking(async {
let thread_name = get_thread_name().await;
assert!(thread_name == "executor_thread_2");
});
handle.run_blocking(async {
let thread_name = get_thread_name().await;
assert!(thread_name == "executor_thread_0");
});
}
#[test]
fn test_run_blocking_return() {
let handle = RuntimeBuilder::default().build();
let ctr = 1;
let res = handle.run_blocking(async move { ctr + 1 });
assert!(res == 2)
}
#[test]
#[should_panic]
fn test_handle_panicking_task() {
let handle = RuntimeBuilder::default().build();
handle.run_blocking(async {
let _ = RuntimeBuilder::default().build();
});
}
}