async_rs/implementors/
tokio.rs

1//! tokio implementation of async runtime definition traits
2
3use crate::{
4    Runtime,
5    sys::AsSysFd,
6    traits::{Executor, Reactor, RuntimeKit, Task},
7};
8use async_trait::async_trait;
9use cfg_if::cfg_if;
10use futures_core::Stream;
11use futures_io::{AsyncRead, AsyncWrite};
12use std::{
13    future::Future,
14    io::{self, Read, Write},
15    net::SocketAddr,
16    pin::Pin,
17    task::{Context, Poll},
18    time::{Duration, Instant},
19};
20use tokio::{net::TcpStream, runtime::Handle};
21use tokio_stream::{StreamExt, wrappers::IntervalStream};
22use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
23
24/// Type alias for the tokio runtime
25pub type TokioRuntime = Runtime<Tokio>;
26
27impl TokioRuntime {
28    /// Create a new TokioRuntime and bind it to the current tokio runtime by default.
29    pub fn tokio() -> Self {
30        Self::new(Tokio::current())
31    }
32
33    /// Create a new TokioRuntime and bind it to the tokio runtime associated to this handle by default.
34    pub fn tokio_with_handle(handle: Handle) -> Self {
35        Self::new(Tokio::default().with_handle(handle))
36    }
37}
38
39/// Dummy object implementing async common interfaces on top of tokio
40#[derive(Default, Debug, Clone)]
41pub struct Tokio {
42    handle: Option<Handle>,
43}
44
45impl Tokio {
46    /// Bind to the tokio Runtime associated to this handle by default.
47    pub fn with_handle(mut self, handle: Handle) -> Self {
48        self.handle = Some(handle);
49        self
50    }
51
52    /// Bind to the current tokio Runtime by default.
53    pub fn current() -> Self {
54        Self::default().with_handle(Handle::current())
55    }
56
57    pub(crate) fn handle(&self) -> Option<Handle> {
58        Handle::try_current().ok().or_else(|| self.handle.clone())
59    }
60}
61
62struct TTask<T: Send + 'static>(Option<tokio::task::JoinHandle<T>>);
63
64impl RuntimeKit for Tokio {}
65
66impl Executor for Tokio {
67    fn block_on<T, F: Future<Output = T>>(&self, f: F) -> T {
68        if let Some(handle) = self.handle() {
69            handle.block_on(f)
70        } else {
71            Handle::current().block_on(f)
72        }
73    }
74
75    fn spawn<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
76        &self,
77        f: F,
78    ) -> impl Task<T> + 'static {
79        TTask(Some(if let Some(handle) = self.handle() {
80            handle.spawn(f)
81        } else {
82            tokio::task::spawn(f)
83        }))
84    }
85
86    fn spawn_blocking<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(
87        &self,
88        f: F,
89    ) -> impl Task<T> + 'static {
90        TTask(Some(if let Some(handle) = self.handle() {
91            handle.spawn_blocking(f)
92        } else {
93            tokio::task::spawn_blocking(f)
94        }))
95    }
96}
97
98#[async_trait]
99impl<T: Send + 'static> Task<T> for TTask<T> {
100    async fn cancel(&mut self) -> Option<T> {
101        let task = self.0.take()?;
102        task.abort();
103        task.await.ok()
104    }
105}
106
107impl<T: Send + 'static> Future for TTask<T> {
108    type Output = T;
109
110    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
111        let task = self.0.as_mut().expect("task has been canceled");
112        match Pin::new(task).poll(cx) {
113            Poll::Pending => Poll::Pending,
114            Poll::Ready(res) => Poll::Ready(res.expect("task has been canceled")),
115        }
116    }
117}
118
119impl Reactor for Tokio {
120    type TcpStream = Compat<TcpStream>;
121
122    fn register<H: Read + Write + AsSysFd + Send + 'static>(
123        &self,
124        socket: H,
125    ) -> io::Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static> {
126        let _enter = self.handle().as_ref().map(|handle| handle.enter());
127        cfg_if! {
128            if #[cfg(unix)] {
129                Ok(unix::AsyncFdWrapper(
130                    tokio::io::unix::AsyncFd::new(socket)?,
131                ))
132            } else {
133                Err::<windows::Dummy, _>(io::Error::other(
134                    "Registering FD on tokio reactor is only supported on unix",
135                ))
136            }
137        }
138    }
139
140    fn sleep(&self, dur: Duration) -> impl Future<Output = ()> + Send + 'static {
141        tokio::time::sleep(dur)
142    }
143
144    fn interval(&self, dur: Duration) -> impl Stream<Item = Instant> + Send + 'static {
145        let _enter = self.handle().as_ref().map(|handle| handle.enter());
146        Box::new(
147            IntervalStream::new(tokio::time::interval(dur)).map(tokio::time::Instant::into_std),
148        )
149    }
150
151    fn tcp_connect(
152        &self,
153        addr: SocketAddr,
154    ) -> impl Future<Output = io::Result<Self::TcpStream>> + Send + 'static {
155        let _enter = self.handle().as_ref().map(|handle| handle.enter());
156        async move { Ok(TcpStream::connect(addr).await?.compat()) }
157    }
158}
159
160#[cfg(unix)]
161mod unix {
162    use super::*;
163    use futures_io::{AsyncRead, AsyncWrite};
164    use std::io::{IoSlice, IoSliceMut};
165    use tokio::io::unix::AsyncFd;
166
167    pub(super) struct AsyncFdWrapper<H: Read + Write + AsSysFd>(pub(super) AsyncFd<H>);
168
169    impl<H: Read + Write + AsSysFd> AsyncFdWrapper<H> {
170        fn read<F: FnOnce(&mut AsyncFd<H>) -> io::Result<usize>>(
171            mut self: Pin<&mut Self>,
172            cx: &mut Context<'_>,
173            f: F,
174        ) -> Option<Poll<io::Result<usize>>> {
175            Some(match self.0.poll_read_ready_mut(cx) {
176                Poll::Pending => Poll::Pending,
177                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
178                Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
179                    Ok(res) => Poll::Ready(res),
180                    Err(_) => return None,
181                },
182            })
183        }
184
185        fn write<R, F: FnOnce(&mut AsyncFd<H>) -> io::Result<R>>(
186            mut self: Pin<&mut Self>,
187            cx: &mut Context<'_>,
188            f: F,
189        ) -> Option<Poll<io::Result<R>>> {
190            Some(match self.0.poll_write_ready_mut(cx) {
191                Poll::Pending => Poll::Pending,
192                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
193                Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
194                    Ok(res) => Poll::Ready(res),
195                    Err(_) => return None,
196                },
197            })
198        }
199    }
200
201    impl<H: Read + Write + AsSysFd> Unpin for AsyncFdWrapper<H> {}
202
203    impl<H: Read + Write + AsSysFd> AsyncRead for AsyncFdWrapper<H> {
204        fn poll_read(
205            mut self: Pin<&mut Self>,
206            cx: &mut Context<'_>,
207            buf: &mut [u8],
208        ) -> Poll<io::Result<usize>> {
209            loop {
210                if let Some(res) = self.as_mut().read(cx, |socket| socket.get_mut().read(buf)) {
211                    return res;
212                }
213            }
214        }
215
216        fn poll_read_vectored(
217            mut self: Pin<&mut Self>,
218            cx: &mut Context<'_>,
219            bufs: &mut [IoSliceMut<'_>],
220        ) -> Poll<io::Result<usize>> {
221            loop {
222                if let Some(res) = self
223                    .as_mut()
224                    .read(cx, |socket| socket.get_mut().read_vectored(bufs))
225                {
226                    return res;
227                }
228            }
229        }
230    }
231
232    impl<H: Read + Write + AsSysFd + Send + 'static> AsyncWrite for AsyncFdWrapper<H> {
233        fn poll_write(
234            mut self: Pin<&mut Self>,
235            cx: &mut Context<'_>,
236            buf: &[u8],
237        ) -> Poll<io::Result<usize>> {
238            loop {
239                if let Some(res) = self
240                    .as_mut()
241                    .write(cx, |socket| socket.get_mut().write(buf))
242                {
243                    return res;
244                }
245            }
246        }
247
248        fn poll_write_vectored(
249            mut self: Pin<&mut Self>,
250            cx: &mut Context<'_>,
251            bufs: &[IoSlice<'_>],
252        ) -> Poll<io::Result<usize>> {
253            loop {
254                if let Some(res) = self
255                    .as_mut()
256                    .write(cx, |socket| socket.get_mut().write_vectored(bufs))
257                {
258                    return res;
259                }
260            }
261        }
262
263        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
264            loop {
265                if let Some(res) = self.as_mut().write(cx, |socket| socket.get_mut().flush()) {
266                    return res;
267                }
268            }
269        }
270
271        fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
272            self.poll_flush(cx)
273        }
274    }
275}
276
277#[cfg(windows)]
278mod windows {
279    use super::*;
280    use futures_io::{AsyncRead, AsyncWrite};
281
282    pub(super) struct Dummy;
283
284    impl AsyncRead for Dummy {
285        fn poll_read(
286            self: Pin<&mut Self>,
287            cx: &mut Context<'_>,
288            buf: &mut [u8],
289        ) -> Poll<io::Result<usize>> {
290            Poll::Pending
291        }
292    }
293
294    impl AsyncWrite for Dummy {
295        fn poll_write(
296            self: Pin<&mut Self>,
297            cx: &mut Context<'_>,
298            buf: &[u8],
299        ) -> Poll<io::Result<usize>> {
300            Poll::Pending
301        }
302
303        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
304            Poll::Pending
305        }
306
307        fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
308            Poll::Pending
309        }
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn dyn_compat() {
319        struct Test {
320            _executor: Box<dyn Executor>,
321            _reactor: Box<dyn Reactor<TcpStream = Compat<TcpStream>>>,
322            _kit: Box<dyn RuntimeKit<TcpStream = Compat<TcpStream>>>,
323            _task: Box<dyn Task<String>>,
324        }
325
326        let _ = Test {
327            _executor: Box::new(Tokio::default()),
328            _reactor: Box::new(Tokio::default()),
329            _kit: Box::new(Tokio::default()),
330            _task: Box::new(TTask(None)),
331        };
332    }
333}