use std::{
any::Any,
fmt,
future::Future,
panic::{catch_unwind, resume_unwind, AssertUnwindSafe},
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll, Waker},
thread,
time::{Duration, Instant},
};
pub struct Task<T> {
inner: Arc<Mutex<Inner<T>>>,
}
struct Inner<T> {
result: Option<thread::Result<T>>,
waker: Option<Waker>,
}
impl<T> Task<T> {
pub(crate) fn from_closure<F>(closure: F) -> (Self, Coroutine)
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
let task = Self::pending();
let coroutine = Coroutine {
might_yield: false,
waker: crate::wakers::empty_waker(),
poller: Box::new(ClosurePoller {
closure: Some(closure),
result: None,
task: task.inner.clone(),
}),
};
(task, coroutine)
}
pub(crate) fn from_future<F>(future: F) -> (Self, Coroutine)
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let task = Self::pending();
let coroutine = Coroutine {
might_yield: true,
waker: crate::wakers::empty_waker(),
poller: Box::new(FuturePoller {
future,
result: None,
task: task.inner.clone(),
}),
};
(task, coroutine)
}
fn pending() -> Self {
Self {
inner: Arc::new(Mutex::new(Inner {
result: None,
waker: None,
})),
}
}
pub fn is_done(&self) -> bool {
self.inner.lock().unwrap().result.is_some()
}
pub fn join(self) -> T {
match self.join_catch() {
Ok(value) => value,
Err(e) => resume_unwind(e),
}
}
fn join_catch(self) -> thread::Result<T> {
let mut inner = self.inner.lock().unwrap();
if let Some(result) = inner.result.take() {
result
} else {
inner.waker = Some(crate::wakers::current_thread_waker());
drop(inner);
loop {
thread::park();
if let Some(result) = self.inner.lock().unwrap().result.take() {
break result;
}
}
}
}
pub fn join_timeout(self, timeout: Duration) -> Result<T, Self> {
self.join_deadline(Instant::now() + timeout)
}
pub fn join_deadline(self, deadline: Instant) -> Result<T, Self> {
match {
let mut inner = self.inner.lock().unwrap();
if let Some(result) = inner.result.take() {
result
} else {
inner.waker = Some(crate::wakers::current_thread_waker());
drop(inner);
loop {
if let Some(timeout) = deadline.checked_duration_since(Instant::now()) {
thread::park_timeout(timeout);
} else {
return Err(self);
}
if let Some(result) = self.inner.lock().unwrap().result.take() {
break result;
}
}
}
} {
Ok(value) => Ok(value),
Err(e) => resume_unwind(e),
}
}
}
impl<T> Future for Task<T> {
type Output = T;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut inner = self.inner.lock().unwrap();
match inner.result.take() {
Some(Ok(value)) => Poll::Ready(value),
Some(Err(e)) => resume_unwind(e),
None => {
inner.waker = Some(cx.waker().clone());
Poll::Pending
}
}
}
}
impl<T> fmt::Debug for Task<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Task")
.field("done", &self.is_done())
.finish()
}
}
pub(crate) struct Coroutine {
might_yield: bool,
waker: Waker,
poller: Box<dyn CoroutinePoller>,
}
impl Coroutine {
pub(crate) fn might_yield(&self) -> bool {
self.might_yield
}
pub(crate) fn addr(&self) -> usize {
&*self.poller as *const dyn CoroutinePoller as *const () as usize
}
pub(crate) fn set_waker(&mut self, waker: Waker) {
self.waker = waker;
}
pub(crate) fn run(&mut self) -> RunResult {
let mut cx = Context::from_waker(&self.waker);
self.poller.run(&mut cx)
}
pub(crate) fn complete(mut self) {
self.poller.complete();
}
pub(crate) fn into_inner_closure<F, T>(self) -> F
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
self
.poller
.into_any()
.downcast::<ClosurePoller<F, T>>()
.unwrap()
.closure
.take()
.unwrap()
}
pub(crate) fn into_inner_future<F, T>(self) -> F
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
self
.poller
.into_any()
.downcast::<FuturePoller<F, T>>()
.unwrap()
.future
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) enum RunResult {
Yield,
Complete { panicked: bool },
}
trait CoroutinePoller: Send + 'static {
fn run(&mut self, cx: &mut Context) -> RunResult;
fn complete(&mut self);
fn into_any(self: Box<Self>) -> Box<dyn Any>;
}
struct ClosurePoller<F, T> {
closure: Option<F>,
result: Option<thread::Result<T>>,
task: Arc<Mutex<Inner<T>>>,
}
impl<F, T> CoroutinePoller for ClosurePoller<F, T>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
fn run(&mut self, _cx: &mut Context) -> RunResult {
let closure = self
.closure
.take()
.expect("closure already ran to completion");
let result = catch_unwind(AssertUnwindSafe(closure));
let panicked = result.is_err();
self.result = Some(result);
RunResult::Complete {
panicked,
}
}
fn complete(&mut self) {
if let Some(result) = self.result.take() {
let mut task = self.task.lock().unwrap();
task.result = Some(result);
if let Some(waker) = task.waker.as_ref() {
waker.wake_by_ref();
};
}
}
fn into_any(self: Box<Self>) -> Box<dyn Any> {
self
}
}
struct FuturePoller<F, T> {
future: F,
result: Option<thread::Result<T>>,
task: Arc<Mutex<Inner<T>>>,
}
impl<F, T> CoroutinePoller for FuturePoller<F, T>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
fn run(&mut self, cx: &mut Context) -> RunResult {
let future = unsafe { Pin::new_unchecked(&mut self.future) };
match catch_unwind(AssertUnwindSafe(|| future.poll(cx))) {
Ok(Poll::Pending) => RunResult::Yield,
Ok(Poll::Ready(value)) => {
self.result = Some(Ok(value));
RunResult::Complete {
panicked: false,
}
}
Err(e) => {
self.result = Some(Err(e));
RunResult::Complete {
panicked: true,
}
}
}
}
fn complete(&mut self) {
if let Some(result) = self.result.take() {
let mut task = self.task.lock().unwrap();
task.result = Some(result);
if let Some(waker) = task.waker.as_ref() {
waker.wake_by_ref();
};
}
}
fn into_any(self: Box<Self>) -> Box<dyn Any> {
self
}
}