anput_generator/
lib.rs

1use std::{
2    cell::Cell,
3    future::poll_fn,
4    pin::Pin,
5    rc::Rc,
6    sync::Arc,
7    task::{Context, Poll, RawWaker, RawWakerVTable, Wake, Waker},
8};
9
10pub trait IntoGenerator<T> {
11    fn into_generator(self) -> GeneratorIter<T>;
12}
13
14impl<T, F: Future<Output = ()> + 'static> IntoGenerator<T> for F {
15    fn into_generator(self) -> GeneratorIter<T> {
16        GeneratorIter::new(self)
17    }
18}
19
20pub struct GeneratorIter<T> {
21    future: Pin<Box<dyn Future<Output = ()>>>,
22    yielded_value: Rc<Cell<Option<T>>>,
23}
24
25impl<T> GeneratorIter<T> {
26    pub fn new<F: Future<Output = ()> + 'static>(future: F) -> Self {
27        GeneratorIter {
28            future: Box::pin(future),
29            yielded_value: Default::default(),
30        }
31    }
32}
33
34impl<T> Iterator for GeneratorIter<T> {
35    type Item = T;
36
37    fn next(&mut self) -> Option<Self::Item> {
38        let waker = GeneratorWaker::<T>::new_waker(self.yielded_value.clone());
39        let mut context = Context::from_waker(&waker);
40        match self.future.as_mut().poll(&mut context) {
41            Poll::Ready(_) => None,
42            Poll::Pending => self.yielded_value.take(),
43        }
44    }
45}
46
47pub async fn gen_yield<T>(value: T) {
48    let mut value = Some(value);
49    poll_fn(move |cx| {
50        let waker = cx.waker();
51        if let Some(value) = value.take() {
52            if let Some(waker) = GeneratorWaker::<T>::try_cast(waker) {
53                waker.yielded_value.set(Some(value));
54            }
55            waker.wake_by_ref();
56            Poll::Pending
57        } else {
58            waker.wake_by_ref();
59            Poll::Ready(())
60        }
61    })
62    .await
63}
64
65struct GeneratorWaker<T> {
66    yielded_value: Rc<Cell<Option<T>>>,
67}
68
69impl<T> GeneratorWaker<T> {
70    const VTABLE: RawWakerVTable =
71        RawWakerVTable::new(Self::vtable_clone, |_| {}, |_| {}, Self::vtable_drop);
72
73    fn vtable_clone(data: *const ()) -> RawWaker {
74        let arc = unsafe { Arc::<Self>::from_raw(data as *const Self) };
75        let cloned = arc.clone();
76        std::mem::forget(arc);
77        RawWaker::new(Arc::into_raw(cloned) as *const (), &Self::VTABLE)
78    }
79
80    fn vtable_drop(data: *const ()) {
81        let _ = unsafe { Arc::from_raw(data as *const Self) };
82    }
83
84    fn new_waker(yielded_value: Rc<Cell<Option<T>>>) -> Waker {
85        let arc = Arc::new(Self { yielded_value });
86        let raw = RawWaker::new(Arc::into_raw(arc) as *const (), &Self::VTABLE);
87        unsafe { Waker::from_raw(raw) }
88    }
89
90    fn try_cast(waker: &Waker) -> Option<&Self> {
91        if waker.vtable() == &Self::VTABLE {
92            unsafe { waker.data().cast::<Self>().as_ref() }
93        } else {
94            None
95        }
96    }
97}
98
99impl<T> Wake for GeneratorWaker<T> {
100    fn wake(self: Arc<Self>) {}
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use std::marker::PhantomData;
107
108    #[test]
109    fn test_generator() {
110        let provided = async {
111            for i in 0..5 {
112                gen_yield(i).await;
113            }
114            for i in -10..-5 {
115                gen_yield(i).await;
116            }
117        }
118        .into_generator()
119        .collect::<Vec<i32>>();
120
121        assert_eq!(provided, vec![0, 1, 2, 3, 4, -10, -9, -8, -7, -6]);
122    }
123
124    #[test]
125    fn test_generator_no_send_sync() {
126        struct Foo {
127            value: i32,
128            _phantom: PhantomData<*const ()>,
129        }
130
131        impl Foo {
132            fn new(value: i32) -> Self {
133                Foo {
134                    value,
135                    _phantom: PhantomData,
136                }
137            }
138
139            fn value(&self) -> i32 {
140                self.value
141            }
142        }
143
144        let provided = async {
145            for i in 0..5 {
146                gen_yield(Foo::new(i)).await;
147            }
148            for i in -10..-5 {
149                gen_yield(Foo::new(i)).await;
150            }
151        }
152        .into_generator()
153        .map(|v: Foo| v.value())
154        .collect::<Vec<i32>>();
155
156        assert_eq!(provided, vec![0, 1, 2, 3, 4, -10, -9, -8, -7, -6]);
157    }
158}