hreq/
async_impl.rs

1//! pluggable runtimes
2
3use crate::Error;
4use crate::Stream;
5use crate::{AsyncRead, AsyncReadSeek, AsyncSeek, AsyncWrite};
6use futures_util::future::poll_fn;
7use once_cell::sync::Lazy;
8use std::future::Future;
9use std::io;
10use std::net::SocketAddr;
11use std::pin::Pin;
12use std::sync::Mutex;
13use std::task::Context;
14use std::task::Poll;
15use std::time::Duration;
16
17use tokio::runtime::Runtime as TokioRuntime;
18
19#[allow(clippy::needless_doctest_main)]
20/// Switches between different async runtimes.
21///
22/// This is a global singleton.
23///
24/// hreq supports different ways of using tokio.
25///
26///   * `TokioSingle`. The default option. A minimal tokio `rt-core`
27///     which executes calls in one single thread. It does nothing
28///     until the current thread blocks on a future using `.block()`.
29///   * `TokioShared`. Picks up on a globally shared runtime by using a
30///     [`Handle`]. This runtime cannot use the `.block()` extension
31///     trait since that requires having a direct connection to the
32///     tokio [`Runtime`].
33///   * `TokioOwned`. Uses a preconfigured tokio [`Runtime`] that is
34///     "handed over" to hreq.
35///
36/// [`Handle`]: https://docs.rs/tokio/latest/tokio/runtime/struct.Handle.html
37/// [`Runtime`]: https://docs.rs/tokio/latest/tokio/runtime/struct.Runtime.html
38#[derive(Debug)]
39#[allow(clippy::large_enum_variant)]
40pub enum AsyncRuntime {
41    /// Use a tokio `rt-core` single threaded runtime. This is the default.
42    TokioSingle,
43    /// Pick up on a tokio shared runtime.
44    ///
45    ///
46    /// # Example using a shared tokio.
47    ///
48    /// ```no_run
49    /// use hreq::AsyncRuntime;
50    ///
51    /// // assuming the current thread has some tokio runtime, such
52    /// // as using the `#[tokio::main]` macro on `fn main() { .. }`
53    ///
54    /// AsyncRuntime::TokioShared.make_default();
55    /// ```
56    TokioShared,
57    /// Use a tokio runtime owned by hreq.
58    ///
59    /// # Example using an owned tokio.
60    ///
61    /// ```
62    /// use hreq::AsyncRuntime;
63    /// // normally: use tokio::runtime::Builder;
64    /// use tokio::runtime::Builder;
65    ///
66    /// let runtime = Builder::new_multi_thread()
67    ///   .enable_io()
68    ///   .enable_time()
69    ///   .build()
70    ///   .expect("Failed to build tokio runtime");
71    ///
72    /// AsyncRuntime::TokioOwned(runtime).make_default();
73    /// ```
74    TokioOwned(TokioRuntime),
75}
76
77#[derive(Clone, Copy, Debug, PartialEq, Eq)]
78#[allow(unused)]
79enum Inner {
80    TokioSingle,
81    TokioShared,
82    TokioOwned,
83}
84
85#[cfg(feature = "server")]
86#[allow(dead_code)]
87pub(crate) enum Listener {
88    Tokio(tokio::net::TcpListener),
89}
90
91#[cfg(feature = "server")]
92impl Listener {
93    pub async fn accept(&mut self) -> Result<(impl Stream, SocketAddr), Error> {
94        use Listener::*;
95        Ok(match self {
96            Tokio(v) => {
97                let (t, a) = v.accept().await?;
98                (crate::tokio_conv::from_tokio(t), a)
99            }
100        })
101    }
102
103    pub fn local_addr(&self) -> io::Result<SocketAddr> {
104        match self {
105            Listener::Tokio(l) => l.local_addr(),
106        }
107    }
108}
109
110static CURRENT_RUNTIME: Lazy<Mutex<Inner>> = Lazy::new(|| {
111    let rt = if tokio::runtime::Handle::try_current().ok().is_some() {
112        trace!("Shared tokio runtime detected");
113        async_tokio::use_shared();
114        Inner::TokioShared
115    } else {
116        async_tokio::use_default();
117        Inner::TokioSingle
118    };
119
120    trace!("Default runtime: {:?}", rt);
121
122    Mutex::new(rt)
123});
124
125fn current() -> Inner {
126    *CURRENT_RUNTIME.lock().unwrap()
127}
128
129impl AsyncRuntime {
130    fn into_inner(self) -> Inner {
131        match self {
132            AsyncRuntime::TokioSingle => {
133                async_tokio::use_default();
134                Inner::TokioSingle
135            }
136            AsyncRuntime::TokioShared => {
137                async_tokio::use_shared();
138                Inner::TokioShared
139            }
140            AsyncRuntime::TokioOwned(rt) => {
141                async_tokio::use_owned(rt);
142                Inner::TokioOwned
143            }
144        }
145    }
146
147    /// Make this runtime the default.
148    pub fn make_default(self) {
149        let mut current = CURRENT_RUNTIME.lock().unwrap();
150
151        trace!("Set runtime: {:?}", self);
152
153        let inner = self.into_inner();
154        *current = inner;
155    }
156
157    pub(crate) async fn connect_tcp(addr: &str) -> Result<impl Stream, Error> {
158        use Inner::*;
159        Ok(match current() {
160            TokioSingle | TokioShared | TokioOwned => async_tokio::connect_tcp(addr).await?,
161        })
162    }
163
164    pub(crate) async fn timeout(duration: Duration) {
165        use Inner::*;
166        match current() {
167            TokioSingle | TokioShared | TokioOwned => async_tokio::timeout(duration).await,
168        }
169    }
170
171    #[doc(hidden)]
172    pub fn spawn<T: Future + Send + 'static>(task: T) {
173        use Inner::*;
174        match current() {
175            TokioSingle | TokioShared | TokioOwned => async_tokio::spawn(task),
176        }
177    }
178
179    pub(crate) fn block_on<F: Future>(task: F) -> F::Output {
180        use Inner::*;
181        match current() {
182            TokioSingle | TokioShared | TokioOwned => async_tokio::block_on(task),
183        }
184    }
185
186    #[cfg(feature = "server")]
187    pub(crate) async fn listen(addr: SocketAddr) -> Result<Listener, Error> {
188        use Inner::*;
189        match current() {
190            TokioSingle | TokioShared | TokioOwned => async_tokio::listen(addr).await,
191        }
192    }
193
194    pub(crate) fn file_to_reader(file: std::fs::File) -> impl AsyncReadSeek {
195        use Inner::*;
196        match current() {
197            TokioSingle | TokioShared | TokioOwned => async_tokio::file_to_reader(file),
198        }
199    }
200}
201
202pub(crate) mod async_tokio {
203    use super::*;
204    use crate::tokio_conv::from_tokio;
205    use std::sync::Mutex;
206    use tokio::net::TcpStream;
207    use tokio::runtime::Builder;
208    use tokio::runtime::Handle;
209
210    static RUNTIME: Lazy<Mutex<Option<TokioRuntime>>> = Lazy::new(|| Mutex::new(None));
211    static HANDLE: Lazy<Mutex<Option<Handle>>> = Lazy::new(|| Mutex::new(None));
212
213    fn set_singletons(handle: Handle, rt: Option<TokioRuntime>) {
214        let mut rt_handle = HANDLE.lock().unwrap();
215        *rt_handle = Some(handle);
216        let mut rt_singleton = RUNTIME.lock().unwrap();
217        *rt_singleton = rt;
218    }
219
220    fn unset_singletons() {
221        let unset = || {
222            let rt = RUNTIME.lock().unwrap().take();
223            {
224                let _ = HANDLE.lock().unwrap().take(); // go out of scope
225            }
226            if let Some(rt) = rt {
227                rt.shutdown_timeout(Duration::from_millis(10));
228            }
229        };
230
231        // this fails if we are currently running in a tokio context.
232        let is_in_context = Handle::try_current().is_ok();
233
234        if is_in_context {
235            std::thread::spawn(unset).join().unwrap();
236        } else {
237            unset();
238        }
239    }
240
241    pub(crate) fn use_default() {
242        unset_singletons();
243        let (handle, rt) = create_default_runtime();
244        set_singletons(handle, Some(rt));
245    }
246    pub(crate) fn use_shared() {
247        unset_singletons();
248        let handle = Handle::current();
249        set_singletons(handle, None);
250    }
251    pub(crate) fn use_owned(rt: TokioRuntime) {
252        unset_singletons();
253        let handle = rt.handle().clone();
254        set_singletons(handle, Some(rt));
255    }
256
257    fn create_default_runtime() -> (Handle, TokioRuntime) {
258        let runtime = Builder::new_current_thread()
259            .enable_io()
260            .enable_time()
261            .build()
262            .expect("Failed to build tokio runtime");
263        let handle = runtime.handle().clone();
264        (handle, runtime)
265    }
266
267    pub(crate) async fn connect_tcp(addr: &str) -> Result<impl Stream, Error> {
268        Ok(from_tokio(TcpStream::connect(addr).await?))
269    }
270    pub(crate) async fn timeout(duration: Duration) {
271        tokio::time::sleep(duration).await;
272    }
273    pub(crate) fn spawn<T>(task: T)
274    where
275        T: Future + Send + 'static,
276    {
277        let mut handle = HANDLE.lock().unwrap();
278        handle.as_mut().unwrap().spawn(async move {
279            task.await;
280        });
281    }
282    pub(crate) fn block_on<F: Future>(task: F) -> F::Output {
283        let mut rt = RUNTIME.lock().unwrap();
284        if let Some(rt) = rt.as_mut() {
285            rt.block_on(task)
286        } else {
287            panic!("Can't use .block() with a TokioShared runtime.");
288        }
289    }
290
291    #[cfg(feature = "server")]
292    pub(crate) async fn listen(addr: SocketAddr) -> Result<Listener, Error> {
293        use tokio::net::TcpListener;
294        let listener = TcpListener::bind(addr).await?;
295        Ok(Listener::Tokio(listener))
296    }
297
298    pub(crate) fn file_to_reader(file: std::fs::File) -> impl AsyncReadSeek {
299        let file = tokio::fs::File::from_std(file);
300        from_tokio(file)
301    }
302}
303
304// TODO does this cause memory leaks?
305pub async fn never() {
306    poll_fn::<(), _>(|_| Poll::Pending).await;
307    unreachable!()
308}
309
310#[allow(unused)]
311pub(crate) struct FakeListener(SocketAddr);
312
313#[allow(unused)]
314impl FakeListener {
315    async fn accept(&mut self) -> Result<(FakeStream, SocketAddr), io::Error> {
316        Ok((FakeStream, self.0))
317    }
318
319    fn local_addr(&self) -> io::Result<SocketAddr> {
320        unreachable!("local_addr() on FakeListener");
321    }
322}
323
324// filler in for "impl Stream" type
325struct FakeStream;
326
327impl AsyncRead for FakeStream {
328    fn poll_read(
329        self: Pin<&mut Self>,
330        _: &mut Context,
331        _: &mut [u8],
332    ) -> Poll<futures_io::Result<usize>> {
333        unreachable!()
334    }
335}
336impl AsyncWrite for FakeStream {
337    fn poll_write(
338        self: Pin<&mut Self>,
339        _: &mut Context,
340        _: &[u8],
341    ) -> Poll<futures_io::Result<usize>> {
342        unreachable!()
343    }
344    fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<futures_io::Result<()>> {
345        unreachable!()
346    }
347    fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<futures_io::Result<()>> {
348        unreachable!()
349    }
350}
351
352impl AsyncSeek for FakeStream {
353    fn poll_seek(self: Pin<&mut Self>, _: &mut Context, _: io::SeekFrom) -> Poll<io::Result<u64>> {
354        unreachable!()
355    }
356}