Skip to main content

async_rs/implementors/
tokio.rs

1//! tokio implementation of async runtime definition traits
2
3use crate::{
4    Runtime,
5    sys::AsSysFd,
6    traits::{Executor, Reactor, RuntimeKit},
7    util::Task,
8};
9use async_compat::{Compat, CompatExt};
10use futures_core::Stream;
11use futures_io::{AsyncRead, AsyncWrite};
12use std::{
13    future::Future,
14    io::{self, Read, Write},
15    net::SocketAddr,
16    sync::Arc,
17    time::{Duration, Instant},
18};
19use tokio::{
20    net::TcpStream,
21    runtime::{EnterGuard, Handle, Runtime as TokioRT},
22    time::Sleep,
23};
24use tokio_stream::{StreamExt, wrappers::IntervalStream};
25
26use task::TTask;
27
28/// Type alias for the tokio runtime
29pub type TokioRuntime = Runtime<Tokio>;
30
31impl TokioRuntime {
32    /// Create a new TokioRuntime backed by a freshly created tokio multi-threaded runtime.
33    pub fn tokio() -> io::Result<Self> {
34        Ok(Self::tokio_with_runtime(TokioRT::new()?))
35    }
36
37    /// Create a new TokioRuntime and bind it to the current tokio runtime by default.
38    #[must_use]
39    pub fn tokio_current() -> Self {
40        Self::new(Tokio::current())
41    }
42
43    /// Create a new TokioRuntime and bind it to the tokio runtime associated to this handle by default.
44    #[must_use]
45    pub fn tokio_with_handle(handle: Handle) -> Self {
46        Self::new(Tokio::default().with_handle(handle))
47    }
48
49    /// Create a new TokioRuntime and bind it to this tokio runtime.
50    pub fn tokio_with_runtime(runtime: TokioRT) -> Self {
51        Self::new(Tokio::default().with_runtime(runtime))
52    }
53}
54
55/// The [`RuntimeKit`] implementation backed by the tokio async runtime
56#[derive(Default, Clone, Debug)]
57pub struct Tokio {
58    handle: Option<Handle>,
59    runtime: Option<Arc<TokioRT>>,
60}
61
62impl Tokio {
63    /// Bind to the tokio Runtime associated to this handle by default.
64    #[must_use]
65    pub fn with_handle(mut self, handle: Handle) -> Self {
66        self.handle = Some(handle);
67        self
68    }
69
70    /// Bind to this tokio runtime by default.
71    pub fn with_runtime(mut self, runtime: TokioRT) -> Self {
72        let handle = runtime.handle().clone();
73        self.runtime = Some(Arc::new(runtime));
74        self.with_handle(handle)
75    }
76
77    /// Bind to the current tokio Runtime by default.
78    #[must_use]
79    pub fn current() -> Self {
80        Self::default().with_handle(Handle::current())
81    }
82
83    fn handle(&self) -> Option<Handle> {
84        self.handle.clone().or_else(|| Handle::try_current().ok())
85    }
86
87    fn enter(&self) -> Option<EnterGuard<'_>> {
88        self.runtime
89            .as_ref()
90            .map(|r| r.handle())
91            .or(self.handle.as_ref())
92            .map(Handle::enter)
93    }
94}
95
96impl RuntimeKit for Tokio {}
97
98impl Executor for Tokio {
99    type Task<T: Send + 'static> = TTask<T>;
100
101    fn block_on<T, F: Future<Output = T>>(&self, f: F) -> T {
102        if let Some(runtime) = self.runtime.as_ref() {
103            runtime.block_on(f)
104        } else if let Some(handle) = self.handle() {
105            handle.block_on(f)
106        } else {
107            Handle::try_current()
108                .expect("no tokio runtime: use Runtime::tokio() or Runtime::tokio_with_handle()")
109                .block_on(f)
110        }
111    }
112
113    fn spawn<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
114        &self,
115        f: F,
116    ) -> Task<Self::Task<T>> {
117        TTask(Some(if let Some(handle) = self.handle() {
118            handle.spawn(f)
119        } else {
120            tokio::task::spawn(f)
121        }))
122        .into()
123    }
124
125    fn spawn_blocking<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(
126        &self,
127        f: F,
128    ) -> Task<Self::Task<T>> {
129        TTask(Some(if let Some(handle) = self.handle() {
130            handle.spawn_blocking(f)
131        } else {
132            tokio::task::spawn_blocking(f)
133        }))
134        .into()
135    }
136}
137
138impl Reactor for Tokio {
139    type TcpStream = Compat<TcpStream>;
140    type Sleep = Sleep;
141
142    fn register<H: Read + Write + AsSysFd + Send + 'static>(
143        &self,
144        socket: H,
145    ) -> io::Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static> {
146        let _enter = self.enter();
147        #[cfg(unix)]
148        {
149            Ok(unix::AsyncFdWrapper(tokio::io::unix::AsyncFd::new(socket)?))
150        }
151        #[cfg(not(unix))]
152        {
153            let _ = socket;
154            Err::<crate::util::DummyIO, _>(io::Error::other(
155                "Registering FD on tokio reactor is only supported on unix",
156            ))
157        }
158    }
159
160    fn sleep(&self, dur: Duration) -> Self::Sleep {
161        let _enter = self.enter();
162        tokio::time::sleep(dur)
163    }
164
165    fn interval(&self, dur: Duration) -> impl Stream<Item = Instant> + Send + 'static {
166        let _enter = self.enter();
167        IntervalStream::new(tokio::time::interval(dur)).map(tokio::time::Instant::into_std)
168    }
169
170    fn tcp_connect_addr(
171        &self,
172        addr: SocketAddr,
173    ) -> impl Future<Output = io::Result<Self::TcpStream>> + Send + 'static {
174        let _enter = self.enter();
175        async move {
176            let stream = TcpStream::connect(addr).await?;
177            stream.set_nodelay(true)?;
178            Ok(stream.compat())
179        }
180    }
181}
182
183mod task {
184    use crate::util::TaskImpl;
185    use async_trait::async_trait;
186    use std::{
187        future::Future,
188        pin::Pin,
189        task::{Context, Poll},
190    };
191
192    /// A tokio task
193    #[derive(Debug)]
194    pub struct TTask<T: Send + 'static>(pub(super) Option<tokio::task::JoinHandle<T>>);
195
196    #[async_trait]
197    impl<T: Send + 'static> TaskImpl for TTask<T> {
198        async fn cancel(&mut self) -> Option<T> {
199            let task = self.0.take()?;
200            task.abort();
201            task.await.ok()
202        }
203    }
204
205    impl<T: Send + 'static> Future for TTask<T> {
206        type Output = T;
207
208        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
209            match self.0.as_mut() {
210                None => Poll::Pending,
211                Some(task) => match Pin::new(task).poll(cx) {
212                    Poll::Pending => Poll::Pending,
213                    Poll::Ready(Ok(res)) => Poll::Ready(res),
214                    Poll::Ready(Err(_)) => Poll::Pending,
215                },
216            }
217        }
218    }
219}
220
221#[cfg(unix)]
222mod unix {
223    use super::*;
224    use futures_io::{AsyncRead, AsyncWrite};
225    use std::{
226        io::{IoSlice, IoSliceMut},
227        pin::Pin,
228        task::{Context, Poll},
229    };
230    use tokio::io::unix::AsyncFd;
231
232    pub(super) struct AsyncFdWrapper<H: Read + Write + AsSysFd>(pub(super) AsyncFd<H>);
233
234    impl<H: Read + Write + AsSysFd> AsyncFdWrapper<H> {
235        fn read<F: FnOnce(&mut AsyncFd<H>) -> io::Result<usize>>(
236            mut self: Pin<&mut Self>,
237            cx: &mut Context<'_>,
238            f: F,
239        ) -> Option<Poll<io::Result<usize>>> {
240            Some(match self.0.poll_read_ready_mut(cx) {
241                Poll::Pending => Poll::Pending,
242                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
243                Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
244                    Ok(res) => Poll::Ready(res),
245                    Err(_) => return None,
246                },
247            })
248        }
249
250        fn write<R, F: FnOnce(&mut AsyncFd<H>) -> io::Result<R>>(
251            mut self: Pin<&mut Self>,
252            cx: &mut Context<'_>,
253            f: F,
254        ) -> Option<Poll<io::Result<R>>> {
255            Some(match self.0.poll_write_ready_mut(cx) {
256                Poll::Pending => Poll::Pending,
257                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
258                Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
259                    Ok(res) => Poll::Ready(res),
260                    Err(_) => return None,
261                },
262            })
263        }
264    }
265
266    impl<H: Read + Write + AsSysFd> Unpin for AsyncFdWrapper<H> {}
267
268    impl<H: Read + Write + AsSysFd> AsyncRead for AsyncFdWrapper<H> {
269        fn poll_read(
270            mut self: Pin<&mut Self>,
271            cx: &mut Context<'_>,
272            buf: &mut [u8],
273        ) -> Poll<io::Result<usize>> {
274            loop {
275                if let Some(res) = self.as_mut().read(cx, |socket| socket.get_mut().read(buf)) {
276                    return res;
277                }
278            }
279        }
280
281        fn poll_read_vectored(
282            mut self: Pin<&mut Self>,
283            cx: &mut Context<'_>,
284            bufs: &mut [IoSliceMut<'_>],
285        ) -> Poll<io::Result<usize>> {
286            loop {
287                if let Some(res) = self
288                    .as_mut()
289                    .read(cx, |socket| socket.get_mut().read_vectored(bufs))
290                {
291                    return res;
292                }
293            }
294        }
295    }
296
297    impl<H: Read + Write + AsSysFd + Send + 'static> AsyncWrite for AsyncFdWrapper<H> {
298        fn poll_write(
299            mut self: Pin<&mut Self>,
300            cx: &mut Context<'_>,
301            buf: &[u8],
302        ) -> Poll<io::Result<usize>> {
303            loop {
304                if let Some(res) = self
305                    .as_mut()
306                    .write(cx, |socket| socket.get_mut().write(buf))
307                {
308                    return res;
309                }
310            }
311        }
312
313        fn poll_write_vectored(
314            mut self: Pin<&mut Self>,
315            cx: &mut Context<'_>,
316            bufs: &[IoSlice<'_>],
317        ) -> Poll<io::Result<usize>> {
318            loop {
319                if let Some(res) = self
320                    .as_mut()
321                    .write(cx, |socket| socket.get_mut().write_vectored(bufs))
322                {
323                    return res;
324                }
325            }
326        }
327
328        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
329            loop {
330                if let Some(res) = self.as_mut().write(cx, |socket| socket.get_mut().flush()) {
331                    return res;
332                }
333            }
334        }
335
336        fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
337            self.poll_flush(cx)
338        }
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345
346    #[test]
347    fn auto_traits() {
348        use crate::util::test::*;
349        let runtime = Runtime::tokio().unwrap();
350        assert_send(&runtime);
351        assert_sync(&runtime);
352        assert_clone(&runtime);
353    }
354}