decurse/for_macro_only/
sound.rs

1pub use super::pend_once::PendOnce;
2pub use decurse_macro::decurse_sound;
3use pfn::PFnOnce;
4use pinned_vec::PinnedVec;
5use scoped_tls::scoped_thread_local;
6use std::{any::Any, cell::RefCell, future::Future, task::Poll};
7
8pub struct Context<F: Future> {
9	next: RefCell<Option<F>>,
10	result: RefCell<Option<F::Output>>,
11}
12
13impl<F: Future + 'static> Context<F> {
14	pub fn new() -> Self {
15		Self {
16			next: RefCell::new(None),
17			result: RefCell::new(None),
18		}
19	}
20	pub fn set_next(self_ptr: &Box<dyn Any>, fut: F) {
21		let this: &Self = self_ptr.downcast_ref().unwrap();
22		*this.next.borrow_mut() = Some(fut);
23	}
24	pub fn get_result(self_ptr: &Box<dyn Any>) -> F::Output {
25		let this: &Self = self_ptr.downcast_ref().unwrap();
26		this.result.borrow_mut().take().unwrap()
27	}
28}
29
30scoped_thread_local! (static CONTEXT: Box<dyn Any>);
31
32pub fn set_next<F: Future + 'static>(fut: F) {
33	CONTEXT.with(|c| Context::set_next(c, fut))
34}
35
36pub fn get_result<A, R, F>(_phantom: R) -> F::Output
37where
38	R: PFnOnce<A, PFnOutput = F>,
39	F: Future + 'static,
40{
41	CONTEXT.with(|c| Context::<F>::get_result(c))
42}
43
44pub fn execute<F>(fut: F) -> F::Output
45where
46	F: Future + 'static,
47{
48	let dummy_waker = waker_fn::waker_fn(|| {});
49	let mut dummy_async_cx: std::task::Context = std::task::Context::from_waker(&dummy_waker);
50	let ctx: Context<F> = Context::new();
51	let any_ctx: Box<dyn Any> = Box::new(ctx);
52	let ctx: &Context<F> = any_ctx.downcast_ref().unwrap();
53
54	let output = CONTEXT.set(&any_ctx, || {
55		let mut heap_stack: PinnedVec<F> = PinnedVec::new();
56		heap_stack.push(fut);
57		loop {
58			let len = heap_stack.len();
59			// UNWRAP Safety: The only way len could go down is through the pop in the Poll::Ready case,
60			// in which we return if len is 1. So len never gets to 0.
61			let fut = heap_stack.get_mut(len - 1).unwrap();
62			let polled = fut.poll(&mut dummy_async_cx);
63			match polled {
64				Poll::Ready(r) => {
65					if len == 1 {
66						break r;
67					} else {
68						let mut bm = ctx.result.borrow_mut();
69						*bm = Some(r);
70						heap_stack.pop();
71					}
72				}
73				Poll::Pending => {
74					// UNWRAP Safety: The decurse macro only yields when recursing,
75					// in which case `next` would be filled before Pending is returned (see ctx.set_next).
76					heap_stack.push(ctx.next.borrow_mut().take().unwrap());
77				}
78			}
79		}
80	});
81	output
82}
83
84#[macro_export]
85macro_rules! for_macro_only_recurse_sound {
86    ($func:path, ($($args:expr),*)) => {
87        ({
88            $crate::for_macro_only::sound::set_next($func ($($args),*));
89            $crate::for_macro_only::sound::PendOnce::new().await;
90            $crate::for_macro_only::sound::get_result($func)
91        })
92    };
93}
94
95#[cfg(test)]
96mod tests {
97	use super::*;
98	#[test]
99	fn stack_factorial() {
100		fn factorial(x: u32) -> u32 {
101			if x == 0 {
102				1
103			} else {
104				x * factorial(x - 1)
105			}
106		}
107		assert_eq!(factorial(6), 720);
108	}
109	#[test]
110	fn stack_fibonacci() {
111		fn fibonacci(x: u32) -> u32 {
112			if x == 0 || x == 1 {
113				1
114			} else {
115				fibonacci(x - 1) + fibonacci(x - 2)
116			}
117		}
118		assert_eq!(fibonacci(10), 89);
119	}
120
121	#[test]
122	fn factorial() {
123		async fn factorial(x: u32) -> u32 {
124			if x == 0 {
125				1
126			} else {
127				for_macro_only_recurse_sound!(factorial, (x - 1)) * x
128			}
129		}
130		assert_eq!(execute(factorial(6)), 720);
131	}
132
133	#[test]
134	fn fibonacci() {
135		async fn fibonacci(x: u32) -> u32 {
136			if x == 0 || x == 1 {
137				1
138			} else {
139				for_macro_only_recurse_sound!(fibonacci, (x - 1))
140					+ for_macro_only_recurse_sound!(fibonacci, (x - 2))
141			}
142		}
143		assert_eq!(execute(fibonacci(10)), 89);
144	}
145
146	// This test cause stack overflow.
147	// #[test]
148	// fn stack_triangular() {
149	//     fn stack_triangular(x: u64) -> u64 {
150	//         if x == 0 {
151	//             0
152	//         } else {
153	//             stack_triangular(x - 1) + x
154	//         }
155	//     }
156	//     assert_eq!(20000100000, stack_triangular(200000));
157	// }
158
159	#[test]
160	fn triangular() {
161		fn triangular(x: u64) -> u64 {
162			async fn decurse_triangular(x: u64) -> u64 {
163				if x == 0 {
164					0
165				} else {
166					for_macro_only_recurse_sound!(decurse_triangular, (x - 1)) + x
167				}
168			}
169			execute(decurse_triangular(x))
170		}
171		assert_eq!(20000100000, triangular(200000));
172	}
173}