use std::task::{Context, RawWaker, RawWakerVTable, Waker};
use crate::task;
std::thread_local! {
static READY_QUEUE: std::cell::Cell<*mut Vec<*mut u8>> =
const { std::cell::Cell::new(std::ptr::null_mut()) };
static DEFERRED_FREE: std::cell::Cell<*mut Vec<*mut u8>> =
const { std::cell::Cell::new(std::ptr::null_mut()) };
}
#[inline]
pub(crate) fn set_poll_context(
ready: &mut Vec<*mut u8>,
deferred_free: &mut Vec<*mut u8>,
) -> PollContextGuard {
let prev_ready = READY_QUEUE.with(|cell| cell.replace(ready as *mut Vec<*mut u8>));
let prev_free = DEFERRED_FREE.with(|cell| cell.replace(deferred_free as *mut Vec<*mut u8>));
PollContextGuard { prev_ready, prev_free }
}
pub(crate) struct PollContextGuard {
prev_ready: *mut Vec<*mut u8>,
prev_free: *mut Vec<*mut u8>,
}
impl Drop for PollContextGuard {
#[inline]
fn drop(&mut self) {
READY_QUEUE.with(|cell| cell.set(self.prev_ready));
DEFERRED_FREE.with(|cell| cell.set(self.prev_free));
}
}
pub(crate) struct ReusableWaker {
raw: [*const (); 6],
}
impl ReusableWaker {
#[inline]
pub(crate) fn new() -> Self {
Self {
raw: [
(&raw const VTABLE).cast::<()>(), std::ptr::null(), std::ptr::null(), std::ptr::null(), std::ptr::null(), std::ptr::null(), ],
}
}
#[inline]
pub(crate) fn init(&mut self) {
let waker_ptr = self.raw.as_ptr().cast::<()>();
self.raw[2] = waker_ptr;
self.raw[3] = waker_ptr;
}
#[inline]
pub(crate) unsafe fn set_task(&mut self, task_ptr: *mut u8) -> &mut Context<'_> {
self.raw[1] = task_ptr.cast::<()>();
unsafe { &mut *(self.raw.as_mut_ptr().add(2).cast::<Context<'_>>()) }
}
}
pub(crate) static VTABLE: RawWakerVTable = RawWakerVTable::new(clone_fn, wake_fn, wake_by_ref_fn, drop_fn);
pub(crate) fn task_ptr_from_local_waker(waker: &Waker) -> Option<*mut u8> {
let raw: &[*const (); 2] = unsafe { &*(waker as *const Waker).cast::<[*const (); 2]>() };
let vtable_ptr = raw[0];
let data_ptr = raw[1];
if vtable_ptr == (&raw const VTABLE).cast::<()>() {
Some(data_ptr as *mut u8)
} else {
None
}
}
unsafe fn clone_fn(data: *const ()) -> RawWaker {
unsafe { task::ref_inc(data as *mut u8) };
RawWaker::new(data, &VTABLE)
}
unsafe fn wake_fn(data: *const ()) {
unsafe { wake_impl(data) };
let should_free = unsafe { task::ref_dec(data as *mut u8) };
if should_free {
unsafe { free_completed_slot(data as *mut u8) };
}
}
unsafe fn wake_by_ref_fn(data: *const ()) {
unsafe { wake_impl(data) };
}
unsafe fn drop_fn(data: *const ()) {
let should_free = unsafe { task::ref_dec(data as *mut u8) };
if should_free {
unsafe { free_completed_slot(data as *mut u8) };
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::task::{Poll, RawWaker, Waker};
#[test]
fn reusable_waker_layout_matches_std() {
assert_eq!(std::mem::size_of::<Waker>(), 16);
assert_eq!(std::mem::align_of::<Waker>(), 8);
let sentinel = 0xDEAD_BEEF_u64 as *const ();
let raw = RawWaker::new(sentinel, &VTABLE);
let waker = std::mem::ManuallyDrop::new(unsafe { Waker::from_raw(raw) });
let bytes: &[u64; 2] =
unsafe { &*(&*waker as *const Waker as *const [u64; 2]) };
assert_eq!(
bytes[0],
(&raw const VTABLE) as u64,
"Waker layout changed: vtable not at offset 0"
);
assert_eq!(
bytes[1],
sentinel as u64,
"Waker layout changed: data not at offset 8"
);
}
#[test]
fn reusable_waker_delivers_correct_task_ptr() {
use crate::task::Task;
use std::future::Future;
use std::pin::Pin;
struct Noop;
impl Future for Noop {
type Output = ();
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
Poll::Ready(())
}
}
let task_a = Box::new(Task::new_boxed(Noop, 0));
let task_b = Box::new(Task::new_boxed(Noop, 0));
let ptr_a = Box::into_raw(task_a) as *mut u8;
let ptr_b = Box::into_raw(task_b) as *mut u8;
let mut reusable = ReusableWaker::new();
reusable.init();
let cx = unsafe { reusable.set_task(ptr_a) };
assert_eq!(unsafe { crate::task::ref_count(ptr_a) }, 1);
let cloned = cx.waker().clone();
let raw_a: &[u64; 2] =
unsafe { &*(&cloned as *const Waker as *const [u64; 2]) };
assert_eq!(raw_a[1], ptr_a as u64);
assert_eq!(unsafe { crate::task::ref_count(ptr_a) }, 2);
drop(cloned); assert_eq!(unsafe { crate::task::ref_count(ptr_a) }, 1);
let cx = unsafe { reusable.set_task(ptr_b) };
assert_eq!(unsafe { crate::task::ref_count(ptr_b) }, 1);
let cloned = cx.waker().clone();
let raw_b: &[u64; 2] =
unsafe { &*(&cloned as *const Waker as *const [u64; 2]) };
assert_eq!(raw_b[1], ptr_b as u64);
assert_eq!(unsafe { crate::task::ref_count(ptr_b) }, 2);
drop(cloned);
assert_eq!(unsafe { crate::task::ref_count(ptr_b) }, 1);
unsafe {
drop(Box::from_raw(ptr_a as *mut Task<Noop>));
drop(Box::from_raw(ptr_b as *mut Task<Noop>));
}
}
#[test]
fn context_layout_matches_assumption() {
let raw = RawWaker::new(std::ptr::null(), &VTABLE);
let waker = std::mem::ManuallyDrop::new(unsafe { Waker::from_raw(raw) });
let cx = Context::from_waker(&waker);
let cx_size = std::mem::size_of::<Context<'_>>();
assert!(
cx_size <= 32,
"Context size {} exceeds our 32-byte allocation",
cx_size
);
let cx_bytes: &[u64] =
unsafe { std::slice::from_raw_parts(&cx as *const _ as *const u64, cx_size / 8) };
assert_eq!(
cx_bytes[0],
&*waker as *const Waker as u64,
"Context first field is not &Waker"
);
}
}
#[cold]
#[inline(never)]
unsafe fn free_completed_slot(ptr: *mut u8) {
DEFERRED_FREE.with(|cell| {
let list_ptr = cell.get();
if !list_ptr.is_null() {
let list = unsafe { &mut *list_ptr };
list.push(ptr);
}
});
}
unsafe fn wake_impl(data: *const ()) {
let task_ptr = data as *mut u8;
if unsafe { task::is_completed(task_ptr) } {
return;
}
if unsafe { task::is_queued(task_ptr) } {
return;
}
unsafe { task::set_queued(task_ptr, true) };
READY_QUEUE.with(|cell| {
let queue_ptr = cell.get();
debug_assert!(
!queue_ptr.is_null(),
"waker fired outside poll cycle — task will be lost. \
Ensure wakers are only used within Runtime::block_on or \
Executor::poll scope."
);
if !queue_ptr.is_null() {
let queue = unsafe { &mut *queue_ptr };
queue.push(task_ptr);
}
});
}