use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
pub const TASK_HEADER_SIZE: usize = 32;
#[repr(C)]
pub(crate) struct Task<F> {
poll_fn: unsafe fn(*mut u8, &mut Context<'_>) -> Poll<()>,
drop_fn: unsafe fn(*mut u8),
free_fn: unsafe fn(*mut u8),
is_queued: bool,
is_completed: bool,
ref_count: u16,
tracker_key: u32,
future: F,
}
const _: () = {
assert!(std::mem::size_of::<Task<()>>() == TASK_HEADER_SIZE);
};
impl<F: Future<Output = ()> + 'static> Task<F> {
#[inline]
pub(crate) fn new_boxed(future: F, tracker_key: u32) -> Self {
Self {
poll_fn: poll_fn::<F>,
drop_fn: drop_fn::<F>,
free_fn: box_free::<F>,
is_queued: false,
is_completed: false,
ref_count: 1, tracker_key,
future,
}
}
#[inline]
pub(crate) fn new_with_free(
future: F,
tracker_key: u32,
free_fn: unsafe fn(*mut u8),
) -> Self {
Self {
poll_fn: poll_fn::<F>,
drop_fn: drop_fn::<F>,
free_fn,
is_queued: false,
is_completed: false,
ref_count: 1,
tracker_key,
future,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TaskId(pub(crate) *mut u8);
impl TaskId {
#[allow(dead_code)]
pub(crate) fn as_ptr(&self) -> *mut u8 {
self.0
}
}
#[inline]
pub(crate) unsafe fn tracker_key(ptr: *mut u8) -> u32 {
unsafe { *(ptr.add(28).cast::<u32>()) }
}
#[inline]
pub(crate) unsafe fn ref_inc(ptr: *mut u8) {
let rc = unsafe { &mut *ptr.add(26).cast::<u16>() };
*rc = rc.checked_add(1).expect("waker refcount overflow");
}
#[inline]
pub(crate) unsafe fn ref_dec(ptr: *mut u8) -> bool {
let rc = unsafe { &mut *ptr.add(26).cast::<u16>() };
debug_assert!(*rc > 0, "waker refcount underflow");
*rc -= 1;
*rc == 0
}
#[allow(dead_code)]
#[inline]
pub(crate) unsafe fn ref_count(ptr: *mut u8) -> u16 {
unsafe { *ptr.add(26).cast::<u16>() }
}
#[inline]
pub(crate) unsafe fn set_completed(ptr: *mut u8) {
unsafe { *ptr.add(25) = 1 }
}
#[inline]
pub(crate) unsafe fn is_completed(ptr: *mut u8) -> bool {
unsafe { *ptr.add(25) != 0 }
}
#[inline]
pub(crate) unsafe fn is_queued(ptr: *mut u8) -> bool {
unsafe { *ptr.add(24) != 0 }
}
#[inline]
pub(crate) unsafe fn set_queued(ptr: *mut u8, queued: bool) {
unsafe { *ptr.add(24) = queued as u8 }
}
#[inline]
pub(crate) unsafe fn poll_task(ptr: *mut u8, cx: &mut Context<'_>) -> Poll<()> {
let poll_fn: unsafe fn(*mut u8, &mut Context<'_>) -> Poll<()> =
unsafe { *(ptr as *const unsafe fn(*mut u8, &mut Context<'_>) -> Poll<()>) };
let future_ptr = unsafe { ptr.add(TASK_HEADER_SIZE) };
unsafe { poll_fn(future_ptr, cx) }
}
#[inline]
pub(crate) unsafe fn drop_task_future(ptr: *mut u8) {
let drop_fn: unsafe fn(*mut u8) =
unsafe { *(ptr.add(8) as *const unsafe fn(*mut u8)) };
let future_ptr = unsafe { ptr.add(TASK_HEADER_SIZE) };
unsafe { drop_fn(future_ptr) }
}
#[inline]
pub(crate) unsafe fn free_task(ptr: *mut u8) {
let free_fn: unsafe fn(*mut u8) =
unsafe { *(ptr.add(16) as *const unsafe fn(*mut u8)) };
unsafe { free_fn(ptr) }
}
unsafe fn poll_fn<F: Future<Output = ()>>(
ptr: *mut u8,
cx: &mut Context<'_>,
) -> Poll<()> {
let future = unsafe { Pin::new_unchecked(&mut *ptr.cast::<F>()) };
future.poll(cx)
}
unsafe fn drop_fn<F: Future<Output = ()>>(ptr: *mut u8) {
unsafe { std::ptr::drop_in_place(ptr.cast::<F>()) }
}
unsafe fn box_free<F>(ptr: *mut u8) {
let layout = std::alloc::Layout::new::<Task<F>>();
unsafe { std::alloc::dealloc(ptr, layout) }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn task_header_size() {
assert_eq!(TASK_HEADER_SIZE, 32);
assert_eq!(std::mem::size_of::<Task<()>>(), 32);
}
#[test]
fn task_layout_offsets() {
assert_eq!(std::mem::offset_of!(Task<()>, poll_fn), 0);
assert_eq!(std::mem::offset_of!(Task<()>, drop_fn), 8);
assert_eq!(std::mem::offset_of!(Task<()>, free_fn), 16);
assert_eq!(std::mem::offset_of!(Task<()>, is_queued), 24);
assert_eq!(std::mem::offset_of!(Task<()>, is_completed), 25);
assert_eq!(std::mem::offset_of!(Task<()>, ref_count), 26);
assert_eq!(std::mem::offset_of!(Task<()>, tracker_key), 28);
assert_eq!(std::mem::offset_of!(Task<()>, future), 32);
}
#[test]
fn task_size_with_future() {
#[allow(dead_code)]
struct SmallFuture([u8; 64]);
impl Future for SmallFuture {
type Output = ();
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
Poll::Ready(())
}
}
assert_eq!(
std::mem::size_of::<Task<SmallFuture>>(),
TASK_HEADER_SIZE + 64
);
}
#[test]
fn queued_flag_via_pointer() {
struct Noop;
impl Future for Noop {
type Output = ();
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
Poll::Ready(())
}
}
let task = Box::new(Task::new_boxed(Noop, 0));
let ptr = Box::into_raw(task) as *mut u8;
unsafe {
assert!(!is_queued(ptr));
set_queued(ptr, true);
assert!(is_queued(ptr));
set_queued(ptr, false);
assert!(!is_queued(ptr));
drop_task_future(ptr);
free_task(ptr);
}
}
#[test]
fn box_free_works() {
struct Noop;
impl Future for Noop {
type Output = ();
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
Poll::Ready(())
}
}
let task = Box::new(Task::new_boxed(Noop, 42));
let ptr = Box::into_raw(task) as *mut u8;
unsafe {
assert_eq!(tracker_key(ptr), 42);
assert_eq!(ref_count(ptr), 1);
drop_task_future(ptr);
free_task(ptr);
}
}
}