Skip to main content

async_rs/implementors/
tokio.rs

1//! tokio implementation of async runtime definition traits
2
3use crate::{
4    Runtime,
5    sys::AsSysFd,
6    traits::{Executor, Reactor, RuntimeKit},
7    util::Task,
8};
9use async_compat::{Compat, CompatExt};
10use futures_core::Stream;
11use futures_io::{AsyncRead, AsyncWrite};
12use std::{
13    future::Future,
14    io::{self, Read, Write},
15    net::SocketAddr,
16    sync::Arc,
17    time::{Duration, Instant},
18};
19use tokio::{
20    net::TcpStream,
21    runtime::{EnterGuard, Handle, Runtime as TokioRT},
22    time::Sleep,
23};
24use tokio_stream::{StreamExt, wrappers::IntervalStream};
25
26use task::TTask;
27
28/// Type alias for the tokio runtime
29pub type TokioRuntime = Runtime<Tokio>;
30
31impl TokioRuntime {
32    /// Create a new TokioRuntime and bind it to this tokio runtime.
33    pub fn tokio() -> io::Result<Self> {
34        Ok(Self::tokio_with_runtime(TokioRT::new()?))
35    }
36
37    /// Create a new TokioRuntime and bind it to the current tokio runtime by default.
38    pub fn tokio_current() -> Self {
39        Self::new(Tokio::current())
40    }
41
42    /// Create a new TokioRuntime and bind it to the tokio runtime associated to this handle by default.
43    pub fn tokio_with_handle(handle: Handle) -> Self {
44        Self::new(Tokio::default().with_handle(handle))
45    }
46
47    /// Create a new TokioRuntime and bind it to this tokio runtime.
48    pub fn tokio_with_runtime(runtime: TokioRT) -> Self {
49        Self::new(Tokio::default().with_runtime(runtime))
50    }
51}
52
53/// Dummy object implementing async common interfaces on top of tokio
54#[derive(Default, Clone, Debug)]
55pub struct Tokio {
56    handle: Option<Handle>,
57    runtime: Option<Arc<TokioRT>>,
58}
59
60impl Tokio {
61    /// Bind to the tokio Runtime associated to this handle by default.
62    pub fn with_handle(mut self, handle: Handle) -> Self {
63        self.handle = Some(handle);
64        self
65    }
66
67    /// Bind to the tokio Runtime associated to this handle by default.
68    pub fn with_runtime(mut self, runtime: TokioRT) -> Self {
69        let handle = runtime.handle().clone();
70        self.runtime = Some(Arc::new(runtime));
71        self.with_handle(handle)
72    }
73
74    /// Bind to the current tokio Runtime by default.
75    pub fn current() -> Self {
76        Self::default().with_handle(Handle::current())
77    }
78
79    fn handle(&self) -> Option<Handle> {
80        self.handle.clone().or_else(|| Handle::try_current().ok())
81    }
82
83    fn enter(&self) -> Option<EnterGuard<'_>> {
84        self.runtime
85            .as_ref()
86            .map(|r| r.handle())
87            .or(self.handle.as_ref())
88            .map(Handle::enter)
89    }
90}
91
92impl RuntimeKit for Tokio {}
93
94impl Executor for Tokio {
95    type Task<T: Send + 'static> = TTask<T>;
96
97    fn block_on<T, F: Future<Output = T>>(&self, f: F) -> T {
98        if let Some(runtime) = self.runtime.as_ref() {
99            runtime.block_on(f)
100        } else if let Some(handle) = self.handle() {
101            handle.block_on(f)
102        } else {
103            Handle::try_current()
104                .expect("no tokio runtime: use Runtime::tokio() or Runtime::tokio_with_handle()")
105                .block_on(f)
106        }
107    }
108
109    fn spawn<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
110        &self,
111        f: F,
112    ) -> Task<Self::Task<T>> {
113        TTask(Some(if let Some(handle) = self.handle() {
114            handle.spawn(f)
115        } else {
116            tokio::task::spawn(f)
117        }))
118        .into()
119    }
120
121    fn spawn_blocking<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(
122        &self,
123        f: F,
124    ) -> Task<Self::Task<T>> {
125        TTask(Some(if let Some(handle) = self.handle() {
126            handle.spawn_blocking(f)
127        } else {
128            tokio::task::spawn_blocking(f)
129        }))
130        .into()
131    }
132}
133
134impl Reactor for Tokio {
135    type TcpStream = Compat<TcpStream>;
136    type Sleep = Sleep;
137
138    fn register<H: Read + Write + AsSysFd + Send + 'static>(
139        &self,
140        socket: H,
141    ) -> io::Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static> {
142        let _enter = self.enter();
143        #[cfg(unix)]
144        {
145            Ok(unix::AsyncFdWrapper(tokio::io::unix::AsyncFd::new(socket)?))
146        }
147        #[cfg(not(unix))]
148        {
149            let _ = socket;
150            Err::<crate::util::DummyIO, _>(io::Error::other(
151                "Registering FD on tokio reactor is only supported on unix",
152            ))
153        }
154    }
155
156    fn sleep(&self, dur: Duration) -> Self::Sleep {
157        let _enter = self.enter();
158        tokio::time::sleep(dur)
159    }
160
161    fn interval(&self, dur: Duration) -> impl Stream<Item = Instant> + Send + 'static {
162        let _enter = self.enter();
163        IntervalStream::new(tokio::time::interval(dur)).map(tokio::time::Instant::into_std)
164    }
165
166    fn tcp_connect_addr(
167        &self,
168        addr: SocketAddr,
169    ) -> impl Future<Output = io::Result<Self::TcpStream>> + Send + 'static {
170        let _enter = self.enter();
171        async move {
172            let stream = TcpStream::connect(addr).await?;
173            stream.set_nodelay(true)?;
174            Ok(stream.compat())
175        }
176    }
177}
178
179mod task {
180    use crate::util::TaskImpl;
181    use async_trait::async_trait;
182    use std::{
183        future::Future,
184        pin::Pin,
185        task::{Context, Poll},
186    };
187
188    /// A tokio task
189    #[derive(Debug)]
190    pub struct TTask<T: Send + 'static>(pub(super) Option<tokio::task::JoinHandle<T>>);
191
192    #[async_trait]
193    impl<T: Send + 'static> TaskImpl for TTask<T> {
194        async fn cancel(&mut self) -> Option<T> {
195            let task = self.0.take()?;
196            task.abort();
197            task.await.ok()
198        }
199    }
200
201    impl<T: Send + 'static> Future for TTask<T> {
202        type Output = T;
203
204        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
205            let task = self.0.as_mut().expect("task has been canceled");
206            match Pin::new(task).poll(cx) {
207                Poll::Pending => Poll::Pending,
208                Poll::Ready(res) => Poll::Ready(res.expect("task has been canceled")),
209            }
210        }
211    }
212}
213
214#[cfg(unix)]
215mod unix {
216    use super::*;
217    use futures_io::{AsyncRead, AsyncWrite};
218    use std::{
219        io::{IoSlice, IoSliceMut},
220        pin::Pin,
221        task::{Context, Poll},
222    };
223    use tokio::io::unix::AsyncFd;
224
225    pub(super) struct AsyncFdWrapper<H: Read + Write + AsSysFd>(pub(super) AsyncFd<H>);
226
227    impl<H: Read + Write + AsSysFd> AsyncFdWrapper<H> {
228        fn read<F: FnOnce(&mut AsyncFd<H>) -> io::Result<usize>>(
229            mut self: Pin<&mut Self>,
230            cx: &mut Context<'_>,
231            f: F,
232        ) -> Option<Poll<io::Result<usize>>> {
233            Some(match self.0.poll_read_ready_mut(cx) {
234                Poll::Pending => Poll::Pending,
235                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
236                Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
237                    Ok(res) => Poll::Ready(res),
238                    Err(_) => return None,
239                },
240            })
241        }
242
243        fn write<R, F: FnOnce(&mut AsyncFd<H>) -> io::Result<R>>(
244            mut self: Pin<&mut Self>,
245            cx: &mut Context<'_>,
246            f: F,
247        ) -> Option<Poll<io::Result<R>>> {
248            Some(match self.0.poll_write_ready_mut(cx) {
249                Poll::Pending => Poll::Pending,
250                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
251                Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
252                    Ok(res) => Poll::Ready(res),
253                    Err(_) => return None,
254                },
255            })
256        }
257    }
258
259    impl<H: Read + Write + AsSysFd> Unpin for AsyncFdWrapper<H> {}
260
261    impl<H: Read + Write + AsSysFd> AsyncRead for AsyncFdWrapper<H> {
262        fn poll_read(
263            mut self: Pin<&mut Self>,
264            cx: &mut Context<'_>,
265            buf: &mut [u8],
266        ) -> Poll<io::Result<usize>> {
267            loop {
268                if let Some(res) = self.as_mut().read(cx, |socket| socket.get_mut().read(buf)) {
269                    return res;
270                }
271            }
272        }
273
274        fn poll_read_vectored(
275            mut self: Pin<&mut Self>,
276            cx: &mut Context<'_>,
277            bufs: &mut [IoSliceMut<'_>],
278        ) -> Poll<io::Result<usize>> {
279            loop {
280                if let Some(res) = self
281                    .as_mut()
282                    .read(cx, |socket| socket.get_mut().read_vectored(bufs))
283                {
284                    return res;
285                }
286            }
287        }
288    }
289
290    impl<H: Read + Write + AsSysFd + Send + 'static> AsyncWrite for AsyncFdWrapper<H> {
291        fn poll_write(
292            mut self: Pin<&mut Self>,
293            cx: &mut Context<'_>,
294            buf: &[u8],
295        ) -> Poll<io::Result<usize>> {
296            loop {
297                if let Some(res) = self
298                    .as_mut()
299                    .write(cx, |socket| socket.get_mut().write(buf))
300                {
301                    return res;
302                }
303            }
304        }
305
306        fn poll_write_vectored(
307            mut self: Pin<&mut Self>,
308            cx: &mut Context<'_>,
309            bufs: &[IoSlice<'_>],
310        ) -> Poll<io::Result<usize>> {
311            loop {
312                if let Some(res) = self
313                    .as_mut()
314                    .write(cx, |socket| socket.get_mut().write_vectored(bufs))
315                {
316                    return res;
317                }
318            }
319        }
320
321        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
322            loop {
323                if let Some(res) = self.as_mut().write(cx, |socket| socket.get_mut().flush()) {
324                    return res;
325                }
326            }
327        }
328
329        fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
330            self.poll_flush(cx)
331        }
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    #[test]
340    fn auto_traits() {
341        use crate::util::test::*;
342        let runtime = Runtime::tokio().unwrap();
343        assert_send(&runtime);
344        assert_sync(&runtime);
345        assert_clone(&runtime);
346    }
347}