wait/
lib.rs

1#![no_std]
2#![doc = include_str!("../README.md")]
3#![doc(html_favicon_url = "https://flippingbinary.com/wait-rs/favicon.ico")]
4
5use core::future::Future;
6
7#[cfg(not(feature = "tokio"))]
8use core::task::{Context, Poll, Waker};
9
10#[cfg(all(not(feature = "tokio"), not(feature = "std")))]
11static VTABLE: core::task::RawWakerVTable = core::task::RawWakerVTable::new(
12    |_| core::task::RawWaker::new(core::ptr::null(), &VTABLE),
13    |_| {},
14    |_| {},
15    |_| {},
16);
17
18/// The `Waitable` trait declares the `.wait()` method.
19///
20/// This trait is implemented for all types that implement the [`Future`]
21/// trait. All `async` functions return a `Future`, so this attaches the
22/// `.wait()` method to every `async` function. When called, the `.wait()`
23/// puts the thread to sleep until the `Future` is ready to return a value.
24pub trait Waitable: sealed::Sealed {
25    /// This is set to the return type of the `Future`.
26    type Output;
27
28    /// Put the thread to sleep until the `Future` is ready to return a value.
29    fn wait(self) -> Self::Output
30    where
31        Self: Sized;
32}
33
34impl<F> sealed::Sealed for F where F: Future {}
35
36#[cfg(all(not(feature = "tokio"), feature = "std"))]
37fn std_wait_block_on<F>(fut: F) -> F::Output
38where
39    F: Future + Sized,
40{
41    extern crate alloc;
42    extern crate std;
43
44    use std::thread;
45
46    use alloc::{boxed::Box, sync::Arc, task::Wake};
47
48    struct ThreadWaker {
49        thread: thread::Thread,
50    }
51
52    impl Wake for ThreadWaker {
53        fn wake(self: Arc<Self>) {
54            self.thread.unpark();
55        }
56    }
57
58    let waker = Arc::new(ThreadWaker {
59        thread: thread::current(),
60    });
61
62    let waker = Waker::from(waker);
63    let mut context = Context::from_waker(&waker);
64
65    let mut future = Box::pin(fut);
66
67    loop {
68        match future.as_mut().poll(&mut context) {
69            Poll::Ready(result) => return result,
70            Poll::Pending => {
71                thread::park();
72            }
73        }
74    }
75}
76
77#[cfg(all(not(feature = "tokio"), not(feature = "std")))]
78fn nostd_wait_block_on<F>(mut fut: F) -> F::Output
79where
80    F: Future + Sized,
81{
82    use core::{hint::spin_loop, pin::Pin, ptr::null, task::RawWaker};
83
84    let waker = {
85        let raw_waker = RawWaker::new(null(), &VTABLE);
86        #[allow(unsafe_code)]
87        unsafe {
88            Waker::from_raw(raw_waker)
89        }
90    };
91
92    #[allow(unsafe_code)]
93    let mut future = unsafe { Pin::new_unchecked(&mut fut) };
94
95    let mut context = Context::from_waker(&waker);
96
97    loop {
98        match future.as_mut().poll(&mut context) {
99            Poll::Ready(result) => return result,
100            Poll::Pending => {
101                for _ in 0..100 {
102                    spin_loop();
103                }
104            }
105        }
106    }
107}
108
109#[cfg(feature = "tokio")]
110fn tokio_wait_block_on<F>(fut: F) -> F::Output
111where
112    F: Future + Sized,
113{
114    if let Ok(handle) = tokio::runtime::Handle::try_current() {
115        if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::CurrentThread {
116            panic!("Cannot block on a future from within a CurrentThread runtime");
117        } else {
118            tokio::task::block_in_place(|| {
119                tokio::runtime::Builder::new_current_thread()
120                    .enable_all()
121                    .build()
122                    .unwrap()
123                    .block_on(fut)
124            })
125        }
126    } else {
127        tokio::runtime::Builder::new_current_thread()
128            .enable_all()
129            .build()
130            .unwrap()
131            .block_on(fut)
132    }
133}
134
135impl<F> Waitable for F
136where
137    F: Future,
138{
139    type Output = F::Output;
140
141    fn wait(self) -> Self::Output
142    where
143        Self: Sized,
144    {
145        #[cfg(all(not(feature = "tokio"), feature = "std"))]
146        return std_wait_block_on(self);
147        #[cfg(all(not(feature = "tokio"), not(feature = "std")))]
148        return nostd_wait_block_on(self);
149        #[cfg(feature = "tokio")]
150        return tokio_wait_block_on(self);
151    }
152}
153
154mod sealed {
155    pub trait Sealed {}
156}
157
158pub mod prelude {
159    //! This is the convenience module where the magic happens.
160    //!
161    //! The alternative is to import the [`Waitable`] trait directly.
162    //!
163    //! [`Waitable`]: super::Waitable
164
165    pub use super::Waitable as _;
166}
167
168#[cfg(test)]
169mod tests {
170    use super::prelude::*;
171
172    async fn add(a: usize, b: usize) -> usize {
173        a + b
174    }
175
176    async fn mul(a: usize, b: usize) -> usize {
177        let mut result = 0;
178        for _ in 0..a {
179            result = add(result, b).await;
180        }
181        result
182    }
183
184    #[test]
185    fn test_single_level() {
186        let result = add(2, 2).wait();
187        assert_eq!(result, 4);
188    }
189
190    #[test]
191    fn test_sequential_calls() {
192        let result1 = add(1, 2).wait();
193        let result2 = add(2, 3).wait();
194
195        assert_eq!(result1, 3);
196        assert_eq!(result2, 5);
197    }
198
199    #[test]
200    fn test_nested_calls() {
201        let result = mul(2, 3).wait();
202
203        assert_eq!(result, 6);
204    }
205
206    // Test the tokio runtime with reqwest only if tokio feature is enabled
207    #[cfg(feature = "tokio")]
208    #[test]
209    fn test_on_future_that_requires_tokio() {
210        let response = reqwest::get("https://www.rust-lang.org").wait().unwrap();
211        assert!(response.status().is_success());
212    }
213
214    #[cfg(feature = "tokio")]
215    #[test]
216    #[should_panic]
217    fn test_inside_single_thread_tokio_runtime() {
218        let runtime = tokio::runtime::Builder::new_current_thread()
219            .build()
220            .unwrap();
221
222        let _ = runtime
223            .block_on(async { reqwest::get("https://www.rust-lang.org").wait() })
224            .is_err();
225    }
226
227    #[cfg(feature = "tokio")]
228    #[test]
229    fn test_inside_multi_thread_tokio_runtime_with_no_timers_or_io() {
230        let response = tokio::runtime::Builder::new_multi_thread()
231            .build()
232            .unwrap()
233            .block_on(async { reqwest::get("https://www.rust-lang.org").wait() })
234            .unwrap();
235
236        assert!(response.status().is_success());
237    }
238}