use core::{
fmt,
future::Future,
marker::{PhantomData, Unpin},
pin::Pin,
ptr::NonNull,
task::{Context, Poll},
};
#[cfg(feature = "debugging")]
use crate::task::debugging::TaskDebugger;
use crate::{
dbg_context,
task::{header::Header, state::*},
};
use std::sync::atomic::Ordering;
pub struct JoinHandle<R> {
pub(crate) raw_task: NonNull<()>,
pub(crate) _marker: PhantomData<R>,
}
impl<R> Unpin for JoinHandle<R> {}
impl<R> JoinHandle<R> {
pub fn cancel(&self) {
let ptr = self.raw_task.as_ptr();
dbg_context!(ptr, "cancel", {
let header = ptr as *mut Header;
unsafe {
let state = (*header).state;
if state & (COMPLETED | CLOSED) != 0 {
return;
}
let new = if state & (SCHEDULED | RUNNING) == 0 {
state | SCHEDULED | CLOSED
} else {
state | CLOSED
};
(*header).state = new;
if state & (SCHEDULED | RUNNING) == 0 {
let refs = (*header).references.fetch_add(1, Ordering::Relaxed);
assert_ne!(refs, i16::max_value());
((*header).vtable.schedule)(ptr);
}
(*header).notify(None);
}
});
}
}
impl<R> Drop for JoinHandle<R> {
fn drop(&mut self) {
let ptr = self.raw_task.as_ptr();
dbg_context!(ptr, "drop_join_handle", {
let header = ptr as *mut Header;
let mut output = None;
unsafe {
if (*header).state == SCHEDULED | HANDLE {
(*header).state = SCHEDULED;
return;
}
let state = (*header).state;
let refs = (*header).references.load(Ordering::Relaxed);
if state & COMPLETED != 0 && state & CLOSED == 0 {
(*header).state |= CLOSED;
output = Some((((*header).vtable.get_output)(ptr) as *mut R).read());
(*header).state &= !HANDLE;
if refs == 0 {
((*header).vtable.destroy)(ptr)
}
} else {
let new = if (refs == 0) & (state & CLOSED == 0) {
SCHEDULED | CLOSED
} else {
state & !HANDLE
};
(*header).state = new;
if refs == 0 {
if state & CLOSED == 0 {
let refs = (*header).references.fetch_add(1, Ordering::Relaxed);
assert_ne!(refs, i16::max_value());
((*header).vtable.schedule)(ptr);
} else {
((*header).vtable.destroy)(ptr);
}
}
}
}
drop(output);
});
}
}
impl<R> Future for JoinHandle<R> {
type Output = Option<R>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let ptr = self.raw_task.as_ptr();
let header = ptr as *mut Header;
unsafe {
let state = (*header).state;
if state & CLOSED != 0 {
if state & (SCHEDULED | RUNNING) != 0 {
(*header).register(cx.waker());
return Poll::Pending;
}
(*header).notify(Some(cx.waker()));
return Poll::Ready(None);
}
if state & COMPLETED == 0 {
(*header).register(cx.waker());
return Poll::Pending;
}
(*header).state |= CLOSED;
(*header).notify(Some(cx.waker()));
let output = ((*header).vtable.get_output)(ptr) as *mut R;
Poll::Ready(Some(output.read()))
}
}
}
impl<R> fmt::Debug for JoinHandle<R> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let ptr = self.raw_task.as_ptr();
let header = ptr as *const Header;
f.debug_struct("JoinHandle")
.field("header", unsafe { &(*header) })
.finish()
}
}