Skip to main content

async_rs/implementors/
tokio.rs

1//! tokio implementation of async runtime definition traits
2
3use crate::{
4    Runtime,
5    sys::AsSysFd,
6    traits::{Executor, Reactor, RuntimeKit},
7    util::Task,
8};
9use async_compat::{Compat, CompatExt};
10use futures_core::Stream;
11use futures_io::{AsyncRead, AsyncWrite};
12use std::{
13    future::Future,
14    io::{self, Read, Write},
15    net::SocketAddr,
16    sync::Arc,
17    time::{Duration, Instant},
18};
19use tokio::{
20    net::TcpStream,
21    runtime::{EnterGuard, Handle, Runtime as TokioRT},
22    time::Sleep,
23};
24use tokio_stream::{StreamExt, wrappers::IntervalStream};
25
26use task::TTask;
27
28/// Type alias for the tokio runtime
29pub type TokioRuntime = Runtime<Tokio>;
30
31impl TokioRuntime {
32    /// Create a new TokioRuntime and bind it to this tokio runtime.
33    pub fn tokio() -> io::Result<Self> {
34        Ok(Self::tokio_with_runtime(TokioRT::new()?))
35    }
36
37    /// Create a new TokioRuntime and bind it to the current tokio runtime by default.
38    pub fn tokio_current() -> Self {
39        Self::new(Tokio::current())
40    }
41
42    /// Create a new TokioRuntime and bind it to the tokio runtime associated to this handle by default.
43    pub fn tokio_with_handle(handle: Handle) -> Self {
44        Self::new(Tokio::default().with_handle(handle))
45    }
46
47    /// Create a new TokioRuntime and bind it to this tokio runtime.
48    pub fn tokio_with_runtime(runtime: TokioRT) -> Self {
49        Self::new(Tokio::default().with_runtime(runtime))
50    }
51}
52
53/// Dummy object implementing async common interfaces on top of tokio
54#[derive(Default, Clone, Debug)]
55pub struct Tokio {
56    handle: Option<Handle>,
57    runtime: Option<Arc<TokioRT>>,
58}
59
60impl Tokio {
61    /// Bind to the tokio Runtime associated to this handle by default.
62    pub fn with_handle(mut self, handle: Handle) -> Self {
63        self.handle = Some(handle);
64        self
65    }
66
67    /// Bind to this tokio runtime by default.
68    pub fn with_runtime(mut self, runtime: TokioRT) -> Self {
69        let handle = runtime.handle().clone();
70        self.runtime = Some(Arc::new(runtime));
71        self.with_handle(handle)
72    }
73
74    /// Bind to the current tokio Runtime by default.
75    pub fn current() -> Self {
76        Self::default().with_handle(Handle::current())
77    }
78
79    fn handle(&self) -> Option<Handle> {
80        self.handle.clone().or_else(|| Handle::try_current().ok())
81    }
82
83    fn enter(&self) -> Option<EnterGuard<'_>> {
84        self.runtime
85            .as_ref()
86            .map(|r| r.handle())
87            .or(self.handle.as_ref())
88            .map(Handle::enter)
89    }
90}
91
92impl RuntimeKit for Tokio {}
93
94impl Executor for Tokio {
95    type Task<T: Send + 'static> = TTask<T>;
96
97    fn block_on<T, F: Future<Output = T>>(&self, f: F) -> T {
98        if let Some(runtime) = self.runtime.as_ref() {
99            runtime.block_on(f)
100        } else if let Some(handle) = self.handle() {
101            handle.block_on(f)
102        } else {
103            Handle::try_current()
104                .expect("no tokio runtime: use Runtime::tokio() or Runtime::tokio_with_handle()")
105                .block_on(f)
106        }
107    }
108
109    fn spawn<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
110        &self,
111        f: F,
112    ) -> Task<Self::Task<T>> {
113        TTask(Some(if let Some(handle) = self.handle() {
114            handle.spawn(f)
115        } else {
116            tokio::task::spawn(f)
117        }))
118        .into()
119    }
120
121    fn spawn_blocking<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(
122        &self,
123        f: F,
124    ) -> Task<Self::Task<T>> {
125        TTask(Some(if let Some(handle) = self.handle() {
126            handle.spawn_blocking(f)
127        } else {
128            tokio::task::spawn_blocking(f)
129        }))
130        .into()
131    }
132}
133
134impl Reactor for Tokio {
135    type TcpStream = Compat<TcpStream>;
136    type Sleep = Sleep;
137
138    fn register<H: Read + Write + AsSysFd + Send + 'static>(
139        &self,
140        socket: H,
141    ) -> io::Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static> {
142        let _enter = self.enter();
143        #[cfg(unix)]
144        {
145            Ok(unix::AsyncFdWrapper(tokio::io::unix::AsyncFd::new(socket)?))
146        }
147        #[cfg(not(unix))]
148        {
149            let _ = socket;
150            Err::<crate::util::DummyIO, _>(io::Error::other(
151                "Registering FD on tokio reactor is only supported on unix",
152            ))
153        }
154    }
155
156    fn sleep(&self, dur: Duration) -> Self::Sleep {
157        let _enter = self.enter();
158        tokio::time::sleep(dur)
159    }
160
161    fn interval(&self, dur: Duration) -> impl Stream<Item = Instant> + Send + 'static {
162        let _enter = self.enter();
163        IntervalStream::new(tokio::time::interval(dur)).map(tokio::time::Instant::into_std)
164    }
165
166    fn tcp_connect_addr(
167        &self,
168        addr: SocketAddr,
169    ) -> impl Future<Output = io::Result<Self::TcpStream>> + Send + 'static {
170        let _enter = self.enter();
171        async move {
172            let stream = TcpStream::connect(addr).await?;
173            stream.set_nodelay(true)?;
174            Ok(stream.compat())
175        }
176    }
177}
178
179mod task {
180    use crate::util::TaskImpl;
181    use async_trait::async_trait;
182    use std::{
183        future::Future,
184        pin::Pin,
185        task::{Context, Poll},
186    };
187
188    /// A tokio task
189    #[derive(Debug)]
190    pub struct TTask<T: Send + 'static>(pub(super) Option<tokio::task::JoinHandle<T>>);
191
192    #[async_trait]
193    impl<T: Send + 'static> TaskImpl for TTask<T> {
194        async fn cancel(&mut self) -> Option<T> {
195            let task = self.0.take()?;
196            task.abort();
197            task.await.ok()
198        }
199    }
200
201    impl<T: Send + 'static> Future for TTask<T> {
202        type Output = T;
203
204        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
205            match self.0.as_mut() {
206                None => Poll::Pending,
207                Some(task) => match Pin::new(task).poll(cx) {
208                    Poll::Pending => Poll::Pending,
209                    Poll::Ready(Ok(res)) => Poll::Ready(res),
210                    Poll::Ready(Err(_)) => Poll::Pending,
211                },
212            }
213        }
214    }
215}
216
217#[cfg(unix)]
218mod unix {
219    use super::*;
220    use futures_io::{AsyncRead, AsyncWrite};
221    use std::{
222        io::{IoSlice, IoSliceMut},
223        pin::Pin,
224        task::{Context, Poll},
225    };
226    use tokio::io::unix::AsyncFd;
227
228    pub(super) struct AsyncFdWrapper<H: Read + Write + AsSysFd>(pub(super) AsyncFd<H>);
229
230    impl<H: Read + Write + AsSysFd> AsyncFdWrapper<H> {
231        fn read<F: FnOnce(&mut AsyncFd<H>) -> io::Result<usize>>(
232            mut self: Pin<&mut Self>,
233            cx: &mut Context<'_>,
234            f: F,
235        ) -> Option<Poll<io::Result<usize>>> {
236            Some(match self.0.poll_read_ready_mut(cx) {
237                Poll::Pending => Poll::Pending,
238                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
239                Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
240                    Ok(res) => Poll::Ready(res),
241                    Err(_) => return None,
242                },
243            })
244        }
245
246        fn write<R, F: FnOnce(&mut AsyncFd<H>) -> io::Result<R>>(
247            mut self: Pin<&mut Self>,
248            cx: &mut Context<'_>,
249            f: F,
250        ) -> Option<Poll<io::Result<R>>> {
251            Some(match self.0.poll_write_ready_mut(cx) {
252                Poll::Pending => Poll::Pending,
253                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
254                Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
255                    Ok(res) => Poll::Ready(res),
256                    Err(_) => return None,
257                },
258            })
259        }
260    }
261
262    impl<H: Read + Write + AsSysFd> Unpin for AsyncFdWrapper<H> {}
263
264    impl<H: Read + Write + AsSysFd> AsyncRead for AsyncFdWrapper<H> {
265        fn poll_read(
266            mut self: Pin<&mut Self>,
267            cx: &mut Context<'_>,
268            buf: &mut [u8],
269        ) -> Poll<io::Result<usize>> {
270            loop {
271                if let Some(res) = self.as_mut().read(cx, |socket| socket.get_mut().read(buf)) {
272                    return res;
273                }
274            }
275        }
276
277        fn poll_read_vectored(
278            mut self: Pin<&mut Self>,
279            cx: &mut Context<'_>,
280            bufs: &mut [IoSliceMut<'_>],
281        ) -> Poll<io::Result<usize>> {
282            loop {
283                if let Some(res) = self
284                    .as_mut()
285                    .read(cx, |socket| socket.get_mut().read_vectored(bufs))
286                {
287                    return res;
288                }
289            }
290        }
291    }
292
293    impl<H: Read + Write + AsSysFd + Send + 'static> AsyncWrite for AsyncFdWrapper<H> {
294        fn poll_write(
295            mut self: Pin<&mut Self>,
296            cx: &mut Context<'_>,
297            buf: &[u8],
298        ) -> Poll<io::Result<usize>> {
299            loop {
300                if let Some(res) = self
301                    .as_mut()
302                    .write(cx, |socket| socket.get_mut().write(buf))
303                {
304                    return res;
305                }
306            }
307        }
308
309        fn poll_write_vectored(
310            mut self: Pin<&mut Self>,
311            cx: &mut Context<'_>,
312            bufs: &[IoSlice<'_>],
313        ) -> Poll<io::Result<usize>> {
314            loop {
315                if let Some(res) = self
316                    .as_mut()
317                    .write(cx, |socket| socket.get_mut().write_vectored(bufs))
318                {
319                    return res;
320                }
321            }
322        }
323
324        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
325            loop {
326                if let Some(res) = self.as_mut().write(cx, |socket| socket.get_mut().flush()) {
327                    return res;
328                }
329            }
330        }
331
332        fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
333            self.poll_flush(cx)
334        }
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    #[test]
343    fn auto_traits() {
344        use crate::util::test::*;
345        let runtime = Runtime::tokio().unwrap();
346        assert_send(&runtime);
347        assert_sync(&runtime);
348        assert_clone(&runtime);
349    }
350}