use crate::{JoinHandle, LocalWaker};
use std::any::Any;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
struct Spawn<T> {
future: Option<T>,
}
impl<T> Unpin for Spawn<T> {}
impl<T: Future + 'static> Future for Spawn<T> {
type Output = JoinHandle<T>;
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let waker = LocalWaker::waker_mut(ctx);
Poll::Ready(waker.task_new(self.future.take().unwrap()))
}
}
pub fn spawn<T: Future + 'static>(future: T) -> impl Future<Output = JoinHandle<T>> {
Spawn::<T>{ future: Some(future), }
}
struct Yield(bool);
impl Future for Yield {
type Output = ();
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
if self.0 {
return Poll::Ready(());
}
let waker = LocalWaker::waker(ctx);
waker.wake_by_ref();
self.0 = true;
Poll::Pending
}
}
pub fn sched_yield() -> impl Future<Output = ()> {
Yield(false)
}
struct EventId;
impl Future for EventId {
type Output = u64;
fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let waker = LocalWaker::waker(ctx);
Poll::Ready(waker.new_event_id())
}
}
pub fn new_event_id() -> impl Future<Output = u64> {
EventId
}
struct CurrentTask;
impl Future for CurrentTask {
type Output = u64;
fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let waker = LocalWaker::waker(ctx);
Poll::Ready(waker.current_task_id())
}
}
pub fn current_task_id() -> impl Future<Output = u64> {
CurrentTask
}
struct WaitEvent<'a, T: Any> {
event_id: u64,
mark: PhantomData<&'a T>,
}
impl<'a, T: Any> Future for WaitEvent<'a, T> {
type Output = Option<&'a T>;
fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let waker = LocalWaker::waker(ctx);
if let Some(data) = waker.event_remove(self.event_id) {
return Poll::Ready(data.downcast_ref::<T>());
}
waker.event_register(self.event_id);
Poll::Pending
}
}
pub fn wait_event<'a, T: Any>(event_id: u64) -> impl Future<Output = Option<&'a T>> {
WaitEvent::<T> {
event_id,
mark: PhantomData,
}
}