use std::cell::{Cell, UnsafeCell};
use std::sync::Arc;
use std::sync::atomic::{AtomicPtr, Ordering};
use crate::task;
thread_local! {
static CTX_CROSS_WAKE: Cell<*const Arc<CrossWakeContext>> =
const { Cell::new(std::ptr::null()) };
static CURRENT_RUNTIME_CTX: Cell<*const CrossWakeContext> =
const { Cell::new(std::ptr::null()) };
}
pub(crate) fn install_cross_wake(ctx: &Arc<CrossWakeContext>) -> CrossWakeGuard {
let prev = CTX_CROSS_WAKE.with(|c| c.replace(std::ptr::from_ref(ctx)));
CrossWakeGuard { prev }
}
pub(crate) struct CrossWakeGuard {
prev: *const Arc<CrossWakeContext>,
}
impl Drop for CrossWakeGuard {
fn drop(&mut self) {
CTX_CROSS_WAKE.with(|c| c.set(self.prev));
}
}
pub(crate) fn cross_wake_context() -> Option<Arc<CrossWakeContext>> {
CTX_CROSS_WAKE.with(|c| {
let ptr = c.get();
if ptr.is_null() {
None
} else {
Some(unsafe { (*ptr).clone() })
}
})
}
pub(crate) fn install_runtime_cross_wake(arc: &Arc<CrossWakeContext>) -> RuntimeCrossWakeGuard {
let ptr = Arc::as_ptr(arc);
let prev = CURRENT_RUNTIME_CTX.with(|c| c.replace(ptr));
RuntimeCrossWakeGuard { prev }
}
pub(crate) struct RuntimeCrossWakeGuard {
prev: *const CrossWakeContext,
}
impl Drop for RuntimeCrossWakeGuard {
fn drop(&mut self) {
CURRENT_RUNTIME_CTX.with(|c| c.set(self.prev));
}
}
#[inline]
pub(crate) fn current_runtime_ctx() -> *const CrossWakeContext {
CURRENT_RUNTIME_CTX.with(Cell::get)
}
#[inline]
fn on_owning_executor(ctx: &CrossWakeContext) -> bool {
let installed = CURRENT_RUNTIME_CTX.with(Cell::get);
!installed.is_null() && std::ptr::eq(installed, ctx)
}
pub(crate) unsafe fn dispose_terminal(task_ptr: *mut u8) {
let ctx_ptr = unsafe { task::header_cross_wake_ctx(task_ptr) };
let on_executor = ctx_ptr.is_null() || {
let ctx = unsafe { &*ctx_ptr };
on_owning_executor(ctx)
};
if on_executor {
let _ = unsafe { crate::waker::try_defer_free(task_ptr) };
return;
}
let ctx = unsafe { &*ctx_ptr };
if unsafe { task::try_set_queued(task_ptr) } {
unsafe { ctx.queue.push(task_ptr) };
if ctx.parked.load(Ordering::Acquire) {
let _ = ctx.mio_waker.wake();
}
}
}
pub(crate) struct CrossWakeQueue {
head: UnsafeCell<*mut u8>,
tail: AtomicPtr<u8>,
stub: *mut AtomicPtr<u8>,
}
unsafe impl Send for CrossWakeQueue {}
unsafe impl Sync for CrossWakeQueue {}
impl CrossWakeQueue {
pub(crate) fn new() -> Self {
let stub = Box::into_raw(Box::new(AtomicPtr::new(std::ptr::null_mut())));
let stub_as_node = stub.cast::<u8>();
Self {
head: UnsafeCell::new(stub_as_node),
tail: AtomicPtr::new(stub_as_node),
stub,
}
}
#[inline]
fn stub_ptr(&self) -> *mut u8 {
self.stub.cast::<u8>()
}
#[inline]
unsafe fn next_of(&self, node: *mut u8) -> &AtomicPtr<u8> {
if node == self.stub_ptr() {
unsafe { &*self.stub }
} else {
unsafe { &*task::cross_next(node) }
}
}
}
impl Drop for CrossWakeQueue {
fn drop(&mut self) {
unsafe { drop(Box::from_raw(self.stub)) };
}
}
impl CrossWakeQueue {
pub(crate) unsafe fn push(&self, task_ptr: *mut u8) {
unsafe { self.next_of(task_ptr) }.store(std::ptr::null_mut(), Ordering::Relaxed);
let prev = self.tail.swap(task_ptr, Ordering::AcqRel);
unsafe { self.next_of(prev) }.store(task_ptr, Ordering::Release);
}
pub(crate) fn pop(&self) -> Option<*mut u8> {
let head_ref = unsafe { &mut *self.head.get() };
let mut head = *head_ref;
let mut next = unsafe { self.next_of(head) }.load(Ordering::Acquire);
let stub = self.stub_ptr();
if head == stub {
if next.is_null() {
return None; }
*head_ref = next;
head = next;
next = unsafe { self.next_of(head) }.load(Ordering::Acquire);
}
if !next.is_null() {
*head_ref = next;
return Some(head);
}
let tail = self.tail.load(Ordering::Acquire);
if head != tail {
return None;
}
unsafe { self.push(stub) };
next = unsafe { self.next_of(head) }.load(Ordering::Acquire);
if !next.is_null() {
*head_ref = next;
return Some(head);
}
None
}
}
pub(crate) struct CrossWakeContext {
pub(crate) queue: CrossWakeQueue,
pub(crate) mio_waker: Arc<mio::Waker>,
pub(crate) parked: std::sync::atomic::AtomicBool,
}
unsafe impl Send for CrossWakeContext {}
unsafe impl Sync for CrossWakeContext {}
pub(crate) unsafe fn wake_task_cross_thread(task_ptr: *mut u8, ctx: &CrossWakeContext) {
if unsafe { task::is_completed(task_ptr) } {
return;
}
if !unsafe { task::try_set_queued(task_ptr) } {
return;
}
unsafe { ctx.queue.push(task_ptr) };
if ctx.parked.load(Ordering::Acquire) {
let _ = ctx.mio_waker.wake();
}
}
const EMPTY: u8 = 0;
const STORED: u8 = 1;
const REGISTERING: u8 = 2;
pub(crate) struct TaskWakerSlot {
task_ptr: AtomicPtr<u8>,
cross_ctx: *const CrossWakeContext,
state: std::sync::atomic::AtomicU8,
}
unsafe impl Send for TaskWakerSlot {}
unsafe impl Sync for TaskWakerSlot {}
impl TaskWakerSlot {
pub(crate) fn new(cross_ctx: *const CrossWakeContext) -> Self {
Self {
task_ptr: AtomicPtr::new(std::ptr::null_mut()),
cross_ctx,
state: std::sync::atomic::AtomicU8::new(EMPTY),
}
}
pub(crate) fn register(&self, task_ptr: *mut u8) {
debug_assert!(
!task_ptr.is_null(),
"TaskWakerSlot::register called with null task_ptr — \
contract violation by caller (typically RecvFut::poll)"
);
let prev = self.state.swap(REGISTERING, Ordering::Acquire);
debug_assert_ne!(prev, REGISTERING, "concurrent register on TaskWakerSlot");
let task_ref = unsafe { crate::task::TaskRef::acquire(task_ptr) };
let ptr = task_ref.as_ptr();
std::mem::forget(task_ref);
let prev_ptr = self.task_ptr.swap(ptr, Ordering::AcqRel);
if !prev_ptr.is_null() {
drop(unsafe { crate::task::TaskRef::from_owned(prev_ptr) });
}
self.state.store(STORED, Ordering::Release);
}
pub(crate) fn try_register_local(&self, waker: &std::task::Waker) -> bool {
crate::waker::task_ptr_from_local_waker(waker).is_some_and(|task_ptr| {
self.register(task_ptr);
true
})
}
pub(crate) fn wake(&self) -> bool {
if self
.state
.compare_exchange(STORED, EMPTY, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
let task_ptr = self.task_ptr.swap(std::ptr::null_mut(), Ordering::Acquire);
if !task_ptr.is_null() {
let ctx = unsafe { &*self.cross_ctx };
unsafe { wake_task_cross_thread(task_ptr, ctx) };
drop(unsafe { crate::task::TaskRef::from_owned(task_ptr) });
return true;
}
}
false
}
pub(crate) fn has_waker(&self) -> bool {
self.state.load(Ordering::Acquire) == STORED
}
pub(crate) fn clear(&self) {
if self
.state
.compare_exchange(STORED, EMPTY, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
let task_ptr = self.task_ptr.swap(std::ptr::null_mut(), Ordering::Acquire);
if !task_ptr.is_null() {
drop(unsafe { crate::task::TaskRef::from_owned(task_ptr) });
}
}
}
}
impl Drop for TaskWakerSlot {
fn drop(&mut self) {
if *self.state.get_mut() == STORED {
let task_ptr = *self.task_ptr.get_mut();
if !task_ptr.is_null() {
drop(unsafe { crate::task::TaskRef::from_owned(task_ptr) });
}
}
}
}
pub(crate) struct FallbackWaker {
state: std::sync::atomic::AtomicU8,
waker: UnsafeCell<Option<std::task::Waker>>,
}
unsafe impl Send for FallbackWaker {}
unsafe impl Sync for FallbackWaker {}
impl FallbackWaker {
pub(crate) fn new() -> Self {
Self {
state: std::sync::atomic::AtomicU8::new(EMPTY),
waker: UnsafeCell::new(None),
}
}
pub(crate) fn register(&self, waker: &std::task::Waker) {
let prev = self.state.swap(REGISTERING, Ordering::Acquire);
debug_assert_ne!(prev, REGISTERING);
unsafe { *self.waker.get() = Some(waker.clone()) };
self.state.store(STORED, Ordering::Release);
}
pub(crate) fn wake(&self) -> bool {
if self
.state
.compare_exchange(STORED, EMPTY, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
let waker = unsafe { (*self.waker.get()).take() };
if let Some(w) = waker {
w.wake();
return true;
}
}
false
}
pub(crate) fn has_waker(&self) -> bool {
self.state.load(Ordering::Acquire) == STORED
}
}
impl Drop for FallbackWaker {
fn drop(&mut self) {
*self.waker.get_mut() = None;
}
}
pub(crate) struct TxWakerSlot {
state: std::sync::atomic::AtomicU8,
waker: UnsafeCell<Option<std::task::Waker>>,
}
unsafe impl Send for TxWakerSlot {}
unsafe impl Sync for TxWakerSlot {}
impl TxWakerSlot {
pub(crate) fn new() -> Self {
Self {
state: std::sync::atomic::AtomicU8::new(EMPTY),
waker: UnsafeCell::new(None),
}
}
pub(crate) fn register(&self, waker: &std::task::Waker) {
let prev = self.state.swap(REGISTERING, Ordering::Acquire);
debug_assert_ne!(prev, REGISTERING);
unsafe { *self.waker.get() = Some(waker.clone()) };
self.state.store(STORED, Ordering::Release);
}
pub(crate) fn wake(&self) -> bool {
if self
.state
.compare_exchange(STORED, EMPTY, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
if let Some(w) = unsafe { (*self.waker.get()).take() } {
w.wake();
return true;
}
}
false
}
pub(crate) fn has_waker(&self) -> bool {
self.state.load(Ordering::Acquire) == STORED
}
}
impl Drop for TxWakerSlot {
fn drop(&mut self) {
*self.waker.get_mut() = None;
}
}
#[cfg(test)]
pub(crate) mod uaf_scenarios {
use super::*;
use crate::task::{self, FreeAction, Task, TaskRef};
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::AtomicBool;
use std::task::{Context, Poll};
struct UafNoop;
impl Future for UafNoop {
type Output = ();
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
Poll::Ready(())
}
}
fn make_uaf_task() -> *mut u8 {
let task = Box::new(Task::new_boxed(UafNoop, 0));
Box::into_raw(task) as *mut u8
}
fn make_uaf_cross_ctx() -> Arc<CrossWakeContext> {
let poll = mio::Poll::new().unwrap();
let mio_waker = Arc::new(mio::Waker::new(poll.registry(), mio::Token(usize::MAX)).unwrap());
Arc::new(CrossWakeContext {
queue: CrossWakeQueue::new(),
mio_waker,
parked: AtomicBool::new(false),
})
}
pub(crate) fn waker_slot_uaf_when_task_freed_mid_dispatch() {
let cross_ctx = make_uaf_cross_ctx();
let task_ptr = make_uaf_task();
assert_eq!(
unsafe { task::ref_count(task_ptr) },
1,
"make_uaf_task should produce refcount=1"
);
let slot = TaskWakerSlot::new(Arc::as_ptr(&cross_ctx));
slot.register(task_ptr);
assert!(
slot.state
.compare_exchange(STORED, EMPTY, Ordering::AcqRel, Ordering::Relaxed)
.is_ok(),
"slot was registered; CAS STORED→EMPTY must succeed"
);
let captured = slot.task_ptr.swap(std::ptr::null_mut(), Ordering::Acquire);
assert_eq!(captured, task_ptr);
let action = unsafe { task::complete_and_unref(task_ptr) };
let pre_fix = match action {
FreeAction::FreeBox => {
#[cfg(not(miri))]
panic!(
"BUG-2 regression detected: register skipped ref_inc, \
so complete_and_unref produced FreeBox instead of \
Retain. Run under miri for the full UAF trace."
);
#[cfg(miri)]
{
unsafe { task::free_task(task_ptr) };
true
}
}
FreeAction::Retain => false,
FreeAction::FreeSlab => {
panic!("box-allocated test task must not produce FreeSlab");
}
};
unsafe { wake_task_cross_thread(captured, &cross_ctx) };
if !pre_fix {
drop(unsafe { TaskRef::from_owned(captured) });
drop(slot);
return;
}
drop(slot);
}
pub(crate) fn slot_drop_releases_ref_when_still_registered() {
let cross_ctx = make_uaf_cross_ctx();
let task_ptr = make_uaf_task();
unsafe { task::ref_inc(task_ptr) };
let action = unsafe { task::complete_and_unref(task_ptr) };
assert!(matches!(action, FreeAction::Retain));
let baseline_refcount = unsafe { task::ref_count(task_ptr) };
assert_eq!(baseline_refcount, 1, "after complete_and_unref, refcount=1");
let slot = TaskWakerSlot::new(Arc::as_ptr(&cross_ctx));
slot.register(task_ptr);
let after_register = unsafe { task::ref_count(task_ptr) };
drop(slot);
let after_drop = unsafe { task::ref_count(task_ptr) };
assert_eq!(
after_register,
after_drop + 1,
"Post-fix Drop must release the ref that register acquired. \
If this fires pre-fix (register skipped ref_inc), there's no \
Drop ref_dec to compensate, so the net is 0 instead of -1."
);
assert_eq!(
after_register,
baseline_refcount + 1,
"Post-fix register must bump refcount by 1. If this fires \
pre-fix, register skipped ref_inc — that's BUG-2's root cause."
);
let action = unsafe { task::ref_dec(task_ptr) };
match action {
FreeAction::FreeBox => unsafe { task::free_task(task_ptr) },
other => panic!("expected FreeBox on final ref_dec, got {other:?}"),
}
}
pub(crate) fn register_during_wake_does_not_leak_ref() {
let cross_ctx = make_uaf_cross_ctx();
let task_ptr = make_uaf_task();
unsafe { task::ref_inc(task_ptr) };
let action = unsafe { task::complete_and_unref(task_ptr) };
assert!(matches!(action, FreeAction::Retain));
let baseline = unsafe { task::ref_count(task_ptr) };
assert_eq!(baseline, 1, "baseline must be 1 (executor-style ref)");
let slot = TaskWakerSlot::new(Arc::as_ptr(&cross_ctx));
slot.register(task_ptr);
assert_eq!(
unsafe { task::ref_count(task_ptr) },
baseline + 1,
"initial register must take a ref (slot owns +1)"
);
let cas_ok = slot
.state
.compare_exchange(STORED, EMPTY, Ordering::AcqRel, Ordering::Relaxed)
.is_ok();
assert!(cas_ok, "wake's CAS must succeed when state is STORED");
slot.register(task_ptr);
assert_eq!(
unsafe { task::ref_count(task_ptr) },
baseline + 1,
"race register must NET to baseline+1 (slot still owns one ref). \
Pre-fix the gate skipped the release of the original; this \
assertion would fire baseline+2 — the leak."
);
let captured = slot.task_ptr.swap(std::ptr::null_mut(), Ordering::Acquire);
assert_eq!(captured, task_ptr);
drop(unsafe { TaskRef::from_owned(captured) });
assert_eq!(
unsafe { task::ref_count(task_ptr) },
baseline,
"after wake's release, slot owes 0 refs to task. Pre-fix \
this is baseline+1 (the leaked original)."
);
drop(slot);
assert_eq!(
unsafe { task::ref_count(task_ptr) },
baseline,
"Drop on a STORED-but-null-task_ptr slot must be a no-op for refcount"
);
match unsafe { task::ref_dec(task_ptr) } {
FreeAction::FreeBox => unsafe { task::free_task(task_ptr) },
other => panic!("expected FreeBox on final ref_dec, got {other:?}"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::task::Task;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
struct Noop;
impl Future for Noop {
type Output = ();
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
Poll::Ready(())
}
}
fn make_task() -> *mut u8 {
let task = Box::new(Task::new_boxed(Noop, 0));
Box::into_raw(task) as *mut u8
}
unsafe fn free(ptr: *mut u8) {
unsafe { task::free_task(ptr) };
}
#[test]
fn queue_push_pop_single() {
let q = CrossWakeQueue::new();
let t1 = make_task();
unsafe { q.push(t1) };
assert_eq!(q.pop(), Some(t1));
assert_eq!(q.pop(), None);
unsafe { free(t1) };
}
#[test]
fn queue_push_pop_multiple() {
let q = CrossWakeQueue::new();
let t1 = make_task();
let t2 = make_task();
let t3 = make_task();
unsafe { q.push(t1) };
unsafe { q.push(t2) };
unsafe { q.push(t3) };
assert_eq!(q.pop(), Some(t1));
assert_eq!(q.pop(), Some(t2));
assert_eq!(q.pop(), Some(t3));
assert_eq!(q.pop(), None);
unsafe { free(t1) };
unsafe { free(t2) };
unsafe { free(t3) };
}
#[test]
fn queue_interleaved_push_pop() {
let q = CrossWakeQueue::new();
let t1 = make_task();
let t2 = make_task();
unsafe { q.push(t1) };
assert_eq!(q.pop(), Some(t1));
unsafe { q.push(t2) };
assert_eq!(q.pop(), Some(t2));
assert_eq!(q.pop(), None);
unsafe { free(t1) };
unsafe { free(t2) };
}
#[test]
fn queue_empty() {
let q = CrossWakeQueue::new();
assert_eq!(q.pop(), None);
assert_eq!(q.pop(), None);
}
#[test]
fn queue_reuse_after_drain() {
let q = CrossWakeQueue::new();
let t1 = make_task();
for _ in 0..100 {
unsafe { q.push(t1) };
assert_eq!(q.pop(), Some(t1));
}
assert_eq!(q.pop(), None);
unsafe { free(t1) };
}
fn make_ctx() -> Arc<CrossWakeContext> {
let poll = mio::Poll::new().expect("mio::Poll");
let waker = mio::Waker::new(poll.registry(), mio::Token(0)).expect("mio::Waker");
Arc::new(CrossWakeContext {
queue: CrossWakeQueue::new(),
mio_waker: Arc::new(waker),
parked: std::sync::atomic::AtomicBool::new(false),
})
}
fn make_spawned_task(ctx: &Arc<CrossWakeContext>) -> *mut u8 {
struct Noop;
impl std::future::Future for Noop {
type Output = u64;
fn poll(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<u64> {
Poll::Ready(0)
}
}
crate::task::box_spawn_joinable(Noop, 0, Arc::as_ptr(ctx))
}
unsafe fn drive_to_terminal(ptr: *mut u8) {
unsafe {
crate::task::drop_task_future(ptr);
assert!(matches!(
crate::task::complete_and_unref(ptr),
crate::task::FreeAction::Retain
)); crate::task::clear_has_join(ptr);
assert!(matches!(
crate::task::ref_dec(ptr),
crate::task::FreeAction::FreeBox
)); assert!(crate::task::is_terminal(ptr));
}
}
#[test]
fn dispose_terminal_null_ctx_no_tls_leaks() {
struct Noop;
impl std::future::Future for Noop {
type Output = ();
fn poll(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
Poll::Ready(())
}
}
let task = Box::new(crate::task::Task::new_boxed(Noop, 0));
let ptr = Box::into_raw(task) as *mut u8;
unsafe {
crate::task::drop_task_future(ptr);
assert!(matches!(
crate::task::complete_and_unref(ptr),
crate::task::FreeAction::FreeBox
));
assert!(crate::task::is_terminal(ptr));
dispose_terminal(ptr);
assert!(crate::task::is_terminal(ptr));
crate::task::free_task(ptr);
}
}
#[test]
fn dispose_terminal_on_executor_defers_when_tls_set() {
let ctx = make_ctx();
let _guard = install_runtime_cross_wake(&ctx);
let ptr = make_spawned_task(&ctx);
let mut deferred: Vec<*mut u8> = Vec::new();
let mut ready: Vec<*mut u8> = Vec::new();
let _poll_guard = crate::waker::set_poll_context(&raw mut ready, &raw mut deferred);
unsafe {
drive_to_terminal(ptr);
dispose_terminal(ptr);
}
assert_eq!(deferred.len(), 1);
assert_eq!(deferred[0], ptr);
unsafe { crate::task::free_task(ptr) };
}
#[test]
fn dispose_terminal_on_executor_leaks_when_tls_null() {
let ctx = make_ctx();
let _guard = install_runtime_cross_wake(&ctx);
let ptr = make_spawned_task(&ctx);
unsafe {
drive_to_terminal(ptr);
dispose_terminal(ptr);
}
assert!(unsafe { crate::task::is_terminal(ptr) });
unsafe { crate::task::free_task(ptr) };
}
#[test]
fn dispose_terminal_off_thread_queues() {
let ctx = make_ctx();
let ptr = make_spawned_task(&ctx);
unsafe {
drive_to_terminal(ptr);
dispose_terminal(ptr);
}
let popped = ctx.queue.pop();
assert_eq!(popped, Some(ptr));
assert!(unsafe { crate::task::is_queued(ptr) });
unsafe {
crate::task::clear_queued(ptr);
crate::task::free_task(ptr);
}
}
#[test]
fn executor_drop_handles_terminal_in_cross_queue() {
use crate::Executor;
let ctx = make_ctx();
let mut exec = Executor::new(8);
exec.install_cross_wake_for_drop(Arc::clone(&ctx));
struct OnceFuture;
impl std::future::Future for OnceFuture {
type Output = ();
fn poll(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
Poll::Ready(())
}
}
let handle = exec.spawn_boxed(OnceFuture);
drop(handle);
exec.poll();
let kept_handle = exec.spawn_boxed(OnceFuture);
exec.poll();
let task_ptr = kept_handle.raw_ptr();
std::mem::forget(kept_handle);
unsafe {
crate::task::clear_has_join(task_ptr);
let action = crate::task::ref_dec(task_ptr);
assert!(matches!(action, crate::task::FreeAction::FreeBox));
}
unsafe {
assert!(crate::task::try_set_queued(task_ptr));
ctx.queue.push(task_ptr);
}
drop(exec);
}
}