1#![no_std]
2#![doc = include_str!("../README.md")]
3#![doc(html_favicon_url = "https://flippingbinary.com/wait-rs/favicon.ico")]
4
5use core::future::Future;
6
7#[cfg(not(feature = "tokio"))]
8use core::task::{Context, Poll, Waker};
9
10#[cfg(all(not(feature = "tokio"), not(feature = "std")))]
11static VTABLE: core::task::RawWakerVTable = core::task::RawWakerVTable::new(
12 |_| core::task::RawWaker::new(core::ptr::null(), &VTABLE),
13 |_| {},
14 |_| {},
15 |_| {},
16);
17
18pub trait Waitable: sealed::Sealed {
25 type Output;
27
28 fn wait(self) -> Self::Output
30 where
31 Self: Sized;
32}
33
34impl<F> sealed::Sealed for F where F: Future {}
35
36#[cfg(all(not(feature = "tokio"), feature = "std"))]
37fn std_wait_block_on<F>(fut: F) -> F::Output
38where
39 F: Future + Sized,
40{
41 extern crate alloc;
42 extern crate std;
43
44 use std::thread;
45
46 use alloc::{boxed::Box, sync::Arc, task::Wake};
47
48 struct ThreadWaker {
49 thread: thread::Thread,
50 }
51
52 impl Wake for ThreadWaker {
53 fn wake(self: Arc<Self>) {
54 self.thread.unpark();
55 }
56 }
57
58 let waker = Arc::new(ThreadWaker {
59 thread: thread::current(),
60 });
61
62 let waker = Waker::from(waker);
63 let mut context = Context::from_waker(&waker);
64
65 let mut future = Box::pin(fut);
66
67 loop {
68 match future.as_mut().poll(&mut context) {
69 Poll::Ready(result) => return result,
70 Poll::Pending => {
71 thread::park();
72 }
73 }
74 }
75}
76
77#[cfg(all(not(feature = "tokio"), not(feature = "std")))]
78fn nostd_wait_block_on<F>(mut fut: F) -> F::Output
79where
80 F: Future + Sized,
81{
82 use core::{hint::spin_loop, pin::Pin, ptr::null, task::RawWaker};
83
84 let waker = {
85 let raw_waker = RawWaker::new(null(), &VTABLE);
86 #[allow(unsafe_code)]
87 unsafe {
88 Waker::from_raw(raw_waker)
89 }
90 };
91
92 #[allow(unsafe_code)]
93 let mut future = unsafe { Pin::new_unchecked(&mut fut) };
94
95 let mut context = Context::from_waker(&waker);
96
97 loop {
98 match future.as_mut().poll(&mut context) {
99 Poll::Ready(result) => return result,
100 Poll::Pending => {
101 for _ in 0..100 {
102 spin_loop();
103 }
104 }
105 }
106 }
107}
108
109#[cfg(feature = "tokio")]
110fn tokio_wait_block_on<F>(fut: F) -> F::Output
111where
112 F: Future + Sized,
113{
114 if let Ok(handle) = tokio::runtime::Handle::try_current() {
115 if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::CurrentThread {
116 panic!("Cannot block on a future from within a CurrentThread runtime");
117 } else {
118 tokio::task::block_in_place(|| {
119 tokio::runtime::Builder::new_current_thread()
120 .enable_all()
121 .build()
122 .unwrap()
123 .block_on(fut)
124 })
125 }
126 } else {
127 tokio::runtime::Builder::new_current_thread()
128 .enable_all()
129 .build()
130 .unwrap()
131 .block_on(fut)
132 }
133}
134
135impl<F> Waitable for F
136where
137 F: Future,
138{
139 type Output = F::Output;
140
141 fn wait(self) -> Self::Output
142 where
143 Self: Sized,
144 {
145 #[cfg(all(not(feature = "tokio"), feature = "std"))]
146 return std_wait_block_on(self);
147 #[cfg(all(not(feature = "tokio"), not(feature = "std")))]
148 return nostd_wait_block_on(self);
149 #[cfg(feature = "tokio")]
150 return tokio_wait_block_on(self);
151 }
152}
153
154mod sealed {
155 pub trait Sealed {}
156}
157
158pub mod prelude {
159 pub use super::Waitable as _;
166}
167
168#[cfg(test)]
169mod tests {
170 use super::prelude::*;
171
172 async fn add(a: usize, b: usize) -> usize {
173 a + b
174 }
175
176 async fn mul(a: usize, b: usize) -> usize {
177 let mut result = 0;
178 for _ in 0..a {
179 result = add(result, b).await;
180 }
181 result
182 }
183
184 #[test]
185 fn test_single_level() {
186 let result = add(2, 2).wait();
187 assert_eq!(result, 4);
188 }
189
190 #[test]
191 fn test_sequential_calls() {
192 let result1 = add(1, 2).wait();
193 let result2 = add(2, 3).wait();
194
195 assert_eq!(result1, 3);
196 assert_eq!(result2, 5);
197 }
198
199 #[test]
200 fn test_nested_calls() {
201 let result = mul(2, 3).wait();
202
203 assert_eq!(result, 6);
204 }
205
206 #[cfg(feature = "tokio")]
208 #[test]
209 fn test_on_future_that_requires_tokio() {
210 let response = reqwest::get("https://www.rust-lang.org").wait().unwrap();
211 assert!(response.status().is_success());
212 }
213
214 #[cfg(feature = "tokio")]
215 #[test]
216 #[should_panic]
217 fn test_inside_single_thread_tokio_runtime() {
218 let runtime = tokio::runtime::Builder::new_current_thread()
219 .build()
220 .unwrap();
221
222 let _ = runtime
223 .block_on(async { reqwest::get("https://www.rust-lang.org").wait() })
224 .is_err();
225 }
226
227 #[cfg(feature = "tokio")]
228 #[test]
229 fn test_inside_multi_thread_tokio_runtime_with_no_timers_or_io() {
230 let response = tokio::runtime::Builder::new_multi_thread()
231 .build()
232 .unwrap()
233 .block_on(async { reqwest::get("https://www.rust-lang.org").wait() })
234 .unwrap();
235
236 assert!(response.status().is_success());
237 }
238}