use std::future::Future;
use std::pin::Pin;
use std::sync::OnceLock;
use std::task::{Context, Poll, Waker};
static TOKIO_RT: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
fn tokio_runtime() -> &'static tokio::runtime::Runtime {
TOKIO_RT.get_or_init(|| {
tokio::runtime::Builder::new_multi_thread()
.worker_threads(1)
.enable_io()
.enable_time()
.build()
.expect("failed to create tokio compatibility runtime")
})
}
thread_local! {
static TOKIO_ENTERED: Cell<bool> = const { Cell::new(false) };
}
use std::cell::Cell;
fn ensure_tokio_context() {
TOKIO_ENTERED.with(|entered| {
if !entered.get() {
std::mem::forget(tokio_runtime().enter());
entered.set(true);
}
});
}
pub fn with_tokio<F, Fut>(f: F) -> TokioCompat<Fut>
where
F: FnOnce() -> Fut,
Fut: Future,
{
ensure_tokio_context();
let future = f();
TokioCompat { future }
}
pub struct TokioCompat<F> {
future: F,
}
impl<F: Future> Future for TokioCompat<F> {
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
let cross_waker = make_cross_waker(cx);
let mut cross_cx = Context::from_waker(&cross_waker);
let future = unsafe { Pin::new_unchecked(&mut this.future) };
future.poll(&mut cross_cx)
}
}
fn make_cross_waker(cx: &Context<'_>) -> Waker {
crate::waker::task_ptr_from_local_waker(cx.waker()).map_or_else(
|| cx.waker().clone(),
|task_ptr| {
let ctx = crate::cross_wake::cross_wake_context()
.expect("with_tokio() requires runtime context");
make_cross_task_waker(task_ptr, ctx)
},
)
}
struct CrossTaskWakerInner {
task_ref: crate::task::TaskRef,
ctx: std::sync::Arc<crate::cross_wake::CrossWakeContext>,
}
unsafe impl Send for CrossTaskWakerInner {}
unsafe impl Sync for CrossTaskWakerInner {}
use std::task::RawWaker;
use std::task::RawWakerVTable;
static CROSS_TASK_VTABLE: RawWakerVTable = RawWakerVTable::new(
cross_task_clone,
cross_task_wake,
cross_task_wake_by_ref,
cross_task_drop,
);
fn make_cross_task_waker(
task_ptr: *mut u8,
ctx: std::sync::Arc<crate::cross_wake::CrossWakeContext>,
) -> Waker {
let inner = std::sync::Arc::new(CrossTaskWakerInner {
task_ref: unsafe { crate::task::TaskRef::acquire(task_ptr) },
ctx,
});
let raw = RawWaker::new(
std::sync::Arc::into_raw(inner).cast::<()>(),
&CROSS_TASK_VTABLE,
);
unsafe { Waker::from_raw(raw) }
}
unsafe fn cross_task_clone(data: *const ()) -> RawWaker {
let arc = unsafe { std::sync::Arc::from_raw(data.cast::<CrossTaskWakerInner>()) };
let cloned = std::sync::Arc::clone(&arc);
let _ = std::sync::Arc::into_raw(arc);
RawWaker::new(
std::sync::Arc::into_raw(cloned).cast::<()>(),
&CROSS_TASK_VTABLE,
)
}
unsafe fn cross_task_wake(data: *const ()) {
let arc = unsafe { std::sync::Arc::from_raw(data.cast::<CrossTaskWakerInner>()) };
unsafe {
crate::cross_wake::wake_task_cross_thread(arc.task_ref.as_ptr(), &arc.ctx);
}
}
unsafe fn cross_task_wake_by_ref(data: *const ()) {
let inner = unsafe { &*data.cast::<CrossTaskWakerInner>() };
unsafe {
crate::cross_wake::wake_task_cross_thread(inner.task_ref.as_ptr(), &inner.ctx);
}
}
unsafe fn cross_task_drop(data: *const ()) {
let _arc = unsafe { std::sync::Arc::from_raw(data.cast::<CrossTaskWakerInner>()) };
}
pub fn spawn_on_tokio<F, T>(future: F) -> TokioJoinHandle<T>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let handle = tokio_runtime().handle().spawn(future);
TokioJoinHandle {
inner: handle,
_not_send: std::marker::PhantomData,
}
}
#[must_use = "dropping a TokioJoinHandle aborts the tokio task"]
pub struct TokioJoinHandle<T> {
inner: tokio::task::JoinHandle<T>,
_not_send: std::marker::PhantomData<*const ()>,
}
impl<T> Future for TokioJoinHandle<T> {
type Output = Result<T, TokioJoinError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let cross_waker = make_cross_waker(cx);
let mut cross_cx = Context::from_waker(&cross_waker);
let inner = Pin::new(&mut self.get_mut().inner);
match inner.poll(&mut cross_cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(val)) => Poll::Ready(Ok(val)),
Poll::Ready(Err(e)) => Poll::Ready(Err(TokioJoinError(e))),
}
}
}
impl<T> TokioJoinHandle<T> {
pub fn is_finished(&self) -> bool {
self.inner.is_finished()
}
pub fn abort(&self) {
self.inner.abort();
}
}
impl<T> Drop for TokioJoinHandle<T> {
fn drop(&mut self) {
self.inner.abort();
}
}
pub struct TokioJoinError(tokio::task::JoinError);
impl TokioJoinError {
pub fn is_cancelled(&self) -> bool {
self.0.is_cancelled()
}
pub fn is_panic(&self) -> bool {
self.0.is_panic()
}
}
impl std::fmt::Display for TokioJoinError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl std::fmt::Debug for TokioJoinError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl std::error::Error for TokioJoinError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.0.source()
}
}
#[cfg(test)]
mod arc_tests {
use super::*;
use crate::cross_wake::{CrossWakeContext, CrossWakeQueue};
use crate::task::{self, Task};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc as StdArc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::task::{Context, Poll};
struct ArcNoop;
impl Future for ArcNoop {
type Output = ();
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
Poll::Ready(())
}
}
fn make_test_task() -> *mut u8 {
let task = Box::new(Task::new_boxed(ArcNoop, 0));
Box::into_raw(task) as *mut u8
}
fn make_test_ctx() -> StdArc<CrossWakeContext> {
let poll = mio::Poll::new().expect("mio::Poll");
let waker = StdArc::new(
mio::Waker::new(poll.registry(), mio::Token(usize::MAX)).expect("mio::Waker"),
);
StdArc::new(CrossWakeContext {
queue: CrossWakeQueue::new(),
mio_waker: waker,
parked: AtomicBool::new(false),
})
}
#[test]
fn multi_clone_arc_terminal_ref_count() {
let ctx = make_test_ctx();
let task_ptr = make_test_task();
assert_eq!(unsafe { task::ref_count(task_ptr) }, 1);
let waker0 = make_cross_task_waker(task_ptr, StdArc::clone(&ctx));
assert_eq!(
unsafe { task::ref_count(task_ptr) },
2,
"make_cross_task_waker must take exactly one task-level ref"
);
let waker1 = waker0.clone();
let waker2 = waker0.clone();
let waker3 = waker0.clone();
let waker4 = waker0.clone();
assert_eq!(
unsafe { task::ref_count(task_ptr) },
2,
"Arc::clone must NOT bump task-level refcount"
);
drop(waker2);
drop(waker4);
drop(waker0);
drop(waker1);
assert_eq!(
unsafe { task::ref_count(task_ptr) },
2,
"intermediate Arc drops must NOT decrement task-level refcount"
);
drop(waker3);
assert_eq!(
unsafe { task::ref_count(task_ptr) },
1,
"last Arc drop must produce exactly ONE task-level ref_dec"
);
unsafe {
task::drop_task_future(task_ptr);
assert!(matches!(
task::complete_and_unref(task_ptr),
task::FreeAction::FreeBox
));
task::free_task(task_ptr);
}
}
#[test]
fn wake_by_value_consumes_one_arc_only() {
let ctx = make_test_ctx();
let task_ptr = make_test_task();
let waker0 = make_cross_task_waker(task_ptr, StdArc::clone(&ctx));
let waker1 = waker0.clone();
assert_eq!(unsafe { task::ref_count(task_ptr) }, 2);
waker0.wake();
assert_eq!(
unsafe { task::ref_count(task_ptr) },
2,
"wake-by-value with surviving sibling Arc must not ref_dec the task"
);
drop(waker1);
assert_eq!(unsafe { task::ref_count(task_ptr) }, 1);
let _ = ctx.queue.pop();
if unsafe { task::is_queued(task_ptr) } {
unsafe { task::clear_queued(task_ptr) };
}
unsafe {
task::drop_task_future(task_ptr);
assert!(matches!(
task::complete_and_unref(task_ptr),
task::FreeAction::FreeBox
));
task::free_task(task_ptr);
}
}
#[test]
#[ignore = "performance benchmark, run with --release --nocapture"]
fn bench_cross_task_clone() {
use std::time::Instant;
let ctx = make_test_ctx();
let task_ptr = make_test_task();
let waker = make_cross_task_waker(task_ptr, StdArc::clone(&ctx));
let warmup: Vec<Waker> = (0..10_000).map(|_| waker.clone()).collect();
drop(warmup);
const ITERS: usize = 1_000_000;
let mut clones = Vec::with_capacity(ITERS);
let start = Instant::now();
for _ in 0..ITERS {
clones.push(waker.clone());
}
let clone_elapsed = start.elapsed();
let drop_start = Instant::now();
drop(clones);
let drop_elapsed = drop_start.elapsed();
let ns_per_clone = clone_elapsed.as_nanos() / ITERS as u128;
let ns_per_drop = drop_elapsed.as_nanos() / ITERS as u128;
println!("cross_task_waker: clone={ns_per_clone}ns, drop={ns_per_drop}ns ({ITERS} iters)");
drop(waker);
unsafe {
task::drop_task_future(task_ptr);
let _ = task::complete_and_unref(task_ptr);
task::free_task(task_ptr);
}
}
}