1use std::{
2 cell::Cell,
3 future::poll_fn,
4 pin::Pin,
5 rc::Rc,
6 sync::Arc,
7 task::{Context, Poll, RawWaker, RawWakerVTable, Wake, Waker},
8};
9
10pub trait IntoGenerator<T> {
11 fn into_generator(self) -> GeneratorIter<T>;
12}
13
14impl<T, F: Future<Output = ()> + 'static> IntoGenerator<T> for F {
15 fn into_generator(self) -> GeneratorIter<T> {
16 GeneratorIter::new(self)
17 }
18}
19
20pub struct GeneratorIter<T> {
21 future: Pin<Box<dyn Future<Output = ()>>>,
22 yielded_value: Rc<Cell<Option<T>>>,
23}
24
25impl<T> GeneratorIter<T> {
26 pub fn new<F: Future<Output = ()> + 'static>(future: F) -> Self {
27 GeneratorIter {
28 future: Box::pin(future),
29 yielded_value: Default::default(),
30 }
31 }
32}
33
34impl<T> Iterator for GeneratorIter<T> {
35 type Item = T;
36
37 fn next(&mut self) -> Option<Self::Item> {
38 let waker = GeneratorWaker::<T>::new_waker(self.yielded_value.clone());
39 let mut context = Context::from_waker(&waker);
40 match self.future.as_mut().poll(&mut context) {
41 Poll::Ready(_) => None,
42 Poll::Pending => self.yielded_value.take(),
43 }
44 }
45}
46
47pub async fn gen_yield<T>(value: T) {
48 let mut value = Some(value);
49 poll_fn(move |cx| {
50 let waker = cx.waker();
51 if let Some(value) = value.take() {
52 if let Some(waker) = GeneratorWaker::<T>::try_cast(waker) {
53 waker.yielded_value.set(Some(value));
54 }
55 waker.wake_by_ref();
56 Poll::Pending
57 } else {
58 waker.wake_by_ref();
59 Poll::Ready(())
60 }
61 })
62 .await
63}
64
65struct GeneratorWaker<T> {
66 yielded_value: Rc<Cell<Option<T>>>,
67}
68
69impl<T> GeneratorWaker<T> {
70 const VTABLE: RawWakerVTable =
71 RawWakerVTable::new(Self::vtable_clone, |_| {}, |_| {}, Self::vtable_drop);
72
73 fn vtable_clone(data: *const ()) -> RawWaker {
74 let arc = unsafe { Arc::<Self>::from_raw(data as *const Self) };
75 let cloned = arc.clone();
76 std::mem::forget(arc);
77 RawWaker::new(Arc::into_raw(cloned) as *const (), &Self::VTABLE)
78 }
79
80 fn vtable_drop(data: *const ()) {
81 let _ = unsafe { Arc::from_raw(data as *const Self) };
82 }
83
84 fn new_waker(yielded_value: Rc<Cell<Option<T>>>) -> Waker {
85 let arc = Arc::new(Self { yielded_value });
86 let raw = RawWaker::new(Arc::into_raw(arc) as *const (), &Self::VTABLE);
87 unsafe { Waker::from_raw(raw) }
88 }
89
90 fn try_cast(waker: &Waker) -> Option<&Self> {
91 if waker.vtable() == &Self::VTABLE {
92 unsafe { waker.data().cast::<Self>().as_ref() }
93 } else {
94 None
95 }
96 }
97}
98
99impl<T> Wake for GeneratorWaker<T> {
100 fn wake(self: Arc<Self>) {}
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use std::marker::PhantomData;
107
108 #[test]
109 fn test_generator() {
110 let provided = async {
111 for i in 0..5 {
112 gen_yield(i).await;
113 }
114 for i in -10..-5 {
115 gen_yield(i).await;
116 }
117 }
118 .into_generator()
119 .collect::<Vec<i32>>();
120
121 assert_eq!(provided, vec![0, 1, 2, 3, 4, -10, -9, -8, -7, -6]);
122 }
123
124 #[test]
125 fn test_generator_no_send_sync() {
126 struct Foo {
127 value: i32,
128 _phantom: PhantomData<*const ()>,
129 }
130
131 impl Foo {
132 fn new(value: i32) -> Self {
133 Foo {
134 value,
135 _phantom: PhantomData,
136 }
137 }
138
139 fn value(&self) -> i32 {
140 self.value
141 }
142 }
143
144 let provided = async {
145 for i in 0..5 {
146 gen_yield(Foo::new(i)).await;
147 }
148 for i in -10..-5 {
149 gen_yield(Foo::new(i)).await;
150 }
151 }
152 .into_generator()
153 .map(|v: Foo| v.value())
154 .collect::<Vec<i32>>();
155
156 assert_eq!(provided, vec![0, 1, 2, 3, 4, -10, -9, -8, -7, -6]);
157 }
158}