#![allow(irrefutable_let_patterns)]
use crate::runtime::blocking::BlockingPool;
use crate::runtime::scheduler::CurrentThread;
use crate::runtime::{context, Builder, EnterGuard, Handle, BOX_FUTURE_THRESHOLD};
use crate::task::JoinHandle;
use crate::util::trace::SpawnMeta;
use std::future::Future;
use std::marker::PhantomData;
use std::mem;
use std::time::Duration;
#[derive(Debug)]
pub struct LocalRuntime {
scheduler: LocalRuntimeScheduler,
handle: Handle,
blocking_pool: BlockingPool,
_phantom: PhantomData<*mut u8>,
}
#[derive(Debug)]
pub(crate) enum LocalRuntimeScheduler {
CurrentThread(CurrentThread),
}
impl LocalRuntime {
pub(crate) fn from_parts(
scheduler: LocalRuntimeScheduler,
handle: Handle,
blocking_pool: BlockingPool,
) -> LocalRuntime {
LocalRuntime {
scheduler,
handle,
blocking_pool,
_phantom: Default::default(),
}
}
pub fn new() -> std::io::Result<LocalRuntime> {
Builder::new_current_thread()
.enable_all()
.build_local(Default::default())
}
pub fn handle(&self) -> &Handle {
&self.handle
}
#[track_caller]
pub fn spawn_local<F>(&self, future: F) -> JoinHandle<F::Output>
where
F: Future + 'static,
F::Output: 'static,
{
let fut_size = std::mem::size_of::<F>();
let meta = SpawnMeta::new_unnamed(fut_size);
unsafe {
if std::mem::size_of::<F>() > BOX_FUTURE_THRESHOLD {
self.handle.spawn_local_named(Box::pin(future), meta)
} else {
self.handle.spawn_local_named(future, meta)
}
}
}
#[track_caller]
pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
self.handle.spawn_blocking(func)
}
#[track_caller]
pub fn block_on<F: Future>(&self, future: F) -> F::Output {
let fut_size = mem::size_of::<F>();
let meta = SpawnMeta::new_unnamed(fut_size);
if std::mem::size_of::<F>() > BOX_FUTURE_THRESHOLD {
self.block_on_inner(Box::pin(future), meta)
} else {
self.block_on_inner(future, meta)
}
}
#[track_caller]
fn block_on_inner<F: Future>(&self, future: F, _meta: SpawnMeta<'_>) -> F::Output {
#[cfg(all(
tokio_unstable,
feature = "taskdump",
feature = "rt",
target_os = "linux",
any(target_arch = "aarch64", target_arch = "x86", target_arch = "x86_64")
))]
let future = crate::runtime::task::trace::Trace::root(future);
#[cfg(all(tokio_unstable, feature = "tracing"))]
let future = crate::util::trace::task(
future,
"block_on",
_meta,
crate::runtime::task::Id::next().as_u64(),
);
let _enter = self.enter();
if let LocalRuntimeScheduler::CurrentThread(exec) = &self.scheduler {
exec.block_on(&self.handle.inner, future)
} else {
unreachable!("LocalRuntime only supports current_thread")
}
}
pub fn enter(&self) -> EnterGuard<'_> {
self.handle.enter()
}
pub fn shutdown_timeout(mut self, duration: Duration) {
self.handle.inner.shutdown();
self.blocking_pool.shutdown(Some(duration));
}
pub fn shutdown_background(self) {
self.shutdown_timeout(Duration::from_nanos(0));
}
pub fn metrics(&self) -> crate::runtime::RuntimeMetrics {
self.handle.metrics()
}
}
impl Drop for LocalRuntime {
fn drop(&mut self) {
if let LocalRuntimeScheduler::CurrentThread(current_thread) = &mut self.scheduler {
let _guard = context::try_set_current(&self.handle.inner);
current_thread.shutdown(&self.handle.inner);
} else {
unreachable!("LocalRuntime only supports current-thread")
}
}
}
impl std::panic::UnwindSafe for LocalRuntime {}
impl std::panic::RefUnwindSafe for LocalRuntime {}