use std::task::{RawWaker, RawWakerVTable, Waker};
use crate::task;
std::thread_local! {
static READY_QUEUE: std::cell::Cell<*mut Vec<*mut u8>> =
const { std::cell::Cell::new(std::ptr::null_mut()) };
static DEFERRED_FREE: std::cell::Cell<*mut Vec<*mut u8>> =
const { std::cell::Cell::new(std::ptr::null_mut()) };
}
#[inline]
pub(crate) fn set_poll_context(
ready: *mut Vec<*mut u8>,
deferred_free: *mut Vec<*mut u8>,
) -> PollContextGuard {
let prev_ready = READY_QUEUE.with(|cell| cell.replace(ready));
let prev_free = DEFERRED_FREE.with(|cell| cell.replace(deferred_free));
PollContextGuard {
prev_ready,
prev_free,
}
}
pub(crate) struct PollContextGuard {
prev_ready: *mut Vec<*mut u8>,
prev_free: *mut Vec<*mut u8>,
}
impl Drop for PollContextGuard {
#[inline]
fn drop(&mut self) {
READY_QUEUE.with(|cell| cell.set(self.prev_ready));
DEFERRED_FREE.with(|cell| cell.set(self.prev_free));
}
}
pub(crate) static VTABLE: RawWakerVTable =
RawWakerVTable::new(clone_fn, wake_fn, wake_by_ref_fn, drop_fn);
#[inline]
pub(crate) unsafe fn task_waker(ptr: *mut u8) -> Waker {
unsafe { task::ref_inc(ptr) };
let raw = RawWaker::new(ptr.cast(), &VTABLE);
unsafe { Waker::from_raw(raw) }
}
pub(crate) fn task_ptr_from_local_waker(waker: &Waker) -> Option<*mut u8> {
if waker.vtable() == &VTABLE {
Some(waker.data() as *mut u8)
} else {
None
}
}
unsafe fn clone_fn(data: *const ()) -> RawWaker {
unsafe { task::ref_inc(data as *mut u8) };
RawWaker::new(data, &VTABLE)
}
unsafe fn wake_fn(data: *const ()) {
unsafe { wake_impl(data) };
let should_free = unsafe { task::ref_dec(data as *mut u8) };
if should_free {
unsafe { free_completed_slot(data as *mut u8) };
}
}
unsafe fn wake_by_ref_fn(data: *const ()) {
unsafe { wake_impl(data) };
}
unsafe fn drop_fn(data: *const ()) {
let should_free = unsafe { task::ref_dec(data as *mut u8) };
if should_free {
unsafe { free_completed_slot(data as *mut u8) };
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::task::{RawWaker, Waker};
#[test]
fn task_ptr_from_local_waker_roundtrip() {
let sentinel = 0xDEAD_BEEF_usize as *mut u8;
let waker = unsafe { Waker::from_raw(RawWaker::new(sentinel.cast(), &VTABLE)) };
let waker = std::mem::ManuallyDrop::new(waker);
let ptr = task_ptr_from_local_waker(&waker);
assert_eq!(ptr, Some(sentinel));
}
#[test]
fn task_ptr_from_foreign_waker_returns_none() {
static OTHER: RawWakerVTable =
RawWakerVTable::new(|p| RawWaker::new(p, &OTHER), |_| {}, |_| {}, |_| {});
let waker = unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &OTHER)) };
let waker = std::mem::ManuallyDrop::new(waker);
assert!(task_ptr_from_local_waker(&waker).is_none());
}
}
#[cold]
#[inline(never)]
pub(crate) unsafe fn defer_free(ptr: *mut u8) {
unsafe { free_completed_slot(ptr) };
}
#[cold]
#[inline(never)]
unsafe fn free_completed_slot(ptr: *mut u8) {
DEFERRED_FREE.with(|cell| {
let list_ptr = cell.get();
if !list_ptr.is_null() {
let list = unsafe { &mut *list_ptr };
list.push(ptr);
}
});
}
unsafe fn wake_impl(data: *const ()) {
let task_ptr = data as *mut u8;
if unsafe { task::is_completed(task_ptr) } {
return;
}
if unsafe { task::is_queued(task_ptr) } {
return;
}
unsafe { task::set_queued(task_ptr, true) };
READY_QUEUE.with(|cell| {
let queue_ptr = cell.get();
debug_assert!(
!queue_ptr.is_null(),
"waker fired outside poll cycle — task will be lost. \
Ensure wakers are only used within Runtime::block_on or \
Executor::poll scope."
);
if !queue_ptr.is_null() {
let queue = unsafe { &mut *queue_ptr };
queue.push(task_ptr);
}
});
}