async_rs/implementors/
tokio.rs

1//! tokio implementation of async runtime definition traits
2
3use crate::{
4    Runtime,
5    sys::AsSysFd,
6    traits::{Executor, Reactor, RuntimeKit, Task},
7};
8use async_trait::async_trait;
9use cfg_if::cfg_if;
10use futures_core::Stream;
11use futures_io::{AsyncRead, AsyncWrite};
12use std::{
13    future::Future,
14    io::{self, Read, Write},
15    net::SocketAddr,
16    pin::Pin,
17    task::{Context, Poll},
18    time::{Duration, Instant},
19};
20use tokio::{
21    net::TcpStream,
22    runtime::{EnterGuard, Handle, Runtime as TokioRT},
23};
24use tokio_stream::{StreamExt, wrappers::IntervalStream};
25use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
26
27/// Type alias for the tokio runtime
28pub type TokioRuntime = Runtime<Tokio>;
29
30impl TokioRuntime {
31    /// Create a new TokioRuntime and bind it to this tokio runtime.
32    pub fn tokio() -> io::Result<Self> {
33        Ok(Self::tokio_with_runtime(TokioRT::new()?))
34    }
35
36    /// Create a new TokioRuntime and bind it to the current tokio runtime by default.
37    pub fn tokio_current() -> Self {
38        Self::new(Tokio::current())
39    }
40
41    /// Create a new TokioRuntime and bind it to the tokio runtime associated to this handle by default.
42    pub fn tokio_with_handle(handle: Handle) -> Self {
43        Self::new(Tokio::default().with_handle(handle))
44    }
45
46    /// Create a new TokioRuntime and bind it to this tokio runtime.
47    pub fn tokio_with_runtime(runtime: TokioRT) -> Self {
48        Self::new(Tokio::default().with_runtime(runtime))
49    }
50}
51
52/// Dummy object implementing async common interfaces on top of tokio
53#[derive(Default, Debug)]
54pub struct Tokio {
55    handle: Option<Handle>,
56    runtime: Option<TokioRT>,
57}
58
59impl Tokio {
60    /// Bind to the tokio Runtime associated to this handle by default.
61    pub fn with_handle(mut self, handle: Handle) -> Self {
62        self.handle = Some(handle);
63        self
64    }
65
66    /// Bind to the tokio Runtime associated to this handle by default.
67    pub fn with_runtime(mut self, runtime: TokioRT) -> Self {
68        let handle = runtime.handle().clone();
69        self.runtime = Some(runtime);
70        self.with_handle(handle)
71    }
72
73    /// Bind to the current tokio Runtime by default.
74    pub fn current() -> Self {
75        Self::default().with_handle(Handle::current())
76    }
77
78    fn handle(&self) -> Option<Handle> {
79        self.runtime
80            .as_ref()
81            .map(|r| r.handle().clone())
82            .or_else(|| Handle::try_current().ok())
83            .or_else(|| self.handle.clone())
84    }
85
86    fn enter(&self) -> Option<EnterGuard<'_>> {
87        self.runtime
88            .as_ref()
89            .map(TokioRT::handle)
90            .or(self.handle.as_ref())
91            .map(Handle::enter)
92    }
93}
94
95struct TTask<T: Send + 'static>(Option<tokio::task::JoinHandle<T>>);
96
97impl RuntimeKit for Tokio {}
98
99impl Executor for Tokio {
100    fn block_on<T, F: Future<Output = T>>(&self, f: F) -> T {
101        if let Some(runtime) = self.runtime.as_ref() {
102            runtime.block_on(f)
103        } else if let Some(handle) = self.handle() {
104            handle.block_on(f)
105        } else {
106            Handle::current().block_on(f)
107        }
108    }
109
110    fn spawn<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
111        &self,
112        f: F,
113    ) -> impl Task<T> + 'static {
114        TTask(Some(if let Some(handle) = self.handle() {
115            handle.spawn(f)
116        } else {
117            tokio::task::spawn(f)
118        }))
119    }
120
121    fn spawn_blocking<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(
122        &self,
123        f: F,
124    ) -> impl Task<T> + 'static {
125        TTask(Some(if let Some(handle) = self.handle() {
126            handle.spawn_blocking(f)
127        } else {
128            tokio::task::spawn_blocking(f)
129        }))
130    }
131}
132
133#[async_trait]
134impl<T: Send + 'static> Task<T> for TTask<T> {
135    async fn cancel(&mut self) -> Option<T> {
136        let task = self.0.take()?;
137        task.abort();
138        task.await.ok()
139    }
140}
141
142impl<T: Send + 'static> Future for TTask<T> {
143    type Output = T;
144
145    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
146        let task = self.0.as_mut().expect("task has been canceled");
147        match Pin::new(task).poll(cx) {
148            Poll::Pending => Poll::Pending,
149            Poll::Ready(res) => Poll::Ready(res.expect("task has been canceled")),
150        }
151    }
152}
153
154impl Reactor for Tokio {
155    type TcpStream = Compat<TcpStream>;
156
157    fn register<H: Read + Write + AsSysFd + Send + 'static>(
158        &self,
159        socket: H,
160    ) -> io::Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static> {
161        let _enter = self.enter();
162        cfg_if! {
163            if #[cfg(unix)] {
164                Ok(unix::AsyncFdWrapper(
165                    tokio::io::unix::AsyncFd::new(socket)?,
166                ))
167            } else {
168                Err::<windows::Dummy, _>(io::Error::other(
169                    "Registering FD on tokio reactor is only supported on unix",
170                ))
171            }
172        }
173    }
174
175    fn sleep(&self, dur: Duration) -> impl Future<Output = ()> + Send + 'static {
176        let _enter = self.enter();
177        tokio::time::sleep(dur)
178    }
179
180    fn interval(&self, dur: Duration) -> impl Stream<Item = Instant> + Send + 'static {
181        let _enter = self.enter();
182        IntervalStream::new(tokio::time::interval(dur)).map(tokio::time::Instant::into_std)
183    }
184
185    fn tcp_connect_addr(
186        &self,
187        addr: SocketAddr,
188    ) -> impl Future<Output = io::Result<Self::TcpStream>> + Send + 'static {
189        let _enter = self.enter();
190        async move { Ok(TcpStream::connect(addr).await?.compat()) }
191    }
192}
193
194#[cfg(unix)]
195mod unix {
196    use super::*;
197    use futures_io::{AsyncRead, AsyncWrite};
198    use std::io::{IoSlice, IoSliceMut};
199    use tokio::io::unix::AsyncFd;
200
201    pub(super) struct AsyncFdWrapper<H: Read + Write + AsSysFd>(pub(super) AsyncFd<H>);
202
203    impl<H: Read + Write + AsSysFd> AsyncFdWrapper<H> {
204        fn read<F: FnOnce(&mut AsyncFd<H>) -> io::Result<usize>>(
205            mut self: Pin<&mut Self>,
206            cx: &mut Context<'_>,
207            f: F,
208        ) -> Option<Poll<io::Result<usize>>> {
209            Some(match self.0.poll_read_ready_mut(cx) {
210                Poll::Pending => Poll::Pending,
211                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
212                Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
213                    Ok(res) => Poll::Ready(res),
214                    Err(_) => return None,
215                },
216            })
217        }
218
219        fn write<R, F: FnOnce(&mut AsyncFd<H>) -> io::Result<R>>(
220            mut self: Pin<&mut Self>,
221            cx: &mut Context<'_>,
222            f: F,
223        ) -> Option<Poll<io::Result<R>>> {
224            Some(match self.0.poll_write_ready_mut(cx) {
225                Poll::Pending => Poll::Pending,
226                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
227                Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
228                    Ok(res) => Poll::Ready(res),
229                    Err(_) => return None,
230                },
231            })
232        }
233    }
234
235    impl<H: Read + Write + AsSysFd> Unpin for AsyncFdWrapper<H> {}
236
237    impl<H: Read + Write + AsSysFd> AsyncRead for AsyncFdWrapper<H> {
238        fn poll_read(
239            mut self: Pin<&mut Self>,
240            cx: &mut Context<'_>,
241            buf: &mut [u8],
242        ) -> Poll<io::Result<usize>> {
243            loop {
244                if let Some(res) = self.as_mut().read(cx, |socket| socket.get_mut().read(buf)) {
245                    return res;
246                }
247            }
248        }
249
250        fn poll_read_vectored(
251            mut self: Pin<&mut Self>,
252            cx: &mut Context<'_>,
253            bufs: &mut [IoSliceMut<'_>],
254        ) -> Poll<io::Result<usize>> {
255            loop {
256                if let Some(res) = self
257                    .as_mut()
258                    .read(cx, |socket| socket.get_mut().read_vectored(bufs))
259                {
260                    return res;
261                }
262            }
263        }
264    }
265
266    impl<H: Read + Write + AsSysFd + Send + 'static> AsyncWrite for AsyncFdWrapper<H> {
267        fn poll_write(
268            mut self: Pin<&mut Self>,
269            cx: &mut Context<'_>,
270            buf: &[u8],
271        ) -> Poll<io::Result<usize>> {
272            loop {
273                if let Some(res) = self
274                    .as_mut()
275                    .write(cx, |socket| socket.get_mut().write(buf))
276                {
277                    return res;
278                }
279            }
280        }
281
282        fn poll_write_vectored(
283            mut self: Pin<&mut Self>,
284            cx: &mut Context<'_>,
285            bufs: &[IoSlice<'_>],
286        ) -> Poll<io::Result<usize>> {
287            loop {
288                if let Some(res) = self
289                    .as_mut()
290                    .write(cx, |socket| socket.get_mut().write_vectored(bufs))
291                {
292                    return res;
293                }
294            }
295        }
296
297        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
298            loop {
299                if let Some(res) = self.as_mut().write(cx, |socket| socket.get_mut().flush()) {
300                    return res;
301                }
302            }
303        }
304
305        fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
306            self.poll_flush(cx)
307        }
308    }
309}
310
311#[cfg(windows)]
312mod windows {
313    use super::*;
314    use futures_io::{AsyncRead, AsyncWrite};
315
316    pub(super) struct Dummy;
317
318    impl AsyncRead for Dummy {
319        fn poll_read(
320            self: Pin<&mut Self>,
321            cx: &mut Context<'_>,
322            buf: &mut [u8],
323        ) -> Poll<io::Result<usize>> {
324            Poll::Pending
325        }
326    }
327
328    impl AsyncWrite for Dummy {
329        fn poll_write(
330            self: Pin<&mut Self>,
331            cx: &mut Context<'_>,
332            buf: &[u8],
333        ) -> Poll<io::Result<usize>> {
334            Poll::Pending
335        }
336
337        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
338            Poll::Pending
339        }
340
341        fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
342            Poll::Pending
343        }
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350
351    #[test]
352    fn dyn_compat() {
353        struct Test {
354            _executor: Box<dyn Executor>,
355            _reactor: Box<dyn Reactor<TcpStream = Compat<TcpStream>>>,
356            _kit: Box<dyn RuntimeKit<TcpStream = Compat<TcpStream>>>,
357            _task: Box<dyn Task<String>>,
358        }
359
360        let _ = Test {
361            _executor: Box::new(Tokio::default()),
362            _reactor: Box::new(Tokio::default()),
363            _kit: Box::new(Tokio::default()),
364            _task: Box::new(TTask(None)),
365        };
366    }
367}