use core::fmt;
use core::future::Future;
use core::marker::{PhantomData, Unpin};
use core::mem;
use core::pin::Pin;
use core::ptr::NonNull;
use core::sync::atomic::Ordering;
use core::task::{Context, Poll};
use crate::header::Header;
use crate::raw::Panic;
use crate::runnable::ScheduleInfo;
use crate::state::*;
#[must_use = "tasks get canceled when dropped, use `.detach()` to run them in the background"]
pub struct Task<T, M = ()> {
pub(crate) ptr: NonNull<()>,
pub(crate) _marker: PhantomData<(T, M)>,
}
unsafe impl<T: Send, M: Send + Sync> Send for Task<T, M> {}
unsafe impl<T, M: Send + Sync> Sync for Task<T, M> {}
impl<T, M> Unpin for Task<T, M> {}
#[cfg(feature = "std")]
impl<T, M> std::panic::UnwindSafe for Task<T, M> {}
#[cfg(feature = "std")]
impl<T, M> std::panic::RefUnwindSafe for Task<T, M> {}
impl<T, M> Task<T, M> {
pub fn detach(self) {
let mut this = self;
let _out = this.set_detached();
mem::forget(this);
}
pub async fn cancel(self) -> Option<T> {
let mut this = self;
this.set_canceled();
this.fallible().await
}
pub fn fallible(self) -> FallibleTask<T, M> {
FallibleTask { task: self }
}
fn set_canceled(&mut self) {
let ptr = self.ptr.as_ptr();
let header = ptr as *const Header<M>;
unsafe {
let mut state = (*header).state.load(Ordering::Acquire);
loop {
if state & (COMPLETED | CLOSED) != 0 {
break;
}
let new = if state & (SCHEDULED | RUNNING) == 0 {
(state | SCHEDULED | CLOSED) + REFERENCE
} else {
state | CLOSED
};
match (*header).state.compare_exchange_weak(
state,
new,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
if state & (SCHEDULED | RUNNING) == 0 {
((*header).vtable.schedule)(ptr, ScheduleInfo::new(false));
}
if state & AWAITER != 0 {
(*header).notify(None);
}
break;
}
Err(s) => state = s,
}
}
}
}
fn set_detached(&mut self) -> Option<Result<T, Panic>> {
let ptr = self.ptr.as_ptr();
let header = ptr as *const Header<M>;
unsafe {
let mut output = None;
if let Err(mut state) = (*header).state.compare_exchange_weak(
SCHEDULED | TASK | REFERENCE,
SCHEDULED | REFERENCE,
Ordering::AcqRel,
Ordering::Acquire,
) {
loop {
if state & COMPLETED != 0 && state & CLOSED == 0 {
match (*header).state.compare_exchange_weak(
state,
state | CLOSED,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
output = Some(
(((*header).vtable.get_output)(ptr) as *mut Result<T, Panic>)
.read(),
);
state |= CLOSED;
}
Err(s) => state = s,
}
} else {
let new = if state & (!(REFERENCE - 1) | CLOSED) == 0 {
SCHEDULED | CLOSED | REFERENCE
} else {
state & !TASK
};
match (*header).state.compare_exchange_weak(
state,
new,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
if state & !(REFERENCE - 1) == 0 {
if state & CLOSED == 0 {
((*header).vtable.schedule)(ptr, ScheduleInfo::new(false));
} else {
((*header).vtable.destroy)(ptr);
}
}
break;
}
Err(s) => state = s,
}
}
}
}
output
}
}
fn poll_task(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
let ptr = self.ptr.as_ptr();
let header = ptr as *const Header<M>;
unsafe {
let mut state = (*header).state.load(Ordering::Acquire);
loop {
if state & CLOSED != 0 {
if state & (SCHEDULED | RUNNING) != 0 {
(*header).register(cx.waker());
state = (*header).state.load(Ordering::Acquire);
if state & (SCHEDULED | RUNNING) != 0 {
return Poll::Pending;
}
}
(*header).notify(Some(cx.waker()));
return Poll::Ready(None);
}
if state & COMPLETED == 0 {
(*header).register(cx.waker());
state = (*header).state.load(Ordering::Acquire);
if state & CLOSED != 0 {
continue;
}
if state & COMPLETED == 0 {
return Poll::Pending;
}
}
match (*header).state.compare_exchange(
state,
state | CLOSED,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
if state & AWAITER != 0 {
(*header).notify(Some(cx.waker()));
}
let output = ((*header).vtable.get_output)(ptr) as *mut Result<T, Panic>;
let output = output.read();
let output = match output {
Ok(output) => output,
Err(panic) => {
#[cfg(feature = "std")]
std::panic::resume_unwind(panic);
#[cfg(not(feature = "std"))]
match panic {}
}
};
return Poll::Ready(Some(output));
}
Err(s) => state = s,
}
}
}
}
fn header(&self) -> &Header<M> {
let ptr = self.ptr.as_ptr();
let header = ptr as *const Header<M>;
unsafe { &*header }
}
pub fn is_finished(&self) -> bool {
let ptr = self.ptr.as_ptr();
let header = ptr as *const Header<M>;
unsafe {
let state = (*header).state.load(Ordering::Acquire);
state & (CLOSED | COMPLETED) != 0
}
}
pub fn metadata(&self) -> &M {
&self.header().metadata
}
}
impl<T, M> Drop for Task<T, M> {
fn drop(&mut self) {
self.set_canceled();
self.set_detached();
}
}
impl<T, M> Future for Task<T, M> {
type Output = T;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.poll_task(cx) {
Poll::Ready(t) => Poll::Ready(t.expect("task has failed")),
Poll::Pending => Poll::Pending,
}
}
}
impl<T, M: fmt::Debug> fmt::Debug for Task<T, M> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Task")
.field("header", self.header())
.finish()
}
}
#[must_use = "tasks get canceled when dropped, use `.detach()` to run them in the background"]
pub struct FallibleTask<T, M = ()> {
task: Task<T, M>,
}
impl<T, M> FallibleTask<T, M> {
pub fn detach(self) {
self.task.detach()
}
pub async fn cancel(self) -> Option<T> {
self.task.cancel().await
}
pub fn is_finished(&self) -> bool {
self.task.is_finished()
}
}
impl<T, M> Future for FallibleTask<T, M> {
type Output = Option<T>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.task.poll_task(cx)
}
}
impl<T, M: fmt::Debug> fmt::Debug for FallibleTask<T, M> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FallibleTask")
.field("header", self.task.header())
.finish()
}
}