extern crate alloc;
use std::alloc::{Layout, alloc, dealloc, handle_alloc_error};
use std::future::Future;
use std::mem::{self, ManuallyDrop};
use std::task::{RawWaker, RawWakerVTable};
use crate::loom_exports::cell::UnsafeCell;
use crate::loom_exports::sync::atomic::{self, AtomicU64, Ordering};
mod cancel_token;
mod promise;
mod runnable;
mod util;
#[cfg(test)]
mod tests;
pub(crate) use cancel_token::CancelToken;
pub(crate) use promise::Promise;
pub(crate) use runnable::Runnable;
use self::util::{RunOnDrop, runnable_exists};
const POLLING: u64 = 1 << 0;
const CLOSED: u64 = 1 << 1;
const REF_INC: u64 = 1 << 2;
const WAKE_INC: u64 = 1 << 33;
const REF_MASK: u64 = !(REF_INC - 1) & (WAKE_INC - 1);
const WAKE_MASK: u64 = !(WAKE_INC - 1);
const REF_CRITICAL: u64 = (REF_MASK / 2) & REF_MASK;
const WAKE_CRITICAL: u64 = (WAKE_MASK / 2) & WAKE_MASK;
union TaskCore<F: Future> {
future: ManuallyDrop<F>,
output: ManuallyDrop<F::Output>,
}
struct Task<F: Future, S, T> {
state: AtomicU64,
core: UnsafeCell<TaskCore<F>>,
schedule_fn: S,
tag: T,
}
impl<F, S, T> Task<F, S, T>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
S: Fn(Runnable, T) + Send + Sync + 'static,
T: Clone + Send + Sync + 'static,
{
unsafe fn clone_waker(ptr: *const ()) -> RawWaker {
let this = unsafe { &*(ptr as *const Self) };
let ref_count = this.state.fetch_add(REF_INC, Ordering::Relaxed) & REF_MASK;
if ref_count > REF_CRITICAL {
panic!("Attack of the clones: the waker was cloned too many times");
}
RawWaker::new(ptr, raw_waker_vtable::<F, S, T>())
}
unsafe fn wake_by_val(ptr: *const ()) {
if mem::size_of::<S>() != 0 {
unsafe { Self::drop_waker(ptr) };
panic!("Scheduling functions with captured variables are not supported");
}
let state = unsafe { Self::wake(ptr, WAKE_INC - REF_INC) };
if state & (REF_MASK | POLLING) == REF_INC {
atomic::fence(Ordering::Acquire);
let this = unsafe { &*(ptr as *const Self) };
let _drop_guard = RunOnDrop::new(|| {
unsafe { dealloc(ptr as *mut u8, Layout::new::<Self>()) };
});
if state & CLOSED == 0 {
this.core
.with_mut(|c| unsafe { ManuallyDrop::drop(&mut (*c).output) });
}
}
}
unsafe fn wake_by_ref(ptr: *const ()) {
unsafe { Self::wake(ptr, WAKE_INC) };
}
#[inline(always)]
unsafe fn wake(ptr: *const (), state_delta: u64) -> u64 {
let this = unsafe { &*(ptr as *const Self) };
let state = this.state.fetch_add(state_delta, Ordering::Release);
if state & WAKE_MASK > WAKE_CRITICAL {
panic!("The task was woken too many times: {state:0x}");
}
if state & (WAKE_MASK | CLOSED | POLLING) == POLLING {
let runnable = unsafe { Runnable::new_unchecked(ptr as *const Self) };
(this.schedule_fn)(runnable, this.tag.clone());
}
state
}
unsafe fn drop_waker(ptr: *const ()) {
let this = unsafe { &*(ptr as *const Self) };
let state = this.state.fetch_sub(REF_INC, Ordering::Release);
if state & REF_MASK == REF_INC && !runnable_exists(state) {
atomic::fence(Ordering::Acquire);
let _drop_guard = RunOnDrop::new(|| {
unsafe { dealloc(ptr as *mut u8, Layout::new::<Self>()) };
});
unsafe {
if state & POLLING == POLLING {
this.core.with_mut(|c| ManuallyDrop::drop(&mut (*c).future));
} else if state & CLOSED == 0 {
this.core.with_mut(|c| ManuallyDrop::drop(&mut (*c).output));
}
}
}
}
}
#[inline(never)]
fn raw_waker_vtable<F, S, T>() -> &'static RawWakerVTable
where
F: Future + Send + 'static,
F::Output: Send + 'static,
S: Fn(Runnable, T) + Send + Sync + 'static,
T: Clone + Send + Sync + 'static,
{
&RawWakerVTable::new(
Task::<F, S, T>::clone_waker,
Task::<F, S, T>::wake_by_val,
Task::<F, S, T>::wake_by_ref,
Task::<F, S, T>::drop_waker,
)
}
pub(crate) fn spawn<F, S, T>(
future: F,
schedule_fn: S,
tag: T,
) -> (Promise<F::Output>, Runnable, CancelToken)
where
F: Future + Send + 'static,
F::Output: Send + 'static,
S: Fn(Runnable, T) + Send + Sync + 'static,
T: Clone + Send + Sync + 'static,
{
let task = Task {
state: AtomicU64::new((2 * REF_INC) | WAKE_INC | POLLING),
core: UnsafeCell::new(TaskCore {
future: ManuallyDrop::new(future),
}),
schedule_fn,
tag,
};
unsafe {
let layout = Layout::new::<Task<F, S, T>>();
let ptr = alloc(layout) as *mut Task<F, S, T>;
if ptr.is_null() {
handle_alloc_error(layout);
}
*ptr = task;
let runnable = Runnable::new_unchecked(ptr);
let promise = Promise::new_unchecked(ptr);
let cancel_token = CancelToken::new_unchecked(ptr);
(promise, runnable, cancel_token)
}
}
pub(crate) fn spawn_and_forget<F, S, T>(
future: F,
schedule_fn: S,
tag: T,
) -> (Runnable, CancelToken)
where
F: Future + Send + 'static,
F::Output: Send + 'static,
S: Fn(Runnable, T) + Send + Sync + 'static,
T: Clone + Send + Sync + 'static,
{
let task = Task {
state: AtomicU64::new(REF_INC | WAKE_INC | POLLING),
core: UnsafeCell::new(TaskCore {
future: ManuallyDrop::new(future),
}),
schedule_fn,
tag,
};
unsafe {
let layout = Layout::new::<Task<F, S, T>>();
let ptr = alloc(layout) as *mut Task<F, S, T>;
if ptr.is_null() {
handle_alloc_error(layout);
}
*ptr = task;
let runnable = Runnable::new_unchecked(ptr);
let cancel_token = CancelToken::new_unchecked(ptr);
(runnable, cancel_token)
}
}