async_rs/implementors/
tokio.rs

1//! tokio implementation of async runtime definition traits
2
3use crate::{AsyncIOHandle, Executor, IOHandle, Reactor, Runtime, RuntimeKit, Task, sys::IO};
4use async_trait::async_trait;
5use cfg_if::cfg_if;
6use futures_core::Stream;
7use std::{
8    future::Future,
9    io,
10    net::SocketAddr,
11    pin::Pin,
12    task::{Context, Poll},
13    time::{Duration, Instant},
14};
15use tokio::{net::TcpStream, runtime::Handle};
16use tokio_stream::{StreamExt, wrappers::IntervalStream};
17use tokio_util::compat::TokioAsyncReadCompatExt;
18
19/// Type alias for the tokio runtime
20pub type TokioRuntime = Runtime<Tokio>;
21
22impl TokioRuntime {
23    /// Create a new TokioRuntime and bind it to the current tokio runtime by default.
24    pub fn tokio() -> Self {
25        Self::new(Tokio::current())
26    }
27
28    /// Create a new TokioRuntime and bind it to the tokio runtime associated to this handle by default.
29    pub fn tokio_with_handle(handle: Handle) -> Self {
30        Self::new(Tokio::default().with_handle(handle))
31    }
32}
33
34/// Dummy object implementing async common interfaces on top of tokio
35#[derive(Default, Debug, Clone)]
36pub struct Tokio {
37    handle: Option<Handle>,
38}
39
40impl Tokio {
41    /// Bind to the tokio Runtime associated to this handle by default.
42    pub fn with_handle(mut self, handle: Handle) -> Self {
43        self.handle = Some(handle);
44        self
45    }
46
47    /// Bind to the current tokio Runtime by default.
48    pub fn current() -> Self {
49        Self::default().with_handle(Handle::current())
50    }
51
52    pub(crate) fn handle(&self) -> Option<Handle> {
53        Handle::try_current().ok().or_else(|| self.handle.clone())
54    }
55}
56
57struct TTask<T: Send>(Option<tokio::task::JoinHandle<T>>);
58
59impl RuntimeKit for Tokio {}
60
61impl Executor for Tokio {
62    fn block_on<T, F: Future<Output = T>>(&self, f: F) -> T {
63        if let Some(handle) = self.handle() {
64            handle.block_on(f)
65        } else {
66            Handle::current().block_on(f)
67        }
68    }
69
70    fn spawn<T: Send + 'static>(
71        &self,
72        f: impl Future<Output = T> + Send + 'static,
73    ) -> impl Task<T> {
74        TTask(Some(if let Some(handle) = self.handle() {
75            handle.spawn(f)
76        } else {
77            tokio::task::spawn(f)
78        }))
79    }
80
81    fn spawn_blocking<F: FnOnce() -> T + Send + 'static, T: Send + 'static>(
82        &self,
83        f: F,
84    ) -> impl Task<T> {
85        TTask(Some(if let Some(handle) = self.handle() {
86            handle.spawn_blocking(f)
87        } else {
88            tokio::task::spawn_blocking(f)
89        }))
90    }
91}
92
93#[async_trait(?Send)]
94impl<T: Send> Task<T> for TTask<T> {
95    async fn cancel(&mut self) -> Option<T> {
96        let task = self.0.take()?;
97        task.abort();
98        task.await.ok()
99    }
100}
101
102impl<T: Send> Future for TTask<T> {
103    type Output = T;
104
105    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
106        let task = self.0.as_mut().expect("task has been canceled");
107        match Pin::new(task).poll(cx) {
108            Poll::Pending => Poll::Pending,
109            Poll::Ready(res) => Poll::Ready(res.expect("task has been canceled")),
110        }
111    }
112}
113
114#[async_trait]
115impl Reactor for Tokio {
116    fn register<H: IO + Send + 'static>(
117        &self,
118        socket: IOHandle<H>,
119    ) -> io::Result<impl AsyncIOHandle + Send> {
120        let _enter = self.handle().as_ref().map(|handle| handle.enter());
121        cfg_if! {
122            if #[cfg(unix)] {
123                Ok(Box::new(unix::AsyncFdWrapper(
124                    tokio::io::unix::AsyncFd::new(socket)?,
125                )))
126            } else {
127                Err::<windows::Dummy, _>(io::Error::other(
128                    "Registering FD on tokio reactor is only supported on unix",
129                ))
130            }
131        }
132    }
133
134    fn sleep(&self, dur: Duration) -> impl Future<Output = ()> {
135        tokio::time::sleep(dur)
136    }
137
138    fn interval(&self, dur: Duration) -> impl Stream<Item = Instant> {
139        let _enter = self.handle().as_ref().map(|handle| handle.enter());
140        Box::new(
141            IntervalStream::new(tokio::time::interval(dur)).map(tokio::time::Instant::into_std),
142        )
143    }
144
145    async fn tcp_connect(&self, addr: SocketAddr) -> io::Result<impl AsyncIOHandle + Send> {
146        // We cannot do that as EnterGuard is not Send (which makes sense)
147        // let _enter = self.handle().as_ref().map(|handle| handle.enter());
148        Ok(TcpStream::connect(addr).await?.compat())
149    }
150}
151
152#[cfg(unix)]
153mod unix {
154    use super::*;
155    use futures_io::{AsyncRead, AsyncWrite};
156    use std::io::{IoSlice, IoSliceMut, Read, Write};
157    use tokio::io::unix::AsyncFd;
158
159    pub(super) struct AsyncFdWrapper<H: IO + Send + 'static>(pub(super) AsyncFd<IOHandle<H>>);
160
161    impl<H: IO + Send + 'static> AsyncFdWrapper<H> {
162        fn read<F: FnOnce(&mut AsyncFd<IOHandle<H>>) -> io::Result<usize>>(
163            mut self: Pin<&mut Self>,
164            cx: &mut Context<'_>,
165            f: F,
166        ) -> Option<Poll<io::Result<usize>>> {
167            Some(match self.0.poll_read_ready_mut(cx) {
168                Poll::Pending => Poll::Pending,
169                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
170                Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
171                    Ok(res) => Poll::Ready(res),
172                    Err(_) => return None,
173                },
174            })
175        }
176
177        fn write<R, F: FnOnce(&mut AsyncFd<IOHandle<H>>) -> io::Result<R>>(
178            mut self: Pin<&mut Self>,
179            cx: &mut Context<'_>,
180            f: F,
181        ) -> Option<Poll<io::Result<R>>> {
182            Some(match self.0.poll_write_ready_mut(cx) {
183                Poll::Pending => Poll::Pending,
184                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
185                Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
186                    Ok(res) => Poll::Ready(res),
187                    Err(_) => return None,
188                },
189            })
190        }
191    }
192
193    impl<H: IO + Send + 'static> Unpin for AsyncFdWrapper<H> {}
194
195    impl<H: IO + Send + 'static> AsyncRead for AsyncFdWrapper<H> {
196        fn poll_read(
197            mut self: Pin<&mut Self>,
198            cx: &mut Context<'_>,
199            buf: &mut [u8],
200        ) -> Poll<io::Result<usize>> {
201            loop {
202                if let Some(res) = self.as_mut().read(cx, |socket| socket.get_mut().read(buf)) {
203                    return res;
204                }
205            }
206        }
207
208        fn poll_read_vectored(
209            mut self: Pin<&mut Self>,
210            cx: &mut Context<'_>,
211            bufs: &mut [IoSliceMut<'_>],
212        ) -> Poll<io::Result<usize>> {
213            loop {
214                if let Some(res) = self
215                    .as_mut()
216                    .read(cx, |socket| socket.get_mut().read_vectored(bufs))
217                {
218                    return res;
219                }
220            }
221        }
222    }
223
224    impl<H: IO + Send + 'static> AsyncWrite for AsyncFdWrapper<H> {
225        fn poll_write(
226            mut self: Pin<&mut Self>,
227            cx: &mut Context<'_>,
228            buf: &[u8],
229        ) -> Poll<io::Result<usize>> {
230            loop {
231                if let Some(res) = self
232                    .as_mut()
233                    .write(cx, |socket| socket.get_mut().write(buf))
234                {
235                    return res;
236                }
237            }
238        }
239
240        fn poll_write_vectored(
241            mut self: Pin<&mut Self>,
242            cx: &mut Context<'_>,
243            bufs: &[IoSlice<'_>],
244        ) -> Poll<io::Result<usize>> {
245            loop {
246                if let Some(res) = self
247                    .as_mut()
248                    .write(cx, |socket| socket.get_mut().write_vectored(bufs))
249                {
250                    return res;
251                }
252            }
253        }
254
255        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
256            loop {
257                if let Some(res) = self.as_mut().write(cx, |socket| socket.get_mut().flush()) {
258                    return res;
259                }
260            }
261        }
262
263        fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
264            self.poll_flush(cx)
265        }
266    }
267}
268
269#[cfg(windows)]
270mod windows {
271    use super::*;
272    use futures_io::{AsyncRead, AsyncWrite};
273
274    pub(super) struct Dummy;
275
276    impl AsyncRead for Dummy {
277        fn poll_read(
278            self: Pin<&mut Self>,
279            cx: &mut Context<'_>,
280            buf: &mut [u8],
281        ) -> Poll<io::Result<usize>> {
282            Poll::Pending
283        }
284    }
285
286    impl AsyncWrite for Dummy {
287        fn poll_write(
288            self: Pin<&mut Self>,
289            cx: &mut Context<'_>,
290            buf: &[u8],
291        ) -> Poll<io::Result<usize>> {
292            Poll::Pending
293        }
294
295        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
296            Poll::Pending
297        }
298
299        fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
300            Poll::Pending
301        }
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    #[test]
310    fn dyn_compat() {
311        struct Test {
312            _executor: Box<dyn Executor>,
313            _reactor: Box<dyn Reactor>,
314            _kit: Box<dyn RuntimeKit>,
315            _task: Box<dyn Task<String>>,
316        }
317
318        let _ = Test {
319            _executor: Box::new(Tokio::default()),
320            _reactor: Box::new(Tokio::default()),
321            _kit: Box::new(Tokio::default()),
322            _task: Box::new(TTask(None)),
323        };
324    }
325}