1use crate::Error;
4use crate::Stream;
5use crate::{AsyncRead, AsyncReadSeek, AsyncSeek, AsyncWrite};
6use futures_util::future::poll_fn;
7use once_cell::sync::Lazy;
8use std::future::Future;
9use std::io;
10use std::net::SocketAddr;
11use std::pin::Pin;
12use std::sync::Mutex;
13use std::task::Context;
14use std::task::Poll;
15use std::time::Duration;
16
17use tokio::runtime::Runtime as TokioRuntime;
18
19#[allow(clippy::needless_doctest_main)]
20#[derive(Debug)]
39#[allow(clippy::large_enum_variant)]
40pub enum AsyncRuntime {
41 TokioSingle,
43 TokioShared,
57 TokioOwned(TokioRuntime),
75}
76
77#[derive(Clone, Copy, Debug, PartialEq, Eq)]
78#[allow(unused)]
79enum Inner {
80 TokioSingle,
81 TokioShared,
82 TokioOwned,
83}
84
85#[cfg(feature = "server")]
86#[allow(dead_code)]
87pub(crate) enum Listener {
88 Tokio(tokio::net::TcpListener),
89}
90
91#[cfg(feature = "server")]
92impl Listener {
93 pub async fn accept(&mut self) -> Result<(impl Stream, SocketAddr), Error> {
94 use Listener::*;
95 Ok(match self {
96 Tokio(v) => {
97 let (t, a) = v.accept().await?;
98 (crate::tokio_conv::from_tokio(t), a)
99 }
100 })
101 }
102
103 pub fn local_addr(&self) -> io::Result<SocketAddr> {
104 match self {
105 Listener::Tokio(l) => l.local_addr(),
106 }
107 }
108}
109
110static CURRENT_RUNTIME: Lazy<Mutex<Inner>> = Lazy::new(|| {
111 let rt = if tokio::runtime::Handle::try_current().ok().is_some() {
112 trace!("Shared tokio runtime detected");
113 async_tokio::use_shared();
114 Inner::TokioShared
115 } else {
116 async_tokio::use_default();
117 Inner::TokioSingle
118 };
119
120 trace!("Default runtime: {:?}", rt);
121
122 Mutex::new(rt)
123});
124
125fn current() -> Inner {
126 *CURRENT_RUNTIME.lock().unwrap()
127}
128
129impl AsyncRuntime {
130 fn into_inner(self) -> Inner {
131 match self {
132 AsyncRuntime::TokioSingle => {
133 async_tokio::use_default();
134 Inner::TokioSingle
135 }
136 AsyncRuntime::TokioShared => {
137 async_tokio::use_shared();
138 Inner::TokioShared
139 }
140 AsyncRuntime::TokioOwned(rt) => {
141 async_tokio::use_owned(rt);
142 Inner::TokioOwned
143 }
144 }
145 }
146
147 pub fn make_default(self) {
149 let mut current = CURRENT_RUNTIME.lock().unwrap();
150
151 trace!("Set runtime: {:?}", self);
152
153 let inner = self.into_inner();
154 *current = inner;
155 }
156
157 pub(crate) async fn connect_tcp(addr: &str) -> Result<impl Stream, Error> {
158 use Inner::*;
159 Ok(match current() {
160 TokioSingle | TokioShared | TokioOwned => async_tokio::connect_tcp(addr).await?,
161 })
162 }
163
164 pub(crate) async fn timeout(duration: Duration) {
165 use Inner::*;
166 match current() {
167 TokioSingle | TokioShared | TokioOwned => async_tokio::timeout(duration).await,
168 }
169 }
170
171 #[doc(hidden)]
172 pub fn spawn<T: Future + Send + 'static>(task: T) {
173 use Inner::*;
174 match current() {
175 TokioSingle | TokioShared | TokioOwned => async_tokio::spawn(task),
176 }
177 }
178
179 pub(crate) fn block_on<F: Future>(task: F) -> F::Output {
180 use Inner::*;
181 match current() {
182 TokioSingle | TokioShared | TokioOwned => async_tokio::block_on(task),
183 }
184 }
185
186 #[cfg(feature = "server")]
187 pub(crate) async fn listen(addr: SocketAddr) -> Result<Listener, Error> {
188 use Inner::*;
189 match current() {
190 TokioSingle | TokioShared | TokioOwned => async_tokio::listen(addr).await,
191 }
192 }
193
194 pub(crate) fn file_to_reader(file: std::fs::File) -> impl AsyncReadSeek {
195 use Inner::*;
196 match current() {
197 TokioSingle | TokioShared | TokioOwned => async_tokio::file_to_reader(file),
198 }
199 }
200}
201
202pub(crate) mod async_tokio {
203 use super::*;
204 use crate::tokio_conv::from_tokio;
205 use std::sync::Mutex;
206 use tokio::net::TcpStream;
207 use tokio::runtime::Builder;
208 use tokio::runtime::Handle;
209
210 static RUNTIME: Lazy<Mutex<Option<TokioRuntime>>> = Lazy::new(|| Mutex::new(None));
211 static HANDLE: Lazy<Mutex<Option<Handle>>> = Lazy::new(|| Mutex::new(None));
212
213 fn set_singletons(handle: Handle, rt: Option<TokioRuntime>) {
214 let mut rt_handle = HANDLE.lock().unwrap();
215 *rt_handle = Some(handle);
216 let mut rt_singleton = RUNTIME.lock().unwrap();
217 *rt_singleton = rt;
218 }
219
220 fn unset_singletons() {
221 let unset = || {
222 let rt = RUNTIME.lock().unwrap().take();
223 {
224 let _ = HANDLE.lock().unwrap().take(); }
226 if let Some(rt) = rt {
227 rt.shutdown_timeout(Duration::from_millis(10));
228 }
229 };
230
231 let is_in_context = Handle::try_current().is_ok();
233
234 if is_in_context {
235 std::thread::spawn(unset).join().unwrap();
236 } else {
237 unset();
238 }
239 }
240
241 pub(crate) fn use_default() {
242 unset_singletons();
243 let (handle, rt) = create_default_runtime();
244 set_singletons(handle, Some(rt));
245 }
246 pub(crate) fn use_shared() {
247 unset_singletons();
248 let handle = Handle::current();
249 set_singletons(handle, None);
250 }
251 pub(crate) fn use_owned(rt: TokioRuntime) {
252 unset_singletons();
253 let handle = rt.handle().clone();
254 set_singletons(handle, Some(rt));
255 }
256
257 fn create_default_runtime() -> (Handle, TokioRuntime) {
258 let runtime = Builder::new_current_thread()
259 .enable_io()
260 .enable_time()
261 .build()
262 .expect("Failed to build tokio runtime");
263 let handle = runtime.handle().clone();
264 (handle, runtime)
265 }
266
267 pub(crate) async fn connect_tcp(addr: &str) -> Result<impl Stream, Error> {
268 Ok(from_tokio(TcpStream::connect(addr).await?))
269 }
270 pub(crate) async fn timeout(duration: Duration) {
271 tokio::time::sleep(duration).await;
272 }
273 pub(crate) fn spawn<T>(task: T)
274 where
275 T: Future + Send + 'static,
276 {
277 let mut handle = HANDLE.lock().unwrap();
278 handle.as_mut().unwrap().spawn(async move {
279 task.await;
280 });
281 }
282 pub(crate) fn block_on<F: Future>(task: F) -> F::Output {
283 let mut rt = RUNTIME.lock().unwrap();
284 if let Some(rt) = rt.as_mut() {
285 rt.block_on(task)
286 } else {
287 panic!("Can't use .block() with a TokioShared runtime.");
288 }
289 }
290
291 #[cfg(feature = "server")]
292 pub(crate) async fn listen(addr: SocketAddr) -> Result<Listener, Error> {
293 use tokio::net::TcpListener;
294 let listener = TcpListener::bind(addr).await?;
295 Ok(Listener::Tokio(listener))
296 }
297
298 pub(crate) fn file_to_reader(file: std::fs::File) -> impl AsyncReadSeek {
299 let file = tokio::fs::File::from_std(file);
300 from_tokio(file)
301 }
302}
303
304pub async fn never() {
306 poll_fn::<(), _>(|_| Poll::Pending).await;
307 unreachable!()
308}
309
310#[allow(unused)]
311pub(crate) struct FakeListener(SocketAddr);
312
313#[allow(unused)]
314impl FakeListener {
315 async fn accept(&mut self) -> Result<(FakeStream, SocketAddr), io::Error> {
316 Ok((FakeStream, self.0))
317 }
318
319 fn local_addr(&self) -> io::Result<SocketAddr> {
320 unreachable!("local_addr() on FakeListener");
321 }
322}
323
324struct FakeStream;
326
327impl AsyncRead for FakeStream {
328 fn poll_read(
329 self: Pin<&mut Self>,
330 _: &mut Context,
331 _: &mut [u8],
332 ) -> Poll<futures_io::Result<usize>> {
333 unreachable!()
334 }
335}
336impl AsyncWrite for FakeStream {
337 fn poll_write(
338 self: Pin<&mut Self>,
339 _: &mut Context,
340 _: &[u8],
341 ) -> Poll<futures_io::Result<usize>> {
342 unreachable!()
343 }
344 fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<futures_io::Result<()>> {
345 unreachable!()
346 }
347 fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<futures_io::Result<()>> {
348 unreachable!()
349 }
350}
351
352impl AsyncSeek for FakeStream {
353 fn poll_seek(self: Pin<&mut Self>, _: &mut Context, _: io::SeekFrom) -> Poll<io::Result<u64>> {
354 unreachable!()
355 }
356}