use core_affinity::{CoreId, set_for_current};
use scoped_tls::scoped_thread_local;
use std::{
future::Future,
io::Result,
pin::Pin,
task::{Context, Poll},
thread::{self, JoinHandle},
};
use tokio::task::LocalSet;
scoped_thread_local!(pub(super) static LOCAL: LocalSet);
pub struct Task<T> {
inner: tokio::task::JoinHandle<T>,
}
impl<T> Future for Task<T> {
type Output = std::result::Result<T, TaskError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner)
.poll(cx)
.map(|result| result.map_err(|e| TaskError { inner: e }))
}
}
impl<T> Task<T> {
pub fn detach(self) {
drop(self.inner);
}
pub fn cancel(self) {
self.inner.abort();
}
}
#[derive(Debug)]
pub struct TaskError {
inner: tokio::task::JoinError,
}
impl std::fmt::Display for TaskError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.inner)
}
}
impl std::error::Error for TaskError {}
#[derive(Debug, Default)]
pub struct LocalExecutorBuilder {
core_id: Option<CoreId>,
name: String,
}
impl LocalExecutorBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn name(mut self, name: &str) -> Self {
self.name = String::from(name);
self
}
pub fn core_id(mut self, core_id: CoreId) -> Self {
self.core_id = Some(core_id);
self
}
pub fn run<T>(mut self, f: impl Future<Output = T>) -> T {
if let Some(core_id) = self.core_id.take() {
set_for_current(core_id);
}
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to build tokio runtime");
let local_set = LocalSet::new();
LOCAL.set(&local_set, || rt.block_on(local_set.run_until(f)))
}
pub fn spawn<G, F, T>(mut self, fut_gen: G) -> Result<JoinHandle<T>>
where
G: FnOnce() -> F + Send + 'static,
F: Future<Output = T> + 'static,
T: Send + 'static,
{
let mut core_id = self.core_id.take();
thread::Builder::new().name(self.name).spawn(move || {
if let Some(core_id) = core_id.take() {
set_for_current(core_id);
}
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to build tokio runtime");
let local_set = LocalSet::new();
LOCAL.set(&local_set, || rt.block_on(local_set.run_until(fut_gen())))
})
}
}
pub fn spawn_local<T: 'static>(future: impl Future<Output = T> + 'static) -> Task<T> {
if LOCAL.is_set() {
LOCAL.with(|local_set| Task {
inner: local_set.spawn_local(future),
})
} else {
panic!("`spawn_local()` must be called from a tokio `LocalSet`")
}
}
pub async fn yield_local() {
tokio::task::yield_now().await
}