1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
use std::{
cell::RefCell,
io,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Condvar, Mutex,
},
};
use tokio::runtime::Runtime;
#[derive(Default)]
struct ShutdownBarrier {
guard_count: AtomicUsize,
shutdown_finalized: Mutex<bool>,
cvar: Condvar,
}
#[derive(PartialEq, Eq, Debug)]
enum BarrierContext {
/// Tokio Runtime Worker
RuntimeWorker,
/// Tokio Pool Worker
PoolWorker,
}
thread_local! {
static CONTEXT: RefCell<Option<BarrierContext>> = RefCell::new(None);
}
#[derive(PartialEq, Eq)]
pub(crate) enum Kind {
CurrentThread,
MultiThread,
}
/// Builds Tokio runtime configured with a shutdown barrier
pub struct Builder {
kind: Kind,
worker_threads: usize,
inner: tokio::runtime::Builder,
}
impl Builder {
/// Returns a new builder with the current thread scheduler selected.
pub fn new_current_thread() -> Builder {
Builder {
kind: Kind::CurrentThread,
worker_threads: 1,
inner: tokio::runtime::Builder::new_current_thread(),
}
}
/// Returns a new builder with the multi thread scheduler selected.
pub fn new_multi_thread() -> Builder {
let worker_threads = std::env::var("TOKIO_WORKER_THEADS")
.ok()
.and_then(|worker_threads| worker_threads.parse().ok())
.unwrap_or_else(num_cpus::get);
Builder {
kind: Kind::MultiThread,
worker_threads,
inner: tokio::runtime::Builder::new_multi_thread(),
}
}
/// Enables both I/O and time drivers.
pub fn enable_all(&mut self) -> &mut Self {
self.inner.enable_all();
self
}
/// Sets the number of worker threads the `Runtime` will use.
///
/// This can be any number above 0 though it is advised to keep this value
/// on the smaller side.
///
/// This will override the value read from environment variable `TOKIO_WORKER_THREADS`.
///
/// # Default
///
/// The default value is the number of cores available to the system.
///
/// When using the `current_thread` runtime this method has no effect.
///
/// # Examples
///
/// ## Multi threaded runtime with 4 threads
///
/// ```
/// use async_local::runtime;
///
/// // This will spawn a work-stealing runtime with 4 worker threads.
/// let rt = runtime::Builder::new_multi_thread()
/// .worker_threads(4)
/// .build()
/// .unwrap();
///
/// rt.spawn(async move {});
/// ```
///
/// ## Current thread runtime (will only run on the current thread via `Runtime::block_on`)
///
/// ```
/// use async_local::runtime;
///
/// // Create a runtime that _must_ be driven from a call
/// // to `Runtime::block_on`.
/// let rt = runtime::Builder::new_current_thread().build().unwrap();
///
/// // This will run the runtime and future on the current thread
/// rt.block_on(async move {});
/// ```
///
/// # Panics
///
/// This will panic if `val` is not larger than `0`.
#[track_caller]
pub fn worker_threads(&mut self, val: usize) -> &mut Self {
assert!(val > 0, "Worker threads cannot be set to 0");
if self.kind.ne(&Kind::CurrentThread) {
self.worker_threads = val;
self.inner.worker_threads(val);
}
self
}
/// Creates a Tokio Runtime configured with a barrier that rendezvous worker threads during shutdown as to ensure tasks never outlive local data owned by worker threads
pub fn build(&mut self) -> io::Result<Runtime> {
let worker_threads = self.worker_threads;
let barrier = Arc::new(ShutdownBarrier::default());
let on_thread_start = {
let barrier = barrier.clone();
move || {
let thread_count = barrier.guard_count.fetch_add(1, Ordering::Release);
CONTEXT.with(|context| {
if thread_count.ge(&worker_threads) {
*context.borrow_mut() = Some(BarrierContext::PoolWorker)
} else {
*context.borrow_mut() = Some(BarrierContext::RuntimeWorker)
}
});
}
};
let on_thread_stop = move || {
let thread_count = barrier.guard_count.fetch_sub(1, Ordering::AcqRel);
CONTEXT.with(|context| {
if thread_count.eq(&1) {
*barrier.shutdown_finalized.lock().unwrap() = true;
barrier.cvar.notify_all();
} else if context.borrow().eq(&Some(BarrierContext::RuntimeWorker)) {
let mut shutdown_finalized = barrier.shutdown_finalized.lock().unwrap();
while !*shutdown_finalized {
shutdown_finalized = barrier.cvar.wait(shutdown_finalized).unwrap();
}
}
});
};
self
.inner
.on_thread_start(on_thread_start)
.on_thread_stop(on_thread_stop)
.build()
}
}