use std::{
cell::{Cell, OnceCell},
future::Future,
sync::{
atomic::AtomicBool,
mpsc::{sync_channel, Receiver},
Arc, Mutex,
},
thread::JoinHandle,
};
use crate::{
pool::thread_pool::{HoochPool, HOOCH_POOL},
spawner::{new_executor_spawner, JoinHandle as SpawnerJoinHandle, Spawner},
task::{
manager::{TaskManager, TASK_MANAGER},
Task,
},
};
thread_local! {
pub static RUNTIME_GUARD: Cell<bool> = const { Cell::new(false) };
pub static RUNTIME: OnceCell<Arc<Runtime>> = const { OnceCell::new() };
}
pub struct RuntimeBuilder {
num_workers: usize,
}
impl RuntimeBuilder {
pub fn new() -> Self {
Self { num_workers: 1 }
}
pub fn num_workers(mut self, num_workers: usize) -> Self {
self.num_workers = num_workers;
self
}
pub fn build(self) -> Handle {
set_runtime_guard();
let (panic_tx, panic_rx) = sync_channel(1);
let panic_rx_arc = Arc::new(Mutex::new(panic_rx));
let mut spawner = None;
RUNTIME.with(|cell| {
cell.get_or_init(|| {
let num_workers = self.num_workers;
let panic_rx_clone = Arc::clone(&panic_rx_arc);
let mut executor_handles = Vec::with_capacity(num_workers);
HoochPool::init(1);
let hooch_pool = HoochPool::get();
let tm = TaskManager::get();
let mut runtime_txs = Vec::with_capacity(num_workers);
let mut tm_txs = Vec::with_capacity(num_workers);
let mut hp_txs = Vec::with_capacity(num_workers);
for i in 0..num_workers {
let (runtime_tx, runtime_rx) = std::sync::mpsc::sync_channel(1);
let (tm_tx, tm_rx) = std::sync::mpsc::sync_channel(1);
let (hp_tx, hp_rx) = std::sync::mpsc::sync_channel(1);
runtime_txs.push(runtime_tx);
tm_txs.push(tm_tx);
hp_txs.push(hp_tx);
let (executor, spawner_inner, exec_sender) =
new_executor_spawner(panic_tx.clone(), i);
tm.register_executor(executor.id(), exec_sender);
spawner = Some(spawner_inner);
let panic_tx_clone = panic_tx.clone();
let handle = std::thread::Builder::new()
.name(format!("executor_thread_{}", i))
.spawn(move || {
let tm = tm_rx.recv().unwrap();
TASK_MANAGER.with(move |cell| {
cell.get_or_init(move || tm);
});
let hooch_pool = hp_rx.recv().unwrap();
HOOCH_POOL.with(move |cell| {
cell.get_or_init(move || hooch_pool);
});
let runtime = runtime_rx.recv().unwrap();
RUNTIME.with(move |cell| {
cell.get_or_init(move || runtime);
});
if let Err(err) = std::panic::catch_unwind(|| {
set_runtime_guard();
executor.run();
exit_runtime_guard();
}) {
println!("EXECUTOR HAS PANICKED. ERROR: {:?}", err);
let _ = panic_tx_clone.send(());
}
})
.unwrap();
executor_handles.push(Some(handle));
}
tm_txs
.into_iter()
.for_each(|tx| tx.send(Arc::clone(&tm)).unwrap());
hp_txs
.into_iter()
.for_each(|tx| tx.send(Arc::clone(&hooch_pool)).unwrap());
let handle = Handle {
panic_rx: panic_rx_clone,
};
let rt = Arc::new(Runtime {
handles: executor_handles,
runtime_handle: handle,
});
runtime_txs.into_iter().for_each(|tx| {
tx.send(Arc::clone(&rt)).unwrap();
});
rt
});
});
Runtime::handle()
}
}
impl Default for RuntimeBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct Runtime {
handles: Vec<Option<JoinHandle<()>>>, runtime_handle: Handle, }
impl Drop for Runtime {
fn drop(&mut self) {
exit_runtime_guard();
}
}
impl Runtime {
pub fn handle() -> Handle {
let mut handle: Option<Handle> = None;
RUNTIME.with(|cell| {
let runtime = cell.get().unwrap();
let inner_handle = runtime.runtime_handle.clone();
handle = Some(inner_handle);
});
handle.unwrap()
}
pub fn dispatch_job<Fut, T>(&self, future: Fut) -> SpawnerJoinHandle<T>
where
T: Send + 'static,
Fut: Future<Output = T> + Send + 'static,
{
Spawner::spawn(future)
}
pub fn run_blocking<Fut, T>(&self, future: Fut) -> T
where
T: Send + 'static,
Fut: Future<Output = T> + Send + 'static,
{
let (tx, rx) = sync_channel(1);
let tm = TaskManager::get();
let task = Task {
future: Mutex::new(Some(Box::pin(async move {
let res = future.await;
let _ = tx.send(res); }))),
task_tag: Task::generate_tag(),
manager: Arc::downgrade(&tm),
abort: Arc::new(AtomicBool::new(false)),
};
tm.register_or_execute_non_blocking_task(Arc::new(task));
match rx.recv() {
Ok(result) => result,
Err(_) => panic!("Task failed to complete"),
}
}
}
fn set_runtime_guard() {
if RUNTIME_GUARD.get() {
panic!("Cannot run nested runtimes");
}
RUNTIME_GUARD.replace(true);
}
fn exit_runtime_guard() {
RUNTIME_GUARD.replace(false);
}
#[derive(Debug, Clone)]
pub struct Handle {
panic_rx: Arc<Mutex<Receiver<()>>>,
}
impl Handle {
pub fn run_blocking<Fut, T>(&self, future: Fut) -> T
where
T: Send + Unpin + 'static,
Fut: Future<Output = T> + Send + 'static,
{
let mut res = None;
RUNTIME.with(|cell| {
let runtime = cell.get().unwrap();
let blocking_res = runtime.run_blocking(future);
res = Some(blocking_res)
});
if self.panic_rx.lock().unwrap().try_recv().is_ok() {
exit_runtime_guard();
panic!("Executor panicked");
}
res.unwrap()
}
pub fn spawn<Fut, T>(&self, future: Fut) -> SpawnerJoinHandle<T>
where
T: Send + 'static,
Fut: Future<Output = T> + Send + 'static,
{
let mut join_handle = None;
RUNTIME.with(|cell| {
let runtime = cell.get().unwrap();
let join_handle_inner = runtime.dispatch_job(future);
join_handle = Some(join_handle_inner);
});
join_handle.unwrap()
}
pub fn num_workers(&self) -> usize {
let mut num_workers = 0;
RUNTIME.with(|cell| {
let runtime = cell.get().unwrap();
num_workers = runtime.handles.len();
});
num_workers
}
}
#[cfg(test)]
mod tests {
use std::sync::{
atomic::{AtomicU8, Ordering},
Arc,
};
use super::*;
async fn increment(ctr: Arc<AtomicU8>) {
ctr.fetch_add(1, Ordering::Relaxed);
}
#[test]
fn test_runtime_builder_default() {
assert!(RuntimeBuilder::default().num_workers == 1);
}
#[test]
fn test_runtime_builder_num_workers() {
assert!(RuntimeBuilder::default().num_workers(2).num_workers == 2);
}
#[test]
fn test_runtime_num_workers() {
let handle = RuntimeBuilder::default().num_workers(4).build();
assert!(handle.num_workers() == 4);
}
#[test]
fn test_run_blocking() {
let runtime = RuntimeBuilder::default().num_workers(2).build();
let ctr = Arc::new(AtomicU8::new(0));
runtime.run_blocking(increment(Arc::clone(&ctr)));
assert!(ctr.swap(0, Ordering::Relaxed) == 1);
}
#[test]
fn test_run_blocking_return() {
let handle = RuntimeBuilder::default().build();
let ctr = 1;
let res = handle.run_blocking(async move { ctr + 1 });
assert!(res == 2)
}
#[test]
#[should_panic]
fn test_handle_nested_panicking_task() {
let handle = RuntimeBuilder::default().build();
handle.run_blocking(async {
let _ = RuntimeBuilder::default().build();
});
}
#[test]
#[should_panic]
fn test_handle_scoped_panicking_task() {
RuntimeBuilder::default().build();
RuntimeBuilder::default().build();
}
#[test]
fn test_obtaining_multiple_handles_from_same_runtime() {
let handle = RuntimeBuilder::default().build();
let r = handle.run_blocking(async { 1 });
assert!(r == 1);
let handle = Runtime::handle();
let r = handle.run_blocking(async { 2 });
assert!(r == 2)
}
#[test]
fn test_multiple_runtimes_thread_task() {
let ct1 = Arc::new(Mutex::new(0));
let ct1_clone = Arc::clone(&ct1);
let t1 = std::thread::spawn(move || {
let handle = RuntimeBuilder::default().build();
handle.run_blocking(async move {
*ct1_clone.lock().unwrap() += 1;
});
});
let ct2 = Arc::new(Mutex::new(0));
let ct2_clone = Arc::clone(&ct2);
let t2 = std::thread::spawn(move || {
let handle = RuntimeBuilder::default().build();
handle.run_blocking(async move {
*ct2_clone.lock().unwrap() += 1;
});
});
t1.join().unwrap();
t2.join().unwrap();
assert!(*ct1.lock().unwrap() == 1);
assert!(*ct2.lock().unwrap() == 1);
}
}