call_recursion/
lib.rs

1//! Do recursion on the heap
2//! ===
3//!
4//! This crate provides a method to avoid stack overflows
5//! by converting async functions into state machines and
6//! doing recursion on the heap.
7//!
8//! # Usage
9//!
10//! ``` rust
11//! // Import trait
12//! use call_recursion::FutureRecursion;
13//!
14//! // Writing deeply recursive functions async
15//! async fn pow_mod(base: usize, n: usize, r#mod: usize) -> usize {
16//!     if n == 0 {
17//!         1
18//!     }
19//!     else {
20//!         // Call 'recurse' method to recurse over the heap
21//!         // 'recurse' return Future
22//!         (base * pow_mod(base, n - 1, r#mod).recurse().await) % r#mod
23//!     }
24//! }
25//!
26//! fn main() {
27//!     // Call 'start_recursion' method at the beginning of the recursion.
28//!     // Return value of 'start_recursion' is not changed
29//!     println!("{}", pow_mod(2, 10_000_000, 1_000_000).start_recursion());
30//! }
31//! ```
32
33use std::{cell::RefCell, pin::Pin, rc::Rc};
34
35pub struct Output<T> {
36    state: Rc<RefCell<Option<T>>>,
37}
38impl<T> Default for Output<T> {
39    fn default() -> Self {
40        Self {
41            state: Rc::new(RefCell::new(None)),
42        }
43    }
44}
45impl<T: Unpin> Future for Output<T> {
46    type Output = T;
47    fn poll(
48        self: Pin<&mut Self>,
49        _cx: &mut std::task::Context<'_>,
50    ) -> std::task::Poll<Self::Output> {
51        if let Some(t) = self.get_mut().state.take() {
52            std::task::Poll::Ready(t)
53        }
54        else {
55            std::task::Poll::Pending
56        }
57    }
58}
59
60struct FutureWrapper<F: Future> {
61    future: F,
62    state: Rc<RefCell<Option<F::Output>>>,
63}
64impl<F: Future> Future for FutureWrapper<F> {
65    type Output = ();
66    fn poll(
67        mut self: std::pin::Pin<&mut Self>,
68        cx: &mut std::task::Context<'_>,
69    ) -> std::task::Poll<Self::Output> {
70        let future = unsafe {
71            Pin::new_unchecked(
72                &mut Pin::get_unchecked_mut(self.as_mut()).future,
73            )
74        };
75        future.poll(cx).map(|out| {
76            *self.state.borrow_mut() = Some(out);
77        })
78    }
79}
80impl<F> FutureWrapper<F>
81where
82    F: Future,
83    F::Output: Unpin,
84{
85    fn new(f: F) -> (FutureWrapper<F>, Output<F::Output>) {
86        let output = Output::default();
87        (
88            FutureWrapper {
89                future: f,
90                state: output.state.clone(),
91            },
92            output,
93        )
94    }
95}
96
97thread_local! {
98    static RECURSION_TEM: RefCell<Option<Pin<Box<dyn Future<Output = ()>>>>> = const { RefCell::new(None) };
99}
100
101pub trait FutureRecursion
102where
103    Self: Future,
104{
105    fn start_recursion(self) -> Self::Output;
106    fn recurse(self) -> Output<Self::Output>;
107}
108
109mod noop_waker {
110    unsafe fn noop_clone(_data: *const ()) -> std::task::RawWaker {
111        noop_raw_waker()
112    }
113    unsafe fn noop(_data: *const ()) {}
114    const NOOP_WAKER_VTABLE: std::task::RawWakerVTable =
115        std::task::RawWakerVTable::new(noop_clone, noop, noop, noop);
116    const fn noop_raw_waker() -> std::task::RawWaker {
117        std::task::RawWaker::new(std::ptr::null(), &NOOP_WAKER_VTABLE)
118    }
119    #[inline]
120    pub fn noop_waker() -> std::task::Waker {
121        unsafe { std::task::Waker::from_raw(noop_raw_waker()) }
122    }
123}
124
125impl<F> FutureRecursion for F
126where
127    F: Future + 'static,
128    F::Output: Unpin,
129{
130    fn start_recursion(self) -> Self::Output {
131        let tem = RECURSION_TEM.replace(None);
132
133        let waker = noop_waker::noop_waker();
134        let mut context = std::task::Context::from_waker(&waker);
135        let mut stack: Vec<Pin<Box<dyn Future<Output = ()>>>> = vec![];
136
137        let (f, output) = FutureWrapper::new(self);
138        stack.push(Box::pin(f));
139        while let Some(l) = stack.last_mut() {
140            match l.as_mut().poll(&mut context) {
141                std::task::Poll::Ready(_) => {
142                    stack.pop();
143                }
144                std::task::Poll::Pending => {
145                    if let Some(f) = RECURSION_TEM.replace(None) {
146                        stack.push(f);
147                    }
148                }
149            }
150        }
151
152        RECURSION_TEM.set(tem);
153
154        output.state.take().unwrap()
155    }
156    fn recurse(self) -> Output<Self::Output> {
157        let (fw, output) = FutureWrapper::new(self);
158        if RECURSION_TEM.replace(Some(Box::pin(fw))).is_some() {
159            panic!("incorrect recursion");
160        }
161        output
162    }
163}