async_rs/implementors/
tokio.rs

1//! tokio implementation of async runtime definition traits
2
3use crate::{
4    Runtime,
5    sys::AsSysFd,
6    traits::{Executor, Reactor, RuntimeKit},
7    util::Task,
8};
9use async_compat::{Compat, CompatExt};
10use cfg_if::cfg_if;
11use futures_core::Stream;
12use futures_io::{AsyncRead, AsyncWrite};
13use std::{
14    future::Future,
15    io::{self, Read, Write},
16    net::SocketAddr,
17    pin::Pin,
18    sync::Arc,
19    task::{Context, Poll},
20    time::{Duration, Instant},
21};
22use tokio::{
23    net::TcpStream,
24    runtime::{EnterGuard, Handle, Runtime as TokioRT},
25    time::Sleep,
26};
27use tokio_stream::{StreamExt, wrappers::IntervalStream};
28
29use task::TTask;
30
31/// Type alias for the tokio runtime
32pub type TokioRuntime = Runtime<Tokio>;
33
34impl TokioRuntime {
35    /// Create a new TokioRuntime and bind it to this tokio runtime.
36    pub fn tokio() -> io::Result<Self> {
37        Ok(Self::tokio_with_runtime(TokioRT::new()?))
38    }
39
40    /// Create a new TokioRuntime and bind it to the current tokio runtime by default.
41    pub fn tokio_current() -> Self {
42        Self::new(Tokio::current())
43    }
44
45    /// Create a new TokioRuntime and bind it to the tokio runtime associated to this handle by default.
46    pub fn tokio_with_handle(handle: Handle) -> Self {
47        Self::new(Tokio::default().with_handle(handle))
48    }
49
50    /// Create a new TokioRuntime and bind it to this tokio runtime.
51    pub fn tokio_with_runtime(runtime: TokioRT) -> Self {
52        Self::new(Tokio::default().with_runtime(runtime))
53    }
54}
55
56/// Dummy object implementing async common interfaces on top of tokio
57#[derive(Default, Clone, Debug)]
58pub struct Tokio {
59    handle: Option<Handle>,
60    runtime: Option<Arc<TokioRT>>,
61}
62
63impl Tokio {
64    /// Bind to the tokio Runtime associated to this handle by default.
65    pub fn with_handle(mut self, handle: Handle) -> Self {
66        self.handle = Some(handle);
67        self
68    }
69
70    /// Bind to the tokio Runtime associated to this handle by default.
71    pub fn with_runtime(mut self, runtime: TokioRT) -> Self {
72        let handle = runtime.handle().clone();
73        self.runtime = Some(Arc::new(runtime));
74        self.with_handle(handle)
75    }
76
77    /// Bind to the current tokio Runtime by default.
78    pub fn current() -> Self {
79        Self::default().with_handle(Handle::current())
80    }
81
82    fn handle(&self) -> Option<Handle> {
83        self.runtime
84            .as_ref()
85            .map(|r| r.handle().clone())
86            .or_else(|| Handle::try_current().ok())
87            .or_else(|| self.handle.clone())
88    }
89
90    fn enter(&self) -> Option<EnterGuard<'_>> {
91        self.runtime
92            .as_ref()
93            .map(|r| r.handle())
94            .or(self.handle.as_ref())
95            .map(Handle::enter)
96    }
97}
98
99impl RuntimeKit for Tokio {}
100
101impl Executor for Tokio {
102    type Task<T: Send + 'static> = TTask<T>;
103
104    fn block_on<T, F: Future<Output = T>>(&self, f: F) -> T {
105        if let Some(runtime) = self.runtime.as_ref() {
106            runtime.block_on(f)
107        } else if let Some(handle) = self.handle() {
108            handle.block_on(f)
109        } else {
110            Handle::current().block_on(f)
111        }
112    }
113
114    fn spawn<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
115        &self,
116        f: F,
117    ) -> Task<Self::Task<T>> {
118        TTask(Some(if let Some(handle) = self.handle() {
119            handle.spawn(f)
120        } else {
121            tokio::task::spawn(f)
122        }))
123        .into()
124    }
125
126    fn spawn_blocking<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(
127        &self,
128        f: F,
129    ) -> Task<Self::Task<T>> {
130        TTask(Some(if let Some(handle) = self.handle() {
131            handle.spawn_blocking(f)
132        } else {
133            tokio::task::spawn_blocking(f)
134        }))
135        .into()
136    }
137}
138
139impl Reactor for Tokio {
140    type TcpStream = Compat<TcpStream>;
141    type Sleep = Sleep;
142
143    fn register<H: Read + Write + AsSysFd + Send + 'static>(
144        &self,
145        socket: H,
146    ) -> io::Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static> {
147        let _enter = self.enter();
148        cfg_if! {
149            if #[cfg(unix)] {
150                Ok(unix::AsyncFdWrapper(
151                    tokio::io::unix::AsyncFd::new(socket)?,
152                ))
153            } else {
154                Err::<crate::util::DummyIO, _>(io::Error::other(
155                    "Registering FD on tokio reactor is only supported on unix",
156                ))
157            }
158        }
159    }
160
161    fn sleep(&self, dur: Duration) -> Self::Sleep {
162        let _enter = self.enter();
163        tokio::time::sleep(dur)
164    }
165
166    fn interval(&self, dur: Duration) -> impl Stream<Item = Instant> + Send + 'static {
167        let _enter = self.enter();
168        IntervalStream::new(tokio::time::interval(dur)).map(tokio::time::Instant::into_std)
169    }
170
171    fn tcp_connect_addr(
172        &self,
173        addr: SocketAddr,
174    ) -> impl Future<Output = io::Result<Self::TcpStream>> + Send + 'static {
175        let _enter = self.enter();
176        async move { Ok(TcpStream::connect(addr).await?.compat()) }
177    }
178}
179
180mod task {
181    use crate::util::TaskImpl;
182    use async_trait::async_trait;
183    use std::{
184        future::Future,
185        pin::Pin,
186        task::{Context, Poll},
187    };
188
189    /// A tokio task
190    #[derive(Debug)]
191    pub struct TTask<T: Send + 'static>(pub(super) Option<tokio::task::JoinHandle<T>>);
192
193    #[async_trait]
194    impl<T: Send + 'static> TaskImpl for TTask<T> {
195        async fn cancel(&mut self) -> Option<T> {
196            let task = self.0.take()?;
197            task.abort();
198            task.await.ok()
199        }
200    }
201
202    impl<T: Send + 'static> Future for TTask<T> {
203        type Output = T;
204
205        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
206            let task = self.0.as_mut().expect("task has been canceled");
207            match Pin::new(task).poll(cx) {
208                Poll::Pending => Poll::Pending,
209                Poll::Ready(res) => Poll::Ready(res.expect("task has been canceled")),
210            }
211        }
212    }
213}
214
215#[cfg(unix)]
216mod unix {
217    use super::*;
218    use futures_io::{AsyncRead, AsyncWrite};
219    use std::io::{IoSlice, IoSliceMut};
220    use tokio::io::unix::AsyncFd;
221
222    pub(super) struct AsyncFdWrapper<H: Read + Write + AsSysFd>(pub(super) AsyncFd<H>);
223
224    impl<H: Read + Write + AsSysFd> AsyncFdWrapper<H> {
225        fn read<F: FnOnce(&mut AsyncFd<H>) -> io::Result<usize>>(
226            mut self: Pin<&mut Self>,
227            cx: &mut Context<'_>,
228            f: F,
229        ) -> Option<Poll<io::Result<usize>>> {
230            Some(match self.0.poll_read_ready_mut(cx) {
231                Poll::Pending => Poll::Pending,
232                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
233                Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
234                    Ok(res) => Poll::Ready(res),
235                    Err(_) => return None,
236                },
237            })
238        }
239
240        fn write<R, F: FnOnce(&mut AsyncFd<H>) -> io::Result<R>>(
241            mut self: Pin<&mut Self>,
242            cx: &mut Context<'_>,
243            f: F,
244        ) -> Option<Poll<io::Result<R>>> {
245            Some(match self.0.poll_write_ready_mut(cx) {
246                Poll::Pending => Poll::Pending,
247                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
248                Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
249                    Ok(res) => Poll::Ready(res),
250                    Err(_) => return None,
251                },
252            })
253        }
254    }
255
256    impl<H: Read + Write + AsSysFd> Unpin for AsyncFdWrapper<H> {}
257
258    impl<H: Read + Write + AsSysFd> AsyncRead for AsyncFdWrapper<H> {
259        fn poll_read(
260            mut self: Pin<&mut Self>,
261            cx: &mut Context<'_>,
262            buf: &mut [u8],
263        ) -> Poll<io::Result<usize>> {
264            loop {
265                if let Some(res) = self.as_mut().read(cx, |socket| socket.get_mut().read(buf)) {
266                    return res;
267                }
268            }
269        }
270
271        fn poll_read_vectored(
272            mut self: Pin<&mut Self>,
273            cx: &mut Context<'_>,
274            bufs: &mut [IoSliceMut<'_>],
275        ) -> Poll<io::Result<usize>> {
276            loop {
277                if let Some(res) = self
278                    .as_mut()
279                    .read(cx, |socket| socket.get_mut().read_vectored(bufs))
280                {
281                    return res;
282                }
283            }
284        }
285    }
286
287    impl<H: Read + Write + AsSysFd + Send + 'static> AsyncWrite for AsyncFdWrapper<H> {
288        fn poll_write(
289            mut self: Pin<&mut Self>,
290            cx: &mut Context<'_>,
291            buf: &[u8],
292        ) -> Poll<io::Result<usize>> {
293            loop {
294                if let Some(res) = self
295                    .as_mut()
296                    .write(cx, |socket| socket.get_mut().write(buf))
297                {
298                    return res;
299                }
300            }
301        }
302
303        fn poll_write_vectored(
304            mut self: Pin<&mut Self>,
305            cx: &mut Context<'_>,
306            bufs: &[IoSlice<'_>],
307        ) -> Poll<io::Result<usize>> {
308            loop {
309                if let Some(res) = self
310                    .as_mut()
311                    .write(cx, |socket| socket.get_mut().write_vectored(bufs))
312                {
313                    return res;
314                }
315            }
316        }
317
318        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
319            loop {
320                if let Some(res) = self.as_mut().write(cx, |socket| socket.get_mut().flush()) {
321                    return res;
322                }
323            }
324        }
325
326        fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
327            self.poll_flush(cx)
328        }
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    #[test]
337    fn auto_traits() {
338        use crate::util::test::*;
339        let runtime = Runtime::tokio().unwrap();
340        assert_send(&runtime);
341        assert_sync(&runtime);
342        assert_clone(&runtime);
343    }
344}