#![doc = include_str!("../README.md")]
#![cfg_attr(doc, feature(doc_cfg))]
#![no_std]
mod state;
mod task;
pub use task::Runnable;
use task::{RawHandle, RawJoinHandle, Task};
#[cfg(test)]
mod test;
use core::{
fmt,
pin::Pin,
task::{Context, Poll},
};
extern crate alloc;
use alloc::boxed::Box;
#[cfg(any(feature = "std", test))]
extern crate std;
#[cfg(feature = "std")]
use std::thread::{self, ThreadId};
use pin_project::pin_project;
use thiserror::Error;
pub struct Builder<M> {
metadata: M,
catch_unwind: bool,
}
impl Builder<()> {
#[inline]
#[must_use]
pub const fn new() -> Builder<()> {
Builder {
metadata: (),
catch_unwind: false,
}
}
}
impl<M> Builder<M> {
#[inline]
#[must_use]
pub fn metadata<T>(self, metadata: T) -> Builder<T> {
Builder {
metadata,
catch_unwind: self.catch_unwind,
}
}
#[inline]
#[must_use]
#[cfg(feature = "std")]
pub fn catch_unwind(self, catch_unwind: bool) -> Builder<M> {
Builder {
catch_unwind,
..self
}
}
#[inline]
pub fn spawn<F, T, S>(self, future: F, scheduler: S) -> (Runnable<M>, JoinHandle<T, M>)
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
S: Schedule<M> + Sync + 'static,
{
unsafe { self.spawn_unchecked(future, scheduler) }
}
#[cfg(feature = "std")]
#[inline]
pub fn spawn_local<F, T, S>(self, future: F, scheduler: S) -> (Runnable<M>, JoinHandle<T, M>)
where
F: Future<Output = T> + 'static,
T: 'static,
S: Schedule<M> + 'static,
{
#[pin_project]
struct ThreadLocal<F> {
#[pin]
future: F,
thread: ThreadId,
}
impl<F, R> Future for ThreadLocal<F>
where
F: Future<Output = R>,
{
type Output = R;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
assert_eq!(
self.thread,
thread::current().id(),
"a local future can only be run on the thread on which it was spawned"
);
self.project().future.poll(cx)
}
}
let future = ThreadLocal {
future,
thread: thread::current().id(),
};
unsafe { self.spawn_unchecked(future, scheduler) }
}
#[cfg(feature = "std")]
fn wrap_catch_unwind<F: Future<Output = R>, R>(future: F) -> impl Future<Output = Result<R>> {
use std::panic::{AssertUnwindSafe, catch_unwind};
#[pin_project]
struct CatchUnwind<F>(#[pin] F);
impl<F, R> Future for CatchUnwind<F>
where
F: Future<Output = R>,
{
type Output = Result<R>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let res = catch_unwind(AssertUnwindSafe(|| self.project().0.poll(cx)));
match res {
Ok(Poll::Ready(value)) => Poll::Ready(Ok(value)),
Ok(Poll::Pending) => Poll::Pending,
Err(err) => Poll::Ready(Err(Error::panicked(err))),
}
}
}
CatchUnwind(future)
}
#[cfg(not(feature = "std"))]
fn wrap_catch_unwind<F: Future<Output = R>, R>(future: F) -> impl Future<Output = Result<R>> {
Self::wrap_panicking(future)
}
fn wrap_panicking<F: Future<Output = R>, R>(future: F) -> impl Future<Output = Result<R>> {
#[pin_project]
struct Wrap<F>(#[pin] F);
impl<F, R> Future for Wrap<F>
where
F: Future<Output = R>,
{
type Output = Result<R>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().0.poll(cx).map(Ok)
}
}
Wrap(future)
}
#[inline]
pub unsafe fn spawn_unchecked<F, T, S>(
self,
future: F,
scheduler: S,
) -> (Runnable<M>, JoinHandle<T, M>)
where
F: Future<Output = T>,
S: Schedule<M>,
{
let (runnable, handle) = if self.catch_unwind {
Task::allocate(Self::wrap_catch_unwind(future), scheduler, self.metadata)
} else {
Task::allocate(Self::wrap_panicking(future), scheduler, self.metadata)
};
(runnable, JoinHandle { raw: handle })
}
}
#[inline]
pub fn spawn<F, T, S>(future: F, scheduler: S) -> (Runnable, JoinHandle<T>)
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
S: Schedule + Sync + 'static,
{
Builder::new().spawn(future, scheduler)
}
#[cfg(feature = "std")]
#[inline]
pub fn spawn_local<F, T, S>(future: F, scheduler: S) -> (Runnable, JoinHandle<T>)
where
F: Future<Output = T> + 'static,
T: 'static,
S: Schedule + 'static,
{
Builder::new().spawn_local(future, scheduler)
}
#[inline]
pub unsafe fn spawn_unchecked<F, T, S>(future: F, scheduler: S) -> (Runnable, JoinHandle<T>)
where
F: Future<Output = T>,
S: Schedule,
{
unsafe { Builder::new().spawn_unchecked(future, scheduler) }
}
impl Default for Builder<()> {
#[inline]
fn default() -> Self {
Self::new()
}
}
pub struct JoinHandle<T, M = ()> {
raw: RawJoinHandle<T, M>,
}
impl<T, M> JoinHandle<T, M> {
#[inline]
pub fn metadata(&self) -> &M {
self.raw.metadata()
}
#[inline]
pub fn cancel(&self) {
self.raw.cancel();
}
#[inline]
pub fn abort_handle(&self) -> AbortHandle {
AbortHandle {
raw: self.raw.handle().clone(),
}
}
#[inline]
pub fn finished(&self) -> bool {
self.raw.finished()
}
}
impl<T, M> Future for JoinHandle<T, M> {
type Output = Result<T>;
#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.raw.poll(cx)
}
}
impl<T, M: fmt::Debug> fmt::Debug for JoinHandle<T, M> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("JoinHandle")
.field("metadata", &self.metadata())
.finish_non_exhaustive()
}
}
unsafe impl<T: Send, M: Sync> Send for JoinHandle<T, M> {}
unsafe impl<T, M: Sync> Sync for JoinHandle<T, M> {}
#[derive(Clone)]
pub struct AbortHandle {
raw: RawHandle,
}
impl AbortHandle {
#[inline]
pub fn cancel(&self) {
self.raw.cancel();
}
}
impl fmt::Debug for AbortHandle {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AbortHandle").finish_non_exhaustive()
}
}
unsafe impl Send for AbortHandle {}
unsafe impl Sync for AbortHandle {}
#[derive(Debug, Clone, Copy)]
pub struct ScheduleInfo {
pub woken_while_running: bool,
}
pub trait Schedule<M = ()> {
fn schedule(&self, runnable: Runnable<M>, info: ScheduleInfo);
}
impl<F, M> Schedule<M> for F
where
F: Fn(Runnable<M>),
{
#[inline]
fn schedule(&self, runnable: Runnable<M>, _: ScheduleInfo) {
self(runnable);
}
}
#[derive(Debug)]
pub struct WithInfo<F>(pub F);
impl<F, M> Schedule<M> for WithInfo<F>
where
F: Fn(Runnable<M>, ScheduleInfo),
{
#[inline]
fn schedule(&self, runnable: Runnable<M>, info: ScheduleInfo) {
self.0(runnable, info);
}
}
#[derive(Debug, Error)]
pub enum Error {
#[error("the task was cancelled")]
Cancelled,
#[error("the task panicked")]
Panicked {
payload: Box<dyn core::any::Any + Send>,
},
}
impl Error {
const fn panicked(payload: Box<dyn core::any::Any + Send>) -> Error {
Error::Panicked { payload }
}
}
pub type Result<T, E = Error> = core::result::Result<T, E>;