use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicPtr, AtomicU8, AtomicU16, Ordering};
use std::task::{Context, Poll};
pub const TASK_HEADER_SIZE: usize = 40;
#[repr(C)]
pub(crate) struct Task<F> {
poll_fn: unsafe fn(*mut u8, &mut Context<'_>) -> Poll<()>,
drop_fn: unsafe fn(*mut u8),
free_fn: unsafe fn(*mut u8),
is_queued: AtomicU8,
is_completed: AtomicU8,
ref_count: AtomicU16,
tracker_key: u32,
cross_next: AtomicPtr<u8>,
future: F,
}
const _: () = {
assert!(std::mem::size_of::<Task<()>>() == TASK_HEADER_SIZE);
};
impl<F: Future<Output = ()> + 'static> Task<F> {
#[inline]
pub(crate) fn new_boxed(future: F, tracker_key: u32) -> Self {
Self {
poll_fn: poll_fn::<F>,
drop_fn: drop_fn::<F>,
free_fn: box_free::<F>,
is_queued: AtomicU8::new(0),
is_completed: AtomicU8::new(0),
ref_count: AtomicU16::new(1), tracker_key,
cross_next: AtomicPtr::new(std::ptr::null_mut()),
future,
}
}
#[inline]
pub(crate) fn new_with_free(
future: F,
tracker_key: u32,
free_fn: unsafe fn(*mut u8),
) -> Self {
Self {
poll_fn: poll_fn::<F>,
drop_fn: drop_fn::<F>,
free_fn,
is_queued: AtomicU8::new(0),
is_completed: AtomicU8::new(0),
ref_count: AtomicU16::new(1),
tracker_key,
cross_next: AtomicPtr::new(std::ptr::null_mut()),
future,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TaskId(pub(crate) *mut u8);
impl TaskId {
#[allow(dead_code)]
pub(crate) fn as_ptr(&self) -> *mut u8 {
self.0
}
}
#[inline]
pub(crate) unsafe fn tracker_key(ptr: *mut u8) -> u32 {
unsafe { *(ptr.add(28).cast::<u32>()) }
}
#[inline]
pub(crate) unsafe fn ref_inc(ptr: *mut u8) {
let rc = unsafe { &*ptr.add(26).cast::<AtomicU16>() };
let prev = rc.fetch_add(1, Ordering::Relaxed);
assert!(prev < u16::MAX, "waker refcount overflow");
}
#[inline]
pub(crate) unsafe fn ref_dec(ptr: *mut u8) -> bool {
let rc = unsafe { &*ptr.add(26).cast::<AtomicU16>() };
let prev = rc.fetch_sub(1, Ordering::AcqRel);
debug_assert!(prev > 0, "waker refcount underflow");
prev == 1
}
#[allow(dead_code)]
#[inline]
pub(crate) unsafe fn ref_count(ptr: *mut u8) -> u16 {
unsafe { &*ptr.add(26).cast::<AtomicU16>() }.load(Ordering::Relaxed)
}
#[inline]
pub(crate) unsafe fn set_completed(ptr: *mut u8) {
unsafe { &*ptr.add(25).cast::<AtomicU8>() }.store(1, Ordering::Release);
}
#[inline]
pub(crate) unsafe fn is_completed(ptr: *mut u8) -> bool {
unsafe { &*ptr.add(25).cast::<AtomicU8>() }.load(Ordering::Acquire) != 0
}
#[inline]
#[allow(dead_code)] pub(crate) unsafe fn cross_next(ptr: *mut u8) -> &'static AtomicPtr<u8> {
unsafe { &*ptr.add(32).cast::<AtomicPtr<u8>>() }
}
#[inline]
pub(crate) unsafe fn is_queued(ptr: *mut u8) -> bool {
unsafe { &*ptr.add(24).cast::<AtomicU8>() }.load(Ordering::Relaxed) != 0
}
#[inline]
pub(crate) unsafe fn set_queued(ptr: *mut u8, queued: bool) {
unsafe { &*ptr.add(24).cast::<AtomicU8>() }.store(queued as u8, Ordering::Relaxed);
}
#[inline]
pub(crate) unsafe fn try_set_queued(ptr: *mut u8) -> bool {
let queued = unsafe { &*ptr.add(24).cast::<AtomicU8>() };
queued.compare_exchange(0, 1, Ordering::AcqRel, Ordering::Relaxed).is_ok()
}
#[inline]
pub(crate) unsafe fn poll_task(ptr: *mut u8, cx: &mut Context<'_>) -> Poll<()> {
let poll_fn: unsafe fn(*mut u8, &mut Context<'_>) -> Poll<()> =
unsafe { *(ptr as *const unsafe fn(*mut u8, &mut Context<'_>) -> Poll<()>) };
let future_ptr = unsafe { ptr.add(TASK_HEADER_SIZE) };
unsafe { poll_fn(future_ptr, cx) }
}
#[inline]
pub(crate) unsafe fn drop_task_future(ptr: *mut u8) {
let drop_fn: unsafe fn(*mut u8) =
unsafe { *(ptr.add(8) as *const unsafe fn(*mut u8)) };
let future_ptr = unsafe { ptr.add(TASK_HEADER_SIZE) };
unsafe { drop_fn(future_ptr) }
}
#[inline]
pub(crate) unsafe fn free_task(ptr: *mut u8) {
let free_fn: unsafe fn(*mut u8) =
unsafe { *(ptr.add(16) as *const unsafe fn(*mut u8)) };
unsafe { free_fn(ptr) }
}
unsafe fn poll_fn<F: Future<Output = ()>>(
ptr: *mut u8,
cx: &mut Context<'_>,
) -> Poll<()> {
let future = unsafe { Pin::new_unchecked(&mut *ptr.cast::<F>()) };
future.poll(cx)
}
unsafe fn drop_fn<F: Future<Output = ()>>(ptr: *mut u8) {
unsafe { std::ptr::drop_in_place(ptr.cast::<F>()) }
}
unsafe fn box_free<F>(ptr: *mut u8) {
let layout = std::alloc::Layout::new::<Task<F>>();
unsafe { std::alloc::dealloc(ptr, layout) }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn task_header_size() {
assert_eq!(TASK_HEADER_SIZE, 40);
assert_eq!(std::mem::size_of::<Task<()>>(), 40);
}
#[test]
fn task_layout_offsets() {
assert_eq!(std::mem::offset_of!(Task<()>, poll_fn), 0);
assert_eq!(std::mem::offset_of!(Task<()>, drop_fn), 8);
assert_eq!(std::mem::offset_of!(Task<()>, free_fn), 16);
assert_eq!(std::mem::offset_of!(Task<()>, is_queued), 24);
assert_eq!(std::mem::offset_of!(Task<()>, is_completed), 25);
assert_eq!(std::mem::offset_of!(Task<()>, ref_count), 26);
assert_eq!(std::mem::offset_of!(Task<()>, tracker_key), 28);
assert_eq!(std::mem::offset_of!(Task<()>, cross_next), 32);
assert_eq!(std::mem::offset_of!(Task<()>, future), 40);
}
#[test]
fn task_size_with_future() {
#[allow(dead_code)]
struct SmallFuture([u8; 24]);
impl Future for SmallFuture {
type Output = ();
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
Poll::Ready(())
}
}
assert_eq!(
std::mem::size_of::<Task<SmallFuture>>(),
TASK_HEADER_SIZE + 24
);
}
#[test]
fn queued_flag_via_pointer() {
struct Noop;
impl Future for Noop {
type Output = ();
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
Poll::Ready(())
}
}
let task = Box::new(Task::new_boxed(Noop, 0));
let ptr = Box::into_raw(task) as *mut u8;
unsafe {
assert!(!is_queued(ptr));
set_queued(ptr, true);
assert!(is_queued(ptr));
set_queued(ptr, false);
assert!(!is_queued(ptr));
drop_task_future(ptr);
free_task(ptr);
}
}
#[test]
fn box_free_works() {
struct Noop;
impl Future for Noop {
type Output = ();
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
Poll::Ready(())
}
}
let task = Box::new(Task::new_boxed(Noop, 42));
let ptr = Box::into_raw(task) as *mut u8;
unsafe {
assert_eq!(tracker_key(ptr), 42);
assert_eq!(ref_count(ptr), 1);
drop_task_future(ptr);
free_task(ptr);
}
}
}