use switcheroo::Generator;
use switcheroo::Yielder;
use std::cell::Cell;
use std::future::Future;
use std::io::Error;
use std::pin::Pin;
use std::task::{Context, Poll, Waker};
pub use switcheroo::stack;
pub struct AsyncWormhole<'a, Stack, Output, P>
where
Stack: stack::Stack + Send,
P: FnMut() + Send,
{
generator: Option<Cell<Generator<'a, Waker, Option<Output>, Stack>>>,
pre_post_poll: Option<P>,
}
impl<'a, Stack, Output, P> AsyncWormhole<'a, Stack, Output, P>
where
Stack: stack::Stack + Send,
P: FnMut() + Send,
{
pub fn new<F>(stack: Stack, f: F) -> Result<Self, Error>
where
F: FnOnce(AsyncYielder<Output>) -> Output + 'a + Send,
{
let generator = Generator::new(stack, |yielder, waker| {
let async_yielder = AsyncYielder::new(yielder, waker);
let finished = Some(f(async_yielder));
yielder.suspend(finished);
});
Ok(Self {
generator: Some(Cell::new(generator)),
pre_post_poll: None,
})
}
pub fn set_pre_post_poll(&mut self, f: P) {
self.pre_post_poll = Some(f);
}
pub fn stack(mut self) -> Stack {
let generator = self.generator.take().unwrap().into_inner();
if generator.started() && !generator.finished() {
if let Some(pre_post_poll) = &mut self.pre_post_poll {
pre_post_poll();
}
}
generator.stack()
}
}
impl<'a, Stack, Output, P> Future for AsyncWormhole<'a, Stack, Output, P>
where
Stack: stack::Stack + Unpin + Send,
P: FnMut() + Unpin + Send,
{
type Output = Output;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Some(pre_post_poll) = &mut self.pre_post_poll {
pre_post_poll()
}
match self
.generator
.as_mut()
.unwrap()
.get_mut()
.resume(cx.waker().clone())
{
None | Some(None) => {
if let Some(pre_post_poll) = &mut self.pre_post_poll {
pre_post_poll()
}
Poll::Pending
}
Some(Some(out)) => {
self.generator
.as_mut()
.unwrap()
.get_mut()
.resume(cx.waker().clone());
Poll::Ready(out)
}
}
}
}
impl<'a, Stack, Output, P> Drop for AsyncWormhole<'a, Stack, Output, P>
where
Stack: stack::Stack + Send,
P: FnMut() + Send,
{
fn drop(&mut self) {
if let Some(pre_post_poll) = &mut self.pre_post_poll {
if let Some(generator) = self.generator.as_mut() {
if generator.get_mut().started() && !generator.get_mut().finished() {
pre_post_poll()
}
}
}
}
}
#[derive(Clone)]
pub struct AsyncYielder<'a, Output> {
yielder: &'a Yielder<Waker, Option<Output>>,
waker: Waker,
}
impl<'a, Output> AsyncYielder<'a, Output> {
pub(crate) fn new(yielder: &'a Yielder<Waker, Option<Output>>, waker: Waker) -> Self {
Self { yielder, waker }
}
pub fn async_suspend<Fut, R>(&mut self, mut future: Fut) -> R
where
Fut: Future<Output = R>,
{
let mut future = unsafe { Pin::new_unchecked(&mut future) };
loop {
let mut cx = Context::from_waker(&mut self.waker);
self.waker = match future.as_mut().poll(&mut cx) {
Poll::Pending => self.yielder.suspend(None),
Poll::Ready(result) => return result,
};
}
}
}