use core::fmt;
use core::future::Future;
use core::marker::PhantomData;
use core::mem;
use core::ptr::NonNull;
use core::sync::atomic::Ordering;
use core::task::Waker;
use alloc::boxed::Box;
use crate::header::Header;
use crate::raw::RawTask;
use crate::state::*;
use crate::Task;
mod sealed {
use super::*;
pub trait Sealed<M> {}
impl<M, F> Sealed<M> for F where F: Fn(Runnable<M>) {}
impl<M, F> Sealed<M> for WithInfo<F> where F: Fn(Runnable<M>, ScheduleInfo) {}
}
#[derive(Debug)]
pub struct Builder<M> {
pub(crate) metadata: M,
#[cfg(feature = "std")]
pub(crate) propagate_panic: bool,
}
impl<M: Default> Default for Builder<M> {
fn default() -> Self {
Builder::new().metadata(M::default())
}
}
#[derive(Debug, Copy, Clone)]
#[non_exhaustive]
pub struct ScheduleInfo {
pub woken_while_running: bool,
}
impl ScheduleInfo {
pub(crate) fn new(woken_while_running: bool) -> Self {
ScheduleInfo {
woken_while_running,
}
}
}
pub trait Schedule<M = ()>: sealed::Sealed<M> {
fn schedule(&self, runnable: Runnable<M>, info: ScheduleInfo);
}
impl<M, F> Schedule<M> for F
where
F: Fn(Runnable<M>),
{
fn schedule(&self, runnable: Runnable<M>, _: ScheduleInfo) {
self(runnable)
}
}
#[derive(Debug)]
pub struct WithInfo<F>(pub F);
impl<F> From<F> for WithInfo<F> {
fn from(value: F) -> Self {
WithInfo(value)
}
}
impl<M, F> Schedule<M> for WithInfo<F>
where
F: Fn(Runnable<M>, ScheduleInfo),
{
fn schedule(&self, runnable: Runnable<M>, info: ScheduleInfo) {
(self.0)(runnable, info)
}
}
impl Builder<()> {
pub fn new() -> Builder<()> {
Builder {
metadata: (),
#[cfg(feature = "std")]
propagate_panic: false,
}
}
pub fn metadata<M>(self, metadata: M) -> Builder<M> {
Builder {
metadata,
#[cfg(feature = "std")]
propagate_panic: self.propagate_panic,
}
}
}
impl<M> Builder<M> {
#[cfg(feature = "std")]
pub fn propagate_panic(self, propagate_panic: bool) -> Builder<M> {
Builder {
metadata: self.metadata,
propagate_panic,
}
}
pub fn spawn<F, Fut, S>(self, future: F, schedule: S) -> (Runnable<M>, Task<Fut::Output, M>)
where
F: FnOnce(&M) -> Fut,
Fut: Future + Send + 'static,
Fut::Output: Send + 'static,
S: Schedule<M> + Send + Sync + 'static,
{
unsafe { self.spawn_unchecked(future, schedule) }
}
#[cfg(feature = "std")]
pub fn spawn_local<F, Fut, S>(
self,
future: F,
schedule: S,
) -> (Runnable<M>, Task<Fut::Output, M>)
where
F: FnOnce(&M) -> Fut,
Fut: Future + 'static,
Fut::Output: 'static,
S: Schedule<M> + Send + Sync + 'static,
{
use std::mem::ManuallyDrop;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::thread::{self, ThreadId};
#[inline]
fn thread_id() -> ThreadId {
std::thread_local! {
static ID: ThreadId = thread::current().id();
}
ID.try_with(|id| *id)
.unwrap_or_else(|_| thread::current().id())
}
struct Checked<F> {
id: ThreadId,
inner: ManuallyDrop<F>,
}
impl<F> Drop for Checked<F> {
fn drop(&mut self) {
assert!(
self.id == thread_id(),
"local task dropped by a thread that didn't spawn it"
);
unsafe {
ManuallyDrop::drop(&mut self.inner);
}
}
}
impl<F: Future> Future for Checked<F> {
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
assert!(
self.id == thread_id(),
"local task polled by a thread that didn't spawn it"
);
unsafe { self.map_unchecked_mut(|c| &mut *c.inner).poll(cx) }
}
}
let future = move |meta| {
let future = future(meta);
Checked {
id: thread_id(),
inner: ManuallyDrop::new(future),
}
};
unsafe { self.spawn_unchecked(future, schedule) }
}
pub unsafe fn spawn_unchecked<'a, F, Fut, S>(
self,
future: F,
schedule: S,
) -> (Runnable<M>, Task<Fut::Output, M>)
where
F: FnOnce(&'a M) -> Fut,
Fut: Future + 'a,
S: Schedule<M>,
M: 'a,
{
let ptr = if mem::size_of::<Fut>() >= 2048 {
let future = |meta| {
let future = future(meta);
Box::pin(future)
};
RawTask::<_, Fut::Output, S, M>::allocate(future, schedule, self)
} else {
RawTask::<Fut, Fut::Output, S, M>::allocate(future, schedule, self)
};
let runnable = Runnable::from_raw(ptr);
let task = Task {
ptr,
_marker: PhantomData,
};
(runnable, task)
}
}
pub fn spawn<F, S>(future: F, schedule: S) -> (Runnable, Task<F::Output>)
where
F: Future + Send + 'static,
F::Output: Send + 'static,
S: Schedule + Send + Sync + 'static,
{
unsafe { spawn_unchecked(future, schedule) }
}
#[cfg(feature = "std")]
pub fn spawn_local<F, S>(future: F, schedule: S) -> (Runnable, Task<F::Output>)
where
F: Future + 'static,
F::Output: 'static,
S: Schedule + Send + Sync + 'static,
{
Builder::new().spawn_local(move |()| future, schedule)
}
pub unsafe fn spawn_unchecked<F, S>(future: F, schedule: S) -> (Runnable, Task<F::Output>)
where
F: Future,
S: Schedule,
{
Builder::new().spawn_unchecked(move |()| future, schedule)
}
pub struct Runnable<M = ()> {
pub(crate) ptr: NonNull<()>,
pub(crate) _marker: PhantomData<M>,
}
unsafe impl<M: Send + Sync> Send for Runnable<M> {}
unsafe impl<M: Send + Sync> Sync for Runnable<M> {}
#[cfg(feature = "std")]
impl<M> std::panic::UnwindSafe for Runnable<M> {}
#[cfg(feature = "std")]
impl<M> std::panic::RefUnwindSafe for Runnable<M> {}
impl<M> Runnable<M> {
pub fn metadata(&self) -> &M {
&self.header().metadata
}
pub fn schedule(self) {
let ptr = self.ptr.as_ptr();
let header = ptr as *const Header<M>;
mem::forget(self);
unsafe {
((*header).vtable.schedule)(ptr, ScheduleInfo::new(false));
}
}
pub fn run(self) -> bool {
let ptr = self.ptr.as_ptr();
let header = ptr as *const Header<M>;
mem::forget(self);
unsafe { ((*header).vtable.run)(ptr) }
}
pub fn waker(&self) -> Waker {
let ptr = self.ptr.as_ptr();
let header = ptr as *const Header<M>;
unsafe {
let raw_waker = ((*header).vtable.clone_waker)(ptr);
Waker::from_raw(raw_waker)
}
}
fn header(&self) -> &Header<M> {
unsafe { &*(self.ptr.as_ptr() as *const Header<M>) }
}
pub fn into_raw(self) -> NonNull<()> {
let ptr = self.ptr;
mem::forget(self);
ptr
}
pub unsafe fn from_raw(ptr: NonNull<()>) -> Self {
Self {
ptr,
_marker: Default::default(),
}
}
}
impl<M> Drop for Runnable<M> {
fn drop(&mut self) {
let ptr = self.ptr.as_ptr();
let header = self.header();
unsafe {
let mut state = header.state.load(Ordering::Acquire);
loop {
if state & (COMPLETED | CLOSED) != 0 {
break;
}
match header.state.compare_exchange_weak(
state,
state | CLOSED,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => break,
Err(s) => state = s,
}
}
(header.vtable.drop_future)(ptr);
let state = header.state.fetch_and(!SCHEDULED, Ordering::AcqRel);
if state & AWAITER != 0 {
(*header).notify(None);
}
(header.vtable.drop_ref)(ptr);
}
}
}
impl<M: fmt::Debug> fmt::Debug for Runnable<M> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let ptr = self.ptr.as_ptr();
let header = ptr as *const Header<M>;
f.debug_struct("Runnable")
.field("header", unsafe { &(*header) })
.finish()
}
}