Skip to main content

async_rs/implementors/
tokio.rs

1//! tokio implementation of async runtime definition traits
2
3use crate::{
4    Runtime,
5    sys::AsSysFd,
6    traits::{Executor, Reactor, RuntimeKit},
7    util::Task,
8};
9use async_compat::{Compat, CompatExt};
10use futures_core::Stream;
11use futures_io::{AsyncRead, AsyncWrite};
12use std::{
13    future::Future,
14    io::{self, Read, Write},
15    net::SocketAddr,
16    sync::Arc,
17    time::{Duration, Instant},
18};
19use tokio::{
20    net::TcpStream,
21    runtime::{EnterGuard, Handle, Runtime as TokioRT},
22    time::Sleep,
23};
24use tokio_stream::{StreamExt, wrappers::IntervalStream};
25
26use task::TTask;
27
28/// Type alias for the tokio runtime
29pub type TokioRuntime = Runtime<Tokio>;
30
31impl TokioRuntime {
32    /// Create a new TokioRuntime and bind it to this tokio runtime.
33    pub fn tokio() -> io::Result<Self> {
34        Ok(Self::tokio_with_runtime(TokioRT::new()?))
35    }
36
37    /// Create a new TokioRuntime and bind it to the current tokio runtime by default.
38    pub fn tokio_current() -> Self {
39        Self::new(Tokio::current())
40    }
41
42    /// Create a new TokioRuntime and bind it to the tokio runtime associated to this handle by default.
43    pub fn tokio_with_handle(handle: Handle) -> Self {
44        Self::new(Tokio::default().with_handle(handle))
45    }
46
47    /// Create a new TokioRuntime and bind it to this tokio runtime.
48    pub fn tokio_with_runtime(runtime: TokioRT) -> Self {
49        Self::new(Tokio::default().with_runtime(runtime))
50    }
51}
52
53/// Dummy object implementing async common interfaces on top of tokio
54#[derive(Default, Clone, Debug)]
55pub struct Tokio {
56    handle: Option<Handle>,
57    runtime: Option<Arc<TokioRT>>,
58}
59
60impl Tokio {
61    /// Bind to the tokio Runtime associated to this handle by default.
62    pub fn with_handle(mut self, handle: Handle) -> Self {
63        self.handle = Some(handle);
64        self
65    }
66
67    /// Bind to the tokio Runtime associated to this handle by default.
68    pub fn with_runtime(mut self, runtime: TokioRT) -> Self {
69        let handle = runtime.handle().clone();
70        self.runtime = Some(Arc::new(runtime));
71        self.with_handle(handle)
72    }
73
74    /// Bind to the current tokio Runtime by default.
75    pub fn current() -> Self {
76        Self::default().with_handle(Handle::current())
77    }
78
79    fn handle(&self) -> Option<Handle> {
80        self.handle.clone().or_else(|| Handle::try_current().ok())
81    }
82
83    fn enter(&self) -> Option<EnterGuard<'_>> {
84        self.runtime
85            .as_ref()
86            .map(|r| r.handle())
87            .or(self.handle.as_ref())
88            .map(Handle::enter)
89    }
90}
91
92impl RuntimeKit for Tokio {}
93
94impl Executor for Tokio {
95    type Task<T: Send + 'static> = TTask<T>;
96
97    fn block_on<T, F: Future<Output = T>>(&self, f: F) -> T {
98        if let Some(runtime) = self.runtime.as_ref() {
99            runtime.block_on(f)
100        } else if let Some(handle) = self.handle() {
101            handle.block_on(f)
102        } else {
103            Handle::try_current()
104                .expect("no tokio runtime: use Runtime::tokio() or Runtime::tokio_with_handle()")
105                .block_on(f)
106        }
107    }
108
109    fn spawn<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
110        &self,
111        f: F,
112    ) -> Task<Self::Task<T>> {
113        TTask(Some(if let Some(handle) = self.handle() {
114            handle.spawn(f)
115        } else {
116            tokio::task::spawn(f)
117        }))
118        .into()
119    }
120
121    fn spawn_blocking<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(
122        &self,
123        f: F,
124    ) -> Task<Self::Task<T>> {
125        TTask(Some(if let Some(handle) = self.handle() {
126            handle.spawn_blocking(f)
127        } else {
128            tokio::task::spawn_blocking(f)
129        }))
130        .into()
131    }
132}
133
134impl Reactor for Tokio {
135    type TcpStream = Compat<TcpStream>;
136    type Sleep = Sleep;
137
138    fn register<H: Read + Write + AsSysFd + Send + 'static>(
139        &self,
140        socket: H,
141    ) -> io::Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static> {
142        let _enter = self.enter();
143        #[cfg(unix)]
144        {
145            Ok(unix::AsyncFdWrapper(tokio::io::unix::AsyncFd::new(socket)?))
146        }
147        #[cfg(not(unix))]
148        {
149            let _ = socket;
150            Err::<crate::util::DummyIO, _>(io::Error::other(
151                "Registering FD on tokio reactor is only supported on unix",
152            ))
153        }
154    }
155
156    fn sleep(&self, dur: Duration) -> Self::Sleep {
157        let _enter = self.enter();
158        tokio::time::sleep(dur)
159    }
160
161    fn interval(&self, dur: Duration) -> impl Stream<Item = Instant> + Send + 'static {
162        let _enter = self.enter();
163        IntervalStream::new(tokio::time::interval(dur)).map(tokio::time::Instant::into_std)
164    }
165
166    fn tcp_connect_addr(
167        &self,
168        addr: SocketAddr,
169    ) -> impl Future<Output = io::Result<Self::TcpStream>> + Send + 'static {
170        let _enter = self.enter();
171        async move { Ok(TcpStream::connect(addr).await?.compat()) }
172    }
173}
174
175mod task {
176    use crate::util::TaskImpl;
177    use async_trait::async_trait;
178    use std::{
179        future::Future,
180        pin::Pin,
181        task::{Context, Poll},
182    };
183
184    /// A tokio task
185    #[derive(Debug)]
186    pub struct TTask<T: Send + 'static>(pub(super) Option<tokio::task::JoinHandle<T>>);
187
188    #[async_trait]
189    impl<T: Send + 'static> TaskImpl for TTask<T> {
190        async fn cancel(&mut self) -> Option<T> {
191            let task = self.0.take()?;
192            task.abort();
193            task.await.ok()
194        }
195    }
196
197    impl<T: Send + 'static> Future for TTask<T> {
198        type Output = T;
199
200        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
201            let task = self.0.as_mut().expect("task has been canceled");
202            match Pin::new(task).poll(cx) {
203                Poll::Pending => Poll::Pending,
204                Poll::Ready(res) => Poll::Ready(res.expect("task has been canceled")),
205            }
206        }
207    }
208}
209
210#[cfg(unix)]
211mod unix {
212    use super::*;
213    use futures_io::{AsyncRead, AsyncWrite};
214    use std::{
215        io::{IoSlice, IoSliceMut},
216        pin::Pin,
217        task::{Context, Poll},
218    };
219    use tokio::io::unix::AsyncFd;
220
221    pub(super) struct AsyncFdWrapper<H: Read + Write + AsSysFd>(pub(super) AsyncFd<H>);
222
223    impl<H: Read + Write + AsSysFd> AsyncFdWrapper<H> {
224        fn read<F: FnOnce(&mut AsyncFd<H>) -> io::Result<usize>>(
225            mut self: Pin<&mut Self>,
226            cx: &mut Context<'_>,
227            f: F,
228        ) -> Option<Poll<io::Result<usize>>> {
229            Some(match self.0.poll_read_ready_mut(cx) {
230                Poll::Pending => Poll::Pending,
231                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
232                Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
233                    Ok(res) => Poll::Ready(res),
234                    Err(_) => return None,
235                },
236            })
237        }
238
239        fn write<R, F: FnOnce(&mut AsyncFd<H>) -> io::Result<R>>(
240            mut self: Pin<&mut Self>,
241            cx: &mut Context<'_>,
242            f: F,
243        ) -> Option<Poll<io::Result<R>>> {
244            Some(match self.0.poll_write_ready_mut(cx) {
245                Poll::Pending => Poll::Pending,
246                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
247                Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
248                    Ok(res) => Poll::Ready(res),
249                    Err(_) => return None,
250                },
251            })
252        }
253    }
254
255    impl<H: Read + Write + AsSysFd> Unpin for AsyncFdWrapper<H> {}
256
257    impl<H: Read + Write + AsSysFd> AsyncRead for AsyncFdWrapper<H> {
258        fn poll_read(
259            mut self: Pin<&mut Self>,
260            cx: &mut Context<'_>,
261            buf: &mut [u8],
262        ) -> Poll<io::Result<usize>> {
263            loop {
264                if let Some(res) = self.as_mut().read(cx, |socket| socket.get_mut().read(buf)) {
265                    return res;
266                }
267            }
268        }
269
270        fn poll_read_vectored(
271            mut self: Pin<&mut Self>,
272            cx: &mut Context<'_>,
273            bufs: &mut [IoSliceMut<'_>],
274        ) -> Poll<io::Result<usize>> {
275            loop {
276                if let Some(res) = self
277                    .as_mut()
278                    .read(cx, |socket| socket.get_mut().read_vectored(bufs))
279                {
280                    return res;
281                }
282            }
283        }
284    }
285
286    impl<H: Read + Write + AsSysFd + Send + 'static> AsyncWrite for AsyncFdWrapper<H> {
287        fn poll_write(
288            mut self: Pin<&mut Self>,
289            cx: &mut Context<'_>,
290            buf: &[u8],
291        ) -> Poll<io::Result<usize>> {
292            loop {
293                if let Some(res) = self
294                    .as_mut()
295                    .write(cx, |socket| socket.get_mut().write(buf))
296                {
297                    return res;
298                }
299            }
300        }
301
302        fn poll_write_vectored(
303            mut self: Pin<&mut Self>,
304            cx: &mut Context<'_>,
305            bufs: &[IoSlice<'_>],
306        ) -> Poll<io::Result<usize>> {
307            loop {
308                if let Some(res) = self
309                    .as_mut()
310                    .write(cx, |socket| socket.get_mut().write_vectored(bufs))
311                {
312                    return res;
313                }
314            }
315        }
316
317        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
318            loop {
319                if let Some(res) = self.as_mut().write(cx, |socket| socket.get_mut().flush()) {
320                    return res;
321                }
322            }
323        }
324
325        fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
326            self.poll_flush(cx)
327        }
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    #[test]
336    fn auto_traits() {
337        use crate::util::test::*;
338        let runtime = Runtime::tokio().unwrap();
339        assert_send(&runtime);
340        assert_sync(&runtime);
341        assert_clone(&runtime);
342    }
343}