use once_cell::sync::{Lazy, OnceCell};
use rand::Rng;
use std::sync::Arc;
use std::thread::JoinHandle;
use std::time::Duration;
use thread_local::ThreadLocal;
use tokio::runtime::{Builder, Handle};
use tokio::sync::oneshot::{channel, Sender};
pub enum Runtime {
Steal(tokio::runtime::Runtime),
NoSteal(NoStealRuntime),
}
impl Runtime {
pub fn new_steal(threads: usize, name: &str) -> Self {
Self::Steal(
Builder::new_multi_thread()
.enable_all()
.worker_threads(threads)
.thread_name(name)
.build()
.unwrap(),
)
}
pub fn new_no_steal(threads: usize, name: &str) -> Self {
Self::NoSteal(NoStealRuntime::new(threads, name))
}
pub fn get_handle(&self) -> &Handle {
match self {
Self::Steal(r) => r.handle(),
Self::NoSteal(r) => r.get_runtime(),
}
}
pub fn shutdown_timeout(self, timeout: Duration) {
match self {
Self::Steal(r) => r.shutdown_timeout(timeout),
Self::NoSteal(r) => r.shutdown_timeout(timeout),
}
}
}
static CURRENT_HANDLE: Lazy<ThreadLocal<Pools>> = Lazy::new(ThreadLocal::new);
pub fn current_handle() -> Handle {
if let Some(pools) = CURRENT_HANDLE.get() {
let pools = pools.get().unwrap();
let mut rng = rand::thread_rng();
let index = rng.gen_range(0..pools.len());
pools[index].clone()
} else {
Handle::current()
}
}
type Control = (Sender<Duration>, JoinHandle<()>);
type Pools = Arc<OnceCell<Box<[Handle]>>>;
pub struct NoStealRuntime {
threads: usize,
name: String,
pools: Pools,
controls: OnceCell<Vec<Control>>,
}
impl NoStealRuntime {
pub fn new(threads: usize, name: &str) -> Self {
assert!(threads != 0);
NoStealRuntime {
threads,
name: name.to_string(),
pools: Arc::new(OnceCell::new()),
controls: OnceCell::new(),
}
}
fn init_pools(&self) -> (Box<[Handle]>, Vec<Control>) {
let mut pools = Vec::with_capacity(self.threads);
let mut controls = Vec::with_capacity(self.threads);
for _ in 0..self.threads {
let rt = Builder::new_current_thread().enable_all().build().unwrap();
let handler = rt.handle().clone();
let (tx, rx) = channel::<Duration>();
let pools_ref = self.pools.clone();
let join = std::thread::Builder::new()
.name(self.name.clone())
.spawn(move || {
CURRENT_HANDLE.get_or(|| pools_ref);
if let Ok(timeout) = rt.block_on(rx) {
rt.shutdown_timeout(timeout);
} })
.unwrap();
pools.push(handler);
controls.push((tx, join));
}
(pools.into_boxed_slice(), controls)
}
pub fn get_runtime(&self) -> &Handle {
let mut rng = rand::thread_rng();
let index = rng.gen_range(0..self.threads);
self.get_runtime_at(index)
}
pub fn threads(&self) -> usize {
self.threads
}
fn get_pools(&self) -> &[Handle] {
if let Some(p) = self.pools.get() {
p
} else {
let (pools, controls) = self.init_pools();
match self.pools.try_insert(pools) {
Ok(p) => {
self.controls.set(controls).unwrap();
p
}
Err((p, _my_pools)) => p,
}
}
}
pub fn get_runtime_at(&self, index: usize) -> &Handle {
let pools = self.get_pools();
&pools[index]
}
pub fn shutdown_timeout(mut self, timeout: Duration) {
if let Some(controls) = self.controls.take() {
let (txs, joins): (Vec<Sender<_>>, Vec<JoinHandle<()>>) = controls.into_iter().unzip();
for tx in txs {
let _ = tx.send(timeout); }
for join in joins {
let _ = join.join(); }
} }
}
#[test]
fn test_steal_runtime() {
use tokio::time::{sleep, Duration};
let threads = 2;
let rt = Runtime::new_steal(threads, "test");
let handle = rt.get_handle();
let ret = handle.block_on(async {
sleep(Duration::from_secs(1)).await;
let handle = current_handle();
let join = handle.spawn(async {
sleep(Duration::from_secs(1)).await;
});
join.await.unwrap();
1
});
#[cfg(target_os = "linux")]
assert_eq!(handle.metrics().num_workers(), threads);
assert_eq!(ret, 1);
}
#[test]
fn test_no_steal_runtime() {
use tokio::time::{sleep, Duration};
let rt = Runtime::new_no_steal(2, "test");
let handle = rt.get_handle();
let ret = handle.block_on(async {
sleep(Duration::from_secs(1)).await;
let handle = current_handle();
let join = handle.spawn(async {
sleep(Duration::from_secs(1)).await;
});
join.await.unwrap();
1
});
assert_eq!(ret, 1);
}
#[test]
fn test_no_steal_shutdown() {
use tokio::time::{sleep, Duration};
let rt = Runtime::new_no_steal(2, "test");
let handle = rt.get_handle();
let ret = handle.block_on(async {
sleep(Duration::from_secs(1)).await;
let handle = current_handle();
let join = handle.spawn(async {
sleep(Duration::from_secs(1)).await;
});
join.await.unwrap();
1
});
assert_eq!(ret, 1);
rt.shutdown_timeout(Duration::from_secs(1));
}