1use std::{cell::RefCell, pin::Pin, rc::Rc};
34
35pub struct Output<T> {
36 state: Rc<RefCell<Option<T>>>,
37}
38impl<T> Default for Output<T> {
39 fn default() -> Self {
40 Self {
41 state: Rc::new(RefCell::new(None)),
42 }
43 }
44}
45impl<T: Unpin> Future for Output<T> {
46 type Output = T;
47 fn poll(
48 self: Pin<&mut Self>,
49 _cx: &mut std::task::Context<'_>,
50 ) -> std::task::Poll<Self::Output> {
51 if let Some(t) = self.get_mut().state.take() {
52 std::task::Poll::Ready(t)
53 }
54 else {
55 std::task::Poll::Pending
56 }
57 }
58}
59
60struct FutureWrapper<F: Future> {
61 future: F,
62 state: Rc<RefCell<Option<F::Output>>>,
63}
64impl<F: Future> Future for FutureWrapper<F> {
65 type Output = ();
66 fn poll(
67 mut self: std::pin::Pin<&mut Self>,
68 cx: &mut std::task::Context<'_>,
69 ) -> std::task::Poll<Self::Output> {
70 let future = unsafe {
71 Pin::new_unchecked(
72 &mut Pin::get_unchecked_mut(self.as_mut()).future,
73 )
74 };
75 future.poll(cx).map(|out| {
76 *self.state.borrow_mut() = Some(out);
77 })
78 }
79}
80impl<F> FutureWrapper<F>
81where
82 F: Future,
83 F::Output: Unpin,
84{
85 fn new(f: F) -> (FutureWrapper<F>, Output<F::Output>) {
86 let output = Output::default();
87 (
88 FutureWrapper {
89 future: f,
90 state: output.state.clone(),
91 },
92 output,
93 )
94 }
95}
96
97thread_local! {
98 static RECURSION_TEM: RefCell<Option<Pin<Box<dyn Future<Output = ()>>>>> = const { RefCell::new(None) };
99}
100
101pub trait FutureRecursion
102where
103 Self: Future,
104{
105 fn start_recursion(self) -> Self::Output;
106 fn recurse(self) -> Output<Self::Output>;
107}
108
109mod noop_waker {
110 unsafe fn noop_clone(_data: *const ()) -> std::task::RawWaker {
111 noop_raw_waker()
112 }
113 unsafe fn noop(_data: *const ()) {}
114 const NOOP_WAKER_VTABLE: std::task::RawWakerVTable =
115 std::task::RawWakerVTable::new(noop_clone, noop, noop, noop);
116 const fn noop_raw_waker() -> std::task::RawWaker {
117 std::task::RawWaker::new(std::ptr::null(), &NOOP_WAKER_VTABLE)
118 }
119 #[inline]
120 pub fn noop_waker() -> std::task::Waker {
121 unsafe { std::task::Waker::from_raw(noop_raw_waker()) }
122 }
123}
124
125impl<F> FutureRecursion for F
126where
127 F: Future + 'static,
128 F::Output: Unpin,
129{
130 fn start_recursion(self) -> Self::Output {
131 let tem = RECURSION_TEM.replace(None);
132
133 let waker = noop_waker::noop_waker();
134 let mut context = std::task::Context::from_waker(&waker);
135 let mut stack: Vec<Pin<Box<dyn Future<Output = ()>>>> = vec![];
136
137 let (f, output) = FutureWrapper::new(self);
138 stack.push(Box::pin(f));
139 while let Some(l) = stack.last_mut() {
140 match l.as_mut().poll(&mut context) {
141 std::task::Poll::Ready(_) => {
142 stack.pop();
143 }
144 std::task::Poll::Pending => {
145 if let Some(f) = RECURSION_TEM.replace(None) {
146 stack.push(f);
147 }
148 }
149 }
150 }
151
152 RECURSION_TEM.set(tem);
153
154 output.state.take().unwrap()
155 }
156 fn recurse(self) -> Output<Self::Output> {
157 let (fw, output) = FutureWrapper::new(self);
158 if RECURSION_TEM.replace(Some(Box::pin(fw))).is_some() {
159 panic!("incorrect recursion");
160 }
161 output
162 }
163}