use std::{future::Future, pin::Pin, task::Poll};
use derive_more::{Deref, From};
use futures::FutureExt;
use pin_project::pin_project;
use crate::{control::ControlHandle, platform};
pub use platform::task::wasm_nonsend;
#[cfg(not(target_os = "unknown"))]
pub use platform::task::make_native_multithreaded_runtime;
pub fn spawn<F, T>(fut: F) -> JoinHandle<T>
where
F: 'static + Send + Future<Output = T>,
T: 'static + Send,
{
RuntimeHandle::current().spawn(fut)
}
pub fn block_in_place<R, F>(f: F) -> R
where
F: FnOnce() -> R,
{
RuntimeHandle::current().block_in_place(f)
}
pub fn spawn_blocking<R, F>(f: F) -> JoinHandle<R>
where
F: 'static + Send + FnOnce() -> R,
R: 'static + Send,
{
RuntimeHandle::current().spawn_blocking(f)
}
pub fn spawn_local<F, Fut, T>(func: F) -> ControlHandle<T>
where
F: 'static + FnOnce() -> Fut + Send,
Fut: 'static + Future<Output = T>,
T: 'static + Send,
{
platform::task::spawn_local(func)
}
#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
pub enum JoinError {
#[error("The future was aborted")]
Aborted,
#[error("The future panicked")]
#[allow(dead_code)]
Panicked,
}
pub struct JoinHandle<T>(pub(crate) platform::task::JoinHandle<T>);
impl<T: std::fmt::Debug> std::fmt::Debug for JoinHandle<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl<T> JoinHandle<T> {
pub fn abort(&self) {
self.0.abort()
}
pub fn is_finished(&self) -> bool {
self.0.is_finished()
}
}
impl<T> Future for JoinHandle<T> {
type Output = Result<T, JoinError>;
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
self.0.poll_unpin(cx)
}
}
#[derive(From, Deref)]
pub struct ChildTask<T>(JoinHandle<T>);
impl<T> Future for ChildTask<T> {
type Output = Result<T, JoinError>;
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
self.0.poll_unpin(cx)
}
}
impl<T> Drop for ChildTask<T> {
fn drop(&mut self) {
self.0.abort();
}
}
#[derive(Debug, Clone, From)]
pub struct RuntimeHandle(pub(crate) platform::task::RuntimeHandle);
impl RuntimeHandle {
#[inline]
pub fn current() -> Self {
Self(platform::task::RuntimeHandle::current())
}
#[inline]
pub fn spawn<F, T>(&self, future: F) -> JoinHandle<T>
where
F: 'static + Send + Future<Output = T>,
T: 'static + Send,
{
JoinHandle(self.0.spawn(future))
}
#[inline]
#[cfg(target_os = "unknown")]
pub fn spawn_local<F, T>(&self, future: F) -> JoinHandle<T>
where
F: 'static + Future<Output = T>,
T: 'static,
{
JoinHandle(self.0.spawn(future))
}
pub fn block_in_place<R, F>(&self, f: F) -> R
where
F: FnOnce() -> R,
{
self.0.block_in_place(f)
}
pub fn spawn_blocking<R, F>(&self, f: F) -> JoinHandle<R>
where
F: 'static + Send + FnOnce() -> R,
R: 'static + Send,
{
JoinHandle(self.0.spawn_blocking(f))
}
#[cfg(not(target_os = "unknown"))]
pub fn block_on<F: Future>(&self, future: F) -> F::Output {
self.0.block_on(future)
}
}
#[must_use = "Futures do nothing if not polled"]
pub struct PlatformBoxFuture<'a, T>(platform::task::PlatformBoxFutureImpl<'a, T>);
impl<'a, T> PlatformBoxFuture<'a, T> {
#[cfg(target_os = "unknown")]
pub fn new<F>(future: F) -> Self
where
F: 'a + Future<Output = T>,
{
Self(platform::task::PlatformBoxFutureImpl::from_boxed(Box::pin(
future,
)))
}
#[cfg(not(target_os = "unknown"))]
pub fn new<F>(future: F) -> Self
where
F: 'a + Future<Output = T> + Send,
{
Self(platform::task::PlatformBoxFutureImpl::from_boxed(Box::pin(
future,
)))
}
#[cfg(target_os = "unknown")]
pub fn into_local(self) -> futures::future::LocalBoxFuture<'a, T> {
self.0.into_local()
}
#[cfg(not(target_os = "unknown"))]
pub fn into_shared(self) -> futures::future::BoxFuture<'a, T> {
self.0.into_shared()
}
}
impl<'a, T> Future for PlatformBoxFuture<'a, T> {
type Output = T;
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
self.0.poll_unpin(cx)
}
}
#[pin_project]
pub struct Cooperative<F> {
max_budget: usize,
current_budget: usize,
#[pin]
fut: F,
}
impl<F> Cooperative<F> {
pub fn new(fut: F) -> Self {
Self {
max_budget: 1024,
current_budget: 1024,
fut,
}
}
pub fn with_budget(mut self, budget: usize) -> Self {
self.max_budget = budget;
self
}
}