use std::future::Future;
use std::pin::Pin;
use std::sync::OnceLock;
use std::task::{Context, Poll, Waker};
static TOKIO_RT: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
fn tokio_runtime() -> &'static tokio::runtime::Runtime {
TOKIO_RT.get_or_init(|| {
tokio::runtime::Builder::new_multi_thread()
.worker_threads(1)
.enable_io()
.enable_time()
.build()
.expect("failed to create tokio compatibility runtime")
})
}
thread_local! {
static TOKIO_ENTERED: Cell<bool> = const { Cell::new(false) };
}
use std::cell::Cell;
fn ensure_tokio_context() {
TOKIO_ENTERED.with(|entered| {
if !entered.get() {
std::mem::forget(tokio_runtime().enter());
entered.set(true);
}
});
}
pub fn with_tokio<F, Fut>(f: F) -> TokioCompat<Fut>
where
F: FnOnce() -> Fut,
Fut: Future,
{
ensure_tokio_context();
let future = f();
TokioCompat { future }
}
pub struct TokioCompat<F> {
future: F,
}
impl<F: Future> Future for TokioCompat<F> {
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
let cross_waker = make_cross_waker(cx);
let mut cross_cx = Context::from_waker(&cross_waker);
let future = unsafe { Pin::new_unchecked(&mut this.future) };
future.poll(&mut cross_cx)
}
}
fn make_cross_waker(cx: &Context<'_>) -> Waker {
crate::waker::task_ptr_from_local_waker(cx.waker()).map_or_else(
|| cx.waker().clone(),
|task_ptr| {
let ctx = crate::cross_wake::cross_wake_context()
.expect("with_tokio() requires runtime context");
make_cross_task_waker(task_ptr, ctx)
},
)
}
struct CrossTaskWakerData {
task_ptr: *mut u8,
ctx: std::sync::Arc<crate::cross_wake::CrossWakeContext>,
}
unsafe impl Send for CrossTaskWakerData {}
unsafe impl Sync for CrossTaskWakerData {}
use std::task::RawWaker;
use std::task::RawWakerVTable;
static CROSS_TASK_VTABLE: RawWakerVTable = RawWakerVTable::new(
cross_task_clone,
cross_task_wake,
cross_task_wake_by_ref,
cross_task_drop,
);
fn make_cross_task_waker(
task_ptr: *mut u8,
ctx: std::sync::Arc<crate::cross_wake::CrossWakeContext>,
) -> Waker {
unsafe { crate::task::ref_inc(task_ptr) };
let data = Box::into_raw(Box::new(CrossTaskWakerData { task_ptr, ctx }));
let raw = RawWaker::new(data.cast::<()>(), &CROSS_TASK_VTABLE);
unsafe { Waker::from_raw(raw) }
}
unsafe fn cross_task_clone(data: *const ()) -> RawWaker {
let orig = unsafe { &*data.cast::<CrossTaskWakerData>() };
unsafe { crate::task::ref_inc(orig.task_ptr) };
let cloned = Box::new(CrossTaskWakerData {
task_ptr: orig.task_ptr,
ctx: orig.ctx.clone(),
});
RawWaker::new(Box::into_raw(cloned).cast::<()>(), &CROSS_TASK_VTABLE)
}
unsafe fn cross_task_wake(data: *const ()) {
unsafe { cross_task_wake_by_ref(data) };
let boxed = unsafe { Box::from_raw(data.cast_mut().cast::<CrossTaskWakerData>()) };
let task_ptr = boxed.task_ptr;
match unsafe { crate::task::ref_dec(task_ptr) } {
crate::task::FreeAction::Retain => {}
crate::task::FreeAction::FreeBox | crate::task::FreeAction::FreeSlab => {
if unsafe { crate::task::try_set_queued(task_ptr) } {
unsafe { boxed.ctx.queue.push(task_ptr) };
if boxed.ctx.parked.load(std::sync::atomic::Ordering::Acquire) {
let _ = boxed.ctx.mio_waker.wake();
}
}
}
}
}
unsafe fn cross_task_wake_by_ref(data: *const ()) {
let waker_data = unsafe { &*data.cast::<CrossTaskWakerData>() };
unsafe {
crate::cross_wake::wake_task_cross_thread(waker_data.task_ptr, &waker_data.ctx);
}
}
unsafe fn cross_task_drop(data: *const ()) {
let boxed = unsafe { Box::from_raw(data.cast_mut().cast::<CrossTaskWakerData>()) };
let task_ptr = boxed.task_ptr;
match unsafe { crate::task::ref_dec(task_ptr) } {
crate::task::FreeAction::Retain => {}
crate::task::FreeAction::FreeBox | crate::task::FreeAction::FreeSlab => {
if unsafe { crate::task::try_set_queued(task_ptr) } {
unsafe { boxed.ctx.queue.push(task_ptr) };
if boxed.ctx.parked.load(std::sync::atomic::Ordering::Acquire) {
let _ = boxed.ctx.mio_waker.wake();
}
}
}
}
}
pub fn spawn_on_tokio<F, T>(future: F) -> TokioJoinHandle<T>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let handle = tokio_runtime().handle().spawn(future);
TokioJoinHandle {
inner: handle,
_not_send: std::marker::PhantomData,
}
}
#[must_use = "dropping a TokioJoinHandle aborts the tokio task"]
pub struct TokioJoinHandle<T> {
inner: tokio::task::JoinHandle<T>,
_not_send: std::marker::PhantomData<*const ()>,
}
impl<T> Future for TokioJoinHandle<T> {
type Output = Result<T, TokioJoinError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let cross_waker = make_cross_waker(cx);
let mut cross_cx = Context::from_waker(&cross_waker);
let inner = Pin::new(&mut self.get_mut().inner);
match inner.poll(&mut cross_cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(val)) => Poll::Ready(Ok(val)),
Poll::Ready(Err(e)) => Poll::Ready(Err(TokioJoinError(e))),
}
}
}
impl<T> TokioJoinHandle<T> {
pub fn is_finished(&self) -> bool {
self.inner.is_finished()
}
pub fn abort(&self) {
self.inner.abort();
}
}
impl<T> Drop for TokioJoinHandle<T> {
fn drop(&mut self) {
self.inner.abort();
}
}
pub struct TokioJoinError(tokio::task::JoinError);
impl TokioJoinError {
pub fn is_cancelled(&self) -> bool {
self.0.is_cancelled()
}
pub fn is_panic(&self) -> bool {
self.0.is_panic()
}
}
impl std::fmt::Display for TokioJoinError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl std::fmt::Debug for TokioJoinError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl std::error::Error for TokioJoinError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.0.source()
}
}