use std::cell::UnsafeCell;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicU16, Ordering};
use std::task::{Context, Poll, Waker};
const HAS_JOIN: u8 = 0b001;
const OUTPUT_TAKEN: u8 = 0b010;
const ABORTED: u8 = 0b100;
pub const TASK_HEADER_SIZE: usize = 64;
#[repr(C)]
pub(crate) struct Task<S> {
poll_fn: unsafe fn(*mut u8, &mut Context<'_>) -> Poll<()>,
drop_fn: unsafe fn(*mut u8),
free_fn: unsafe fn(*mut u8),
is_queued: AtomicBool,
is_completed: AtomicBool,
ref_count: AtomicU16,
tracker_key: u32,
cross_next: AtomicPtr<u8>,
join_waker: UnsafeCell<Option<Waker>>,
storage_offset: u16,
flags: std::cell::Cell<u8>,
_pad: [u8; 5],
storage: S,
}
#[repr(C)]
pub(crate) union FutureOrOutput<F, T> {
pub(crate) future: std::mem::ManuallyDrop<F>,
pub(crate) output: std::mem::ManuallyDrop<T>,
}
const _: () = {
assert!(std::mem::size_of::<Task<()>>() == TASK_HEADER_SIZE);
};
impl<F: Future<Output = ()> + 'static> Task<F> {
#[cfg(test)]
#[inline]
pub(crate) fn new_boxed(future: F, tracker_key: u32) -> Self {
Self {
poll_fn: poll_join::<F>,
drop_fn: drop_future::<F>,
free_fn: box_free::<F>,
is_queued: AtomicBool::new(false),
is_completed: AtomicBool::new(false),
ref_count: AtomicU16::new(1),
tracker_key,
cross_next: AtomicPtr::new(std::ptr::null_mut()),
join_waker: UnsafeCell::new(None),
flags: std::cell::Cell::new(0),
storage_offset: std::mem::offset_of!(Task<F>, storage) as u16,
_pad: [0; 5],
storage: future,
}
}
}
pub(crate) fn box_spawn_joinable<F>(future: F, tracker_key: u32) -> *mut u8
where
F: Future + 'static,
F::Output: 'static,
{
type Storage<F> = FutureOrOutput<F, <F as Future>::Output>;
let task: Task<Storage<F>> = Task {
poll_fn: poll_join::<F>,
drop_fn: drop_future_in_union::<F>,
free_fn: box_free::<Storage<F>>,
is_queued: AtomicBool::new(false),
is_completed: AtomicBool::new(false),
ref_count: AtomicU16::new(2), tracker_key,
cross_next: AtomicPtr::new(std::ptr::null_mut()),
join_waker: UnsafeCell::new(None),
flags: std::cell::Cell::new(HAS_JOIN),
storage_offset: std::mem::offset_of!(Task<Storage<F>>, storage) as u16,
_pad: [0; 5],
storage: FutureOrOutput {
future: std::mem::ManuallyDrop::new(future),
},
};
Box::into_raw(Box::new(task)) as *mut u8
}
pub(crate) fn new_joinable_slab<F>(
future: F,
tracker_key: u32,
free_fn: unsafe fn(*mut u8),
) -> Task<FutureOrOutput<F, F::Output>>
where
F: Future + 'static,
F::Output: 'static,
{
type Storage<F> = FutureOrOutput<F, <F as Future>::Output>;
Task {
poll_fn: poll_join::<F>,
drop_fn: drop_future_in_union::<F>,
free_fn,
is_queued: AtomicBool::new(false),
is_completed: AtomicBool::new(false),
ref_count: AtomicU16::new(2), tracker_key,
cross_next: AtomicPtr::new(std::ptr::null_mut()),
join_waker: UnsafeCell::new(None),
flags: std::cell::Cell::new(HAS_JOIN),
storage_offset: std::mem::offset_of!(Task<Storage<F>>, storage) as u16,
_pad: [0; 5],
storage: FutureOrOutput {
future: std::mem::ManuallyDrop::new(future),
},
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub(crate) struct TaskId(pub(crate) *mut u8);
impl TaskId {
#[allow(dead_code)]
pub(crate) fn as_ptr(&self) -> *mut u8 {
self.0
}
}
#[must_use = "dropping a JoinHandle detaches the task — await it or call .abort()"]
pub struct JoinHandle<T> {
ptr: *mut u8,
_marker: PhantomData<T>,
_not_send: PhantomData<*const ()>, }
impl<T: 'static> Future for JoinHandle<T> {
type Output = T;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> {
let ptr = self.ptr;
if unsafe { is_completed(ptr) } {
let flags = unsafe { task_flags(ptr) };
assert!(
flags & ABORTED == 0,
"polled JoinHandle after task was aborted"
);
let output_ptr = unsafe { ptr.add(storage_offset(ptr)) };
let value = unsafe { std::ptr::read(output_ptr.cast::<T>()) };
unsafe { set_flag(ptr, OUTPUT_TAKEN) };
Poll::Ready(value)
} else {
unsafe { set_join_waker(ptr, cx.waker().clone()) };
Poll::Pending
}
}
}
impl<T> JoinHandle<T> {
pub(crate) fn new(ptr: *mut u8) -> Self {
Self {
ptr,
_marker: PhantomData,
_not_send: PhantomData,
}
}
pub fn is_finished(&self) -> bool {
unsafe { is_completed(self.ptr) }
}
#[must_use = "returns whether the task was still running"]
pub fn abort(self) -> bool {
let ptr = self.ptr;
let was_running = !unsafe { is_completed(ptr) };
if was_running {
unsafe { set_flag(ptr, ABORTED) };
}
was_running
}
}
impl<T> Drop for JoinHandle<T> {
fn drop(&mut self) {
let ptr = self.ptr;
let flags = unsafe { task_flags(ptr) };
if unsafe { is_completed(ptr) } && (flags & OUTPUT_TAKEN == 0) && (flags & ABORTED == 0) {
unsafe { drop_task_future(ptr) };
}
unsafe { clear_flag(ptr, HAS_JOIN) };
let _ = unsafe { take_join_waker(ptr) };
let should_free = unsafe { ref_dec(ptr) };
if should_free {
unsafe { defer_free_slot(ptr) };
}
}
}
unsafe fn defer_free_slot(ptr: *mut u8) {
unsafe { crate::waker::defer_free(ptr) };
}
#[inline]
pub(crate) unsafe fn tracker_key(ptr: *mut u8) -> u32 {
unsafe { *(ptr.add(28).cast::<u32>()) }
}
#[inline]
pub(crate) unsafe fn ref_inc(ptr: *mut u8) {
let rc = unsafe { &*ptr.add(26).cast::<AtomicU16>() };
let prev = rc.fetch_add(1, Ordering::Relaxed);
assert!(prev < u16::MAX, "waker refcount overflow");
}
#[inline]
pub(crate) unsafe fn ref_dec(ptr: *mut u8) -> bool {
let rc = unsafe { &*ptr.add(26).cast::<AtomicU16>() };
let prev = rc.fetch_sub(1, Ordering::AcqRel);
debug_assert!(prev > 0, "waker refcount underflow");
prev == 1
}
#[allow(dead_code)]
#[inline]
pub(crate) unsafe fn ref_count(ptr: *mut u8) -> u16 {
unsafe { &*ptr.add(26).cast::<AtomicU16>() }.load(Ordering::Relaxed)
}
#[inline]
pub(crate) unsafe fn set_completed(ptr: *mut u8) {
unsafe { &*ptr.add(25).cast::<AtomicBool>() }.store(true, Ordering::Release);
}
#[inline]
pub(crate) unsafe fn is_completed(ptr: *mut u8) -> bool {
unsafe { &*ptr.add(25).cast::<AtomicBool>() }.load(Ordering::Acquire)
}
#[inline]
#[allow(dead_code)]
pub(crate) unsafe fn cross_next(ptr: *mut u8) -> &'static AtomicPtr<u8> {
unsafe { &*ptr.add(32).cast::<AtomicPtr<u8>>() }
}
#[inline]
pub(crate) unsafe fn is_queued(ptr: *mut u8) -> bool {
unsafe { &*ptr.add(24).cast::<AtomicBool>() }.load(Ordering::Relaxed)
}
#[inline]
pub(crate) unsafe fn set_queued(ptr: *mut u8, queued: bool) {
unsafe { &*ptr.add(24).cast::<AtomicBool>() }.store(queued, Ordering::Relaxed);
}
#[inline]
pub(crate) unsafe fn try_set_queued(ptr: *mut u8) -> bool {
let queued = unsafe { &*ptr.add(24).cast::<AtomicBool>() };
queued
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
}
#[inline]
pub(crate) unsafe fn storage_offset(ptr: *mut u8) -> usize {
unsafe { *(ptr.add(56).cast::<u16>()) as usize }
}
#[inline]
unsafe fn task_flags(ptr: *mut u8) -> u8 {
unsafe { &*ptr.add(58).cast::<std::cell::Cell<u8>>() }.get()
}
#[inline]
unsafe fn set_flag(ptr: *mut u8, flag: u8) {
let cell = unsafe { &*ptr.add(58).cast::<std::cell::Cell<u8>>() };
cell.set(cell.get() | flag);
}
#[inline]
unsafe fn clear_flag(ptr: *mut u8, flag: u8) {
let cell = unsafe { &*ptr.add(58).cast::<std::cell::Cell<u8>>() };
cell.set(cell.get() & !flag);
}
#[inline]
pub(crate) unsafe fn has_join(ptr: *mut u8) -> bool {
(unsafe { task_flags(ptr) }) & HAS_JOIN != 0
}
#[inline]
pub(crate) unsafe fn is_aborted(ptr: *mut u8) -> bool {
(unsafe { task_flags(ptr) }) & ABORTED != 0
}
#[inline]
unsafe fn set_join_waker(ptr: *mut u8, waker: Waker) {
let cell = unsafe { &*ptr.add(40).cast::<UnsafeCell<Option<Waker>>>() };
unsafe { *cell.get() = Some(waker) };
}
#[inline]
pub(crate) unsafe fn take_join_waker(ptr: *mut u8) -> Option<Waker> {
let cell = unsafe { &*ptr.add(40).cast::<UnsafeCell<Option<Waker>>>() };
unsafe { (*cell.get()).take() }
}
#[inline]
pub(crate) unsafe fn poll_task(ptr: *mut u8, cx: &mut Context<'_>) -> Poll<()> {
let poll_fn: unsafe fn(*mut u8, &mut Context<'_>) -> Poll<()> =
unsafe { *(ptr as *const unsafe fn(*mut u8, &mut Context<'_>) -> Poll<()>) };
unsafe { poll_fn(ptr, cx) }
}
#[inline]
pub(crate) unsafe fn drop_task_future(ptr: *mut u8) {
let drop_fn: unsafe fn(*mut u8) = unsafe { *(ptr.add(8) as *const unsafe fn(*mut u8)) };
unsafe { drop_fn(ptr) }
}
#[inline]
pub(crate) unsafe fn free_task(ptr: *mut u8) {
let free_fn: unsafe fn(*mut u8) = unsafe { *(ptr.add(16) as *const unsafe fn(*mut u8)) };
unsafe { free_fn(ptr) }
}
unsafe fn poll_join<F: Future>(ptr: *mut u8, cx: &mut Context<'_>) -> Poll<()>
where
F::Output: 'static,
{
if unsafe { is_aborted(ptr) } {
return Poll::Ready(());
}
let future_ptr = unsafe { ptr.add(storage_offset(ptr)) };
let future = unsafe { Pin::new_unchecked(&mut *future_ptr.cast::<F>()) };
match future.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(value) => {
let drop_fn_slot = unsafe { ptr.add(8).cast::<unsafe fn(*mut u8)>() };
unsafe { *drop_fn_slot = drop_noop };
unsafe { std::ptr::drop_in_place(future_ptr.cast::<F>()) };
unsafe { std::ptr::write(future_ptr.cast::<F::Output>(), value) };
unsafe { *drop_fn_slot = drop_output::<F::Output> };
Poll::Ready(())
}
}
}
#[cfg(test)]
unsafe fn drop_future<F>(ptr: *mut u8) {
let future_ptr = unsafe { ptr.add(storage_offset(ptr)) };
unsafe { std::ptr::drop_in_place(future_ptr.cast::<F>()) }
}
unsafe fn drop_future_in_union<F: Future>(ptr: *mut u8) {
let storage_ptr = unsafe { ptr.add(storage_offset(ptr)) };
unsafe { std::ptr::drop_in_place(storage_ptr.cast::<F>()) }
}
unsafe fn drop_noop(_ptr: *mut u8) {}
unsafe fn drop_output<T>(ptr: *mut u8) {
let output_ptr = unsafe { ptr.add(storage_offset(ptr)) };
unsafe { std::ptr::drop_in_place(output_ptr.cast::<T>()) }
}
unsafe fn box_free<F>(ptr: *mut u8) {
let layout = std::alloc::Layout::new::<Task<F>>();
unsafe { std::alloc::dealloc(ptr, layout) }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn task_header_size() {
assert_eq!(TASK_HEADER_SIZE, 64);
assert_eq!(std::mem::size_of::<Task<()>>(), 64);
}
#[test]
fn task_layout_offsets() {
assert_eq!(std::mem::offset_of!(Task<()>, poll_fn), 0);
assert_eq!(std::mem::offset_of!(Task<()>, drop_fn), 8);
assert_eq!(std::mem::offset_of!(Task<()>, free_fn), 16);
assert_eq!(std::mem::offset_of!(Task<()>, is_queued), 24);
assert_eq!(std::mem::offset_of!(Task<()>, is_completed), 25);
assert_eq!(std::mem::offset_of!(Task<()>, ref_count), 26);
assert_eq!(std::mem::offset_of!(Task<()>, tracker_key), 28);
assert_eq!(std::mem::offset_of!(Task<()>, cross_next), 32);
assert_eq!(std::mem::offset_of!(Task<()>, join_waker), 40);
assert_eq!(std::mem::offset_of!(Task<()>, storage_offset), 56);
assert_eq!(std::mem::offset_of!(Task<()>, flags), 58);
assert_eq!(std::mem::offset_of!(Task<()>, _pad), 59);
assert_eq!(std::mem::offset_of!(Task<()>, storage), 64);
}
#[test]
fn task_size_with_future() {
#[allow(dead_code)]
struct SmallFuture([u8; 24]);
impl Future for SmallFuture {
type Output = ();
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
Poll::Ready(())
}
}
assert_eq!(
std::mem::size_of::<Task<SmallFuture>>(),
TASK_HEADER_SIZE + 24
);
}
#[test]
fn queued_flag_via_pointer() {
struct Noop;
impl Future for Noop {
type Output = ();
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
Poll::Ready(())
}
}
let task = Box::new(Task::new_boxed(Noop, 0));
let ptr = Box::into_raw(task) as *mut u8;
unsafe {
assert!(!is_queued(ptr));
set_queued(ptr, true);
assert!(is_queued(ptr));
set_queued(ptr, false);
assert!(!is_queued(ptr));
drop_task_future(ptr);
free_task(ptr);
}
}
#[test]
fn box_free_works() {
struct Noop;
impl Future for Noop {
type Output = ();
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
Poll::Ready(())
}
}
let task = Box::new(Task::new_boxed(Noop, 42));
let ptr = Box::into_raw(task) as *mut u8;
unsafe {
assert_eq!(tracker_key(ptr), 42);
assert_eq!(ref_count(ptr), 1);
drop_task_future(ptr);
free_task(ptr);
}
}
#[test]
fn joinable_task_flags() {
struct Noop;
impl Future for Noop {
type Output = u64;
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<u64> {
Poll::Ready(42)
}
}
let ptr = box_spawn_joinable(Noop, 0);
unsafe {
assert!(has_join(ptr));
assert!(!is_aborted(ptr));
assert_eq!(ref_count(ptr), 2);
drop_task_future(ptr);
ref_dec(ptr); ref_dec(ptr); free_task(ptr);
}
}
struct PanickingDrop {
drop_count: *mut u32,
}
impl Future for PanickingDrop {
type Output = u64;
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<u64> {
Poll::Ready(42)
}
}
impl Drop for PanickingDrop {
fn drop(&mut self) {
unsafe { *self.drop_count += 1 };
panic!("intentional drop panic");
}
}
#[test]
fn poll_join_panic_in_drop_prevents_double_drop() {
use std::task::{RawWaker, RawWakerVTable, Waker};
let noop_vtable = RawWakerVTable::new(
|p| RawWaker::new(p, &NOOP_VTABLE),
|_| {},
|_| {},
|_| {},
);
static NOOP_VTABLE: RawWakerVTable = RawWakerVTable::new(
|p| RawWaker::new(p, &NOOP_VTABLE),
|_| {},
|_| {},
|_| {},
);
let waker = unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &NOOP_VTABLE)) };
let mut cx = Context::from_waker(&waker);
let mut drop_count: u32 = 0;
let ptr = box_spawn_joinable(
PanickingDrop {
drop_count: &raw mut drop_count,
},
0,
);
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| unsafe {
poll_task(ptr, &mut cx)
}));
assert!(result.is_err(), "expected panic from PanickingDrop");
assert_eq!(drop_count, 1, "future should be dropped exactly once");
unsafe { drop_task_future(ptr) };
assert_eq!(
drop_count, 1,
"drop_task_future after panic must be a no-op (drop_noop)"
);
unsafe {
ref_dec(ptr);
ref_dec(ptr);
free_task(ptr);
}
}
#[test]
fn drop_fn_transitions_correctly_on_normal_completion() {
use std::task::{RawWaker, RawWakerVTable, Waker};
static NOOP_VTABLE: RawWakerVTable = RawWakerVTable::new(
|p| RawWaker::new(p, &NOOP_VTABLE),
|_| {},
|_| {},
|_| {},
);
let waker = unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &NOOP_VTABLE)) };
let mut cx = Context::from_waker(&waker);
static mut OUTPUT_DROP_COUNT: u32 = 0;
struct TrackedOutput;
impl Drop for TrackedOutput {
fn drop(&mut self) {
unsafe { OUTPUT_DROP_COUNT += 1 };
}
}
struct ProduceTracked;
impl Future for ProduceTracked {
type Output = TrackedOutput;
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<TrackedOutput> {
Poll::Ready(TrackedOutput)
}
}
let ptr = box_spawn_joinable(ProduceTracked, 0);
let result = unsafe { poll_task(ptr, &mut cx) };
assert!(result.is_ready());
unsafe { OUTPUT_DROP_COUNT = 0 };
unsafe { drop_task_future(ptr) };
assert_eq!(
unsafe { OUTPUT_DROP_COUNT },
1,
"drop_fn should drop the output exactly once"
);
unsafe {
ref_dec(ptr);
ref_dec(ptr);
free_task(ptr);
}
}
}