anput-generator 0.20.4

Async generator library
Documentation
use std::{
    cell::Cell,
    future::poll_fn,
    pin::Pin,
    rc::Rc,
    sync::Arc,
    task::{Context, Poll, RawWaker, RawWakerVTable, Wake, Waker},
};

pub trait IntoGenerator<T> {
    fn into_generator(self) -> GeneratorIter<T>;
}

impl<T, F: Future<Output = ()> + 'static> IntoGenerator<T> for F {
    fn into_generator(self) -> GeneratorIter<T> {
        GeneratorIter::new(self)
    }
}

pub struct GeneratorIter<T> {
    future: Pin<Box<dyn Future<Output = ()>>>,
    yielded_value: Rc<Cell<Option<T>>>,
}

impl<T> GeneratorIter<T> {
    pub fn new<F: Future<Output = ()> + 'static>(future: F) -> Self {
        GeneratorIter {
            future: Box::pin(future),
            yielded_value: Default::default(),
        }
    }
}

impl<T> Iterator for GeneratorIter<T> {
    type Item = T;

    fn next(&mut self) -> Option<Self::Item> {
        let waker = GeneratorWaker::<T>::new_waker(self.yielded_value.clone());
        let mut context = Context::from_waker(&waker);
        match self.future.as_mut().poll(&mut context) {
            Poll::Ready(_) => None,
            Poll::Pending => self.yielded_value.take(),
        }
    }
}

pub async fn gen_yield<T>(value: T) {
    let mut value = Some(value);
    poll_fn(move |cx| {
        let waker = cx.waker();
        if let Some(value) = value.take() {
            if let Some(waker) = GeneratorWaker::<T>::try_cast(waker) {
                waker.yielded_value.set(Some(value));
            }
            waker.wake_by_ref();
            Poll::Pending
        } else {
            waker.wake_by_ref();
            Poll::Ready(())
        }
    })
    .await
}

struct GeneratorWaker<T> {
    yielded_value: Rc<Cell<Option<T>>>,
}

impl<T> GeneratorWaker<T> {
    const VTABLE: RawWakerVTable =
        RawWakerVTable::new(Self::vtable_clone, |_| {}, |_| {}, Self::vtable_drop);

    fn vtable_clone(data: *const ()) -> RawWaker {
        let arc = unsafe { Arc::<Self>::from_raw(data as *const Self) };
        let cloned = arc.clone();
        std::mem::forget(arc);
        RawWaker::new(Arc::into_raw(cloned) as *const (), &Self::VTABLE)
    }

    fn vtable_drop(data: *const ()) {
        let _ = unsafe { Arc::from_raw(data as *const Self) };
    }

    fn new_waker(yielded_value: Rc<Cell<Option<T>>>) -> Waker {
        let arc = Arc::new(Self { yielded_value });
        let raw = RawWaker::new(Arc::into_raw(arc) as *const (), &Self::VTABLE);
        unsafe { Waker::from_raw(raw) }
    }

    fn try_cast(waker: &Waker) -> Option<&Self> {
        if waker.vtable() == &Self::VTABLE {
            unsafe { waker.data().cast::<Self>().as_ref() }
        } else {
            None
        }
    }
}

impl<T> Wake for GeneratorWaker<T> {
    fn wake(self: Arc<Self>) {}
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::marker::PhantomData;

    #[test]
    fn test_generator() {
        let provided = async {
            for i in 0..5 {
                gen_yield(i).await;
            }
            for i in -10..-5 {
                gen_yield(i).await;
            }
        }
        .into_generator()
        .collect::<Vec<i32>>();

        assert_eq!(provided, vec![0, 1, 2, 3, 4, -10, -9, -8, -7, -6]);
    }

    #[test]
    fn test_generator_no_send_sync() {
        struct Foo {
            value: i32,
            _phantom: PhantomData<*const ()>,
        }

        impl Foo {
            fn new(value: i32) -> Self {
                Foo {
                    value,
                    _phantom: PhantomData,
                }
            }

            fn value(&self) -> i32 {
                self.value
            }
        }

        let provided = async {
            for i in 0..5 {
                gen_yield(Foo::new(i)).await;
            }
            for i in -10..-5 {
                gen_yield(Foo::new(i)).await;
            }
        }
        .into_generator()
        .map(|v: Foo| v.value())
        .collect::<Vec<i32>>();

        assert_eq!(provided, vec![0, 1, 2, 3, 4, -10, -9, -8, -7, -6]);
    }
}