use core::{
alloc::Layout,
future::Future,
pin::Pin,
ptr::{
self,
NonNull,
},
task::{
Context,
Poll,
},
};
pub(crate) struct TaskHeader {
ref_count: usize,
poll: unsafe fn(ptr: NonNull<TaskHeader>, cx: &mut Context<'_>) -> Poll<()>,
drop_and_dealloc: unsafe fn(ptr: NonNull<TaskHeader>),
}
impl TaskHeader {
pub(crate) unsafe fn poll(ptr: NonNull<TaskHeader>, cx: &mut Context<'_>) -> Poll<()> {
(ptr.as_ref().poll)(ptr, cx)
}
pub(crate) fn increment_ref_count(&mut self) {
self.ref_count += 1;
}
pub(crate) unsafe fn decrement_ref_count(mut ptr: NonNull<TaskHeader>) {
let zero = {
let header = ptr.as_mut();
header.ref_count -= 1;
header.ref_count == 0
};
if zero {
(ptr.as_ref().drop_and_dealloc)(ptr);
}
}
}
#[repr(C)]
pub(crate) struct Task<F> {
header: TaskHeader,
future: F,
}
impl<F> Task<F>
where
F: Future<Output = ()>,
{
pub(crate) fn new(future: F) -> NonNull<Self> {
let layout = Layout::new::<Self>();
assert_ne!(layout.size(), 0);
let ptr: *mut Task<F> = unsafe { alloc::alloc::alloc(layout).cast() };
let ptr = if let Some(ptr) = NonNull::new(ptr) {
ptr
} else {
alloc::alloc::handle_alloc_error(layout);
};
unsafe {
ptr::write(
ptr.as_ptr(),
Task {
header: TaskHeader {
ref_count: 1,
poll: poll::<F>,
drop_and_dealloc: drop_and_dealloc::<Task<F>>,
},
future,
},
)
}
ptr
}
pub(crate) fn header(mut ptr: NonNull<Self>) -> NonNull<TaskHeader> {
unsafe { (&mut ptr.as_mut().header).into() }
}
}
unsafe fn poll<F>(ptr: NonNull<TaskHeader>, cx: &mut Context<'_>) -> Poll<()>
where
F: Future<Output = ()>,
{
let mut ptr = ptr.cast::<Task<F>>();
let pin = Pin::new_unchecked(&mut ptr.as_mut().future);
pin.poll(cx)
}
unsafe fn drop_and_dealloc<T>(ptr: NonNull<TaskHeader>) {
let ptr = ptr.cast::<T>();
ptr::drop_in_place(ptr.as_ptr());
let layout = Layout::new::<T>();
alloc::alloc::dealloc(ptr.as_ptr().cast(), layout);
}