Skip to main content

async_rs/implementors/
tokio.rs

1//! tokio implementation of async runtime definition traits
2
3use crate::{
4    Runtime,
5    sys::AsSysFd,
6    traits::{Executor, Reactor, RuntimeKit},
7    util::Task,
8};
9use async_compat::{Compat, CompatExt};
10use cfg_if::cfg_if;
11use futures_core::Stream;
12use futures_io::{AsyncRead, AsyncWrite};
13use std::{
14    future::Future,
15    io::{self, Read, Write},
16    net::SocketAddr,
17    pin::Pin,
18    sync::Arc,
19    task::{Context, Poll},
20    time::{Duration, Instant},
21};
22use tokio::{
23    net::TcpStream,
24    runtime::{EnterGuard, Handle, Runtime as TokioRT},
25    time::Sleep,
26};
27use tokio_stream::{StreamExt, wrappers::IntervalStream};
28
29use task::TTask;
30
31/// Type alias for the tokio runtime
32pub type TokioRuntime = Runtime<Tokio>;
33
34impl TokioRuntime {
35    /// Create a new TokioRuntime and bind it to this tokio runtime.
36    pub fn tokio() -> io::Result<Self> {
37        Ok(Self::tokio_with_runtime(TokioRT::new()?))
38    }
39
40    /// Create a new TokioRuntime and bind it to the current tokio runtime by default.
41    pub fn tokio_current() -> Self {
42        Self::new(Tokio::current())
43    }
44
45    /// Create a new TokioRuntime and bind it to the tokio runtime associated to this handle by default.
46    pub fn tokio_with_handle(handle: Handle) -> Self {
47        Self::new(Tokio::default().with_handle(handle))
48    }
49
50    /// Create a new TokioRuntime and bind it to this tokio runtime.
51    pub fn tokio_with_runtime(runtime: TokioRT) -> Self {
52        Self::new(Tokio::default().with_runtime(runtime))
53    }
54}
55
56/// Dummy object implementing async common interfaces on top of tokio
57#[derive(Default, Clone, Debug)]
58pub struct Tokio {
59    handle: Option<Handle>,
60    runtime: Option<Arc<TokioRT>>,
61}
62
63impl Tokio {
64    /// Bind to the tokio Runtime associated to this handle by default.
65    pub fn with_handle(mut self, handle: Handle) -> Self {
66        self.handle = Some(handle);
67        self
68    }
69
70    /// Bind to the tokio Runtime associated to this handle by default.
71    pub fn with_runtime(mut self, runtime: TokioRT) -> Self {
72        let handle = runtime.handle().clone();
73        self.runtime = Some(Arc::new(runtime));
74        self.with_handle(handle)
75    }
76
77    /// Bind to the current tokio Runtime by default.
78    pub fn current() -> Self {
79        Self::default().with_handle(Handle::current())
80    }
81
82    fn handle(&self) -> Option<Handle> {
83        self.handle.clone().or_else(|| Handle::try_current().ok())
84    }
85
86    fn enter(&self) -> Option<EnterGuard<'_>> {
87        self.runtime
88            .as_ref()
89            .map(|r| r.handle())
90            .or(self.handle.as_ref())
91            .map(Handle::enter)
92    }
93}
94
95impl RuntimeKit for Tokio {}
96
97impl Executor for Tokio {
98    type Task<T: Send + 'static> = TTask<T>;
99
100    fn block_on<T, F: Future<Output = T>>(&self, f: F) -> T {
101        if let Some(runtime) = self.runtime.as_ref() {
102            runtime.block_on(f)
103        } else if let Some(handle) = self.handle() {
104            handle.block_on(f)
105        } else {
106            Handle::current().block_on(f)
107        }
108    }
109
110    fn spawn<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
111        &self,
112        f: F,
113    ) -> Task<Self::Task<T>> {
114        TTask(Some(if let Some(handle) = self.handle() {
115            handle.spawn(f)
116        } else {
117            tokio::task::spawn(f)
118        }))
119        .into()
120    }
121
122    fn spawn_blocking<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(
123        &self,
124        f: F,
125    ) -> Task<Self::Task<T>> {
126        TTask(Some(if let Some(handle) = self.handle() {
127            handle.spawn_blocking(f)
128        } else {
129            tokio::task::spawn_blocking(f)
130        }))
131        .into()
132    }
133}
134
135impl Reactor for Tokio {
136    type TcpStream = Compat<TcpStream>;
137    type Sleep = Sleep;
138
139    fn register<H: Read + Write + AsSysFd + Send + 'static>(
140        &self,
141        socket: H,
142    ) -> io::Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static> {
143        let _enter = self.enter();
144        cfg_if! {
145            if #[cfg(unix)] {
146                Ok(unix::AsyncFdWrapper(
147                    tokio::io::unix::AsyncFd::new(socket)?,
148                ))
149            } else {
150                Err::<crate::util::DummyIO, _>(io::Error::other(
151                    "Registering FD on tokio reactor is only supported on unix",
152                ))
153            }
154        }
155    }
156
157    fn sleep(&self, dur: Duration) -> Self::Sleep {
158        let _enter = self.enter();
159        tokio::time::sleep(dur)
160    }
161
162    fn interval(&self, dur: Duration) -> impl Stream<Item = Instant> + Send + 'static {
163        let _enter = self.enter();
164        IntervalStream::new(tokio::time::interval(dur)).map(tokio::time::Instant::into_std)
165    }
166
167    fn tcp_connect_addr(
168        &self,
169        addr: SocketAddr,
170    ) -> impl Future<Output = io::Result<Self::TcpStream>> + Send + 'static {
171        let _enter = self.enter();
172        async move { Ok(TcpStream::connect(addr).await?.compat()) }
173    }
174}
175
176mod task {
177    use crate::util::TaskImpl;
178    use async_trait::async_trait;
179    use std::{
180        future::Future,
181        pin::Pin,
182        task::{Context, Poll},
183    };
184
185    /// A tokio task
186    #[derive(Debug)]
187    pub struct TTask<T: Send + 'static>(pub(super) Option<tokio::task::JoinHandle<T>>);
188
189    #[async_trait]
190    impl<T: Send + 'static> TaskImpl for TTask<T> {
191        async fn cancel(&mut self) -> Option<T> {
192            let task = self.0.take()?;
193            task.abort();
194            task.await.ok()
195        }
196    }
197
198    impl<T: Send + 'static> Future for TTask<T> {
199        type Output = T;
200
201        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
202            let task = self.0.as_mut().expect("task has been canceled");
203            match Pin::new(task).poll(cx) {
204                Poll::Pending => Poll::Pending,
205                Poll::Ready(res) => Poll::Ready(res.expect("task has been canceled")),
206            }
207        }
208    }
209}
210
211#[cfg(unix)]
212mod unix {
213    use super::*;
214    use futures_io::{AsyncRead, AsyncWrite};
215    use std::io::{IoSlice, IoSliceMut};
216    use tokio::io::unix::AsyncFd;
217
218    pub(super) struct AsyncFdWrapper<H: Read + Write + AsSysFd>(pub(super) AsyncFd<H>);
219
220    impl<H: Read + Write + AsSysFd> AsyncFdWrapper<H> {
221        fn read<F: FnOnce(&mut AsyncFd<H>) -> io::Result<usize>>(
222            mut self: Pin<&mut Self>,
223            cx: &mut Context<'_>,
224            f: F,
225        ) -> Option<Poll<io::Result<usize>>> {
226            Some(match self.0.poll_read_ready_mut(cx) {
227                Poll::Pending => Poll::Pending,
228                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
229                Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
230                    Ok(res) => Poll::Ready(res),
231                    Err(_) => return None,
232                },
233            })
234        }
235
236        fn write<R, F: FnOnce(&mut AsyncFd<H>) -> io::Result<R>>(
237            mut self: Pin<&mut Self>,
238            cx: &mut Context<'_>,
239            f: F,
240        ) -> Option<Poll<io::Result<R>>> {
241            Some(match self.0.poll_write_ready_mut(cx) {
242                Poll::Pending => Poll::Pending,
243                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
244                Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
245                    Ok(res) => Poll::Ready(res),
246                    Err(_) => return None,
247                },
248            })
249        }
250    }
251
252    impl<H: Read + Write + AsSysFd> Unpin for AsyncFdWrapper<H> {}
253
254    impl<H: Read + Write + AsSysFd> AsyncRead for AsyncFdWrapper<H> {
255        fn poll_read(
256            mut self: Pin<&mut Self>,
257            cx: &mut Context<'_>,
258            buf: &mut [u8],
259        ) -> Poll<io::Result<usize>> {
260            loop {
261                if let Some(res) = self.as_mut().read(cx, |socket| socket.get_mut().read(buf)) {
262                    return res;
263                }
264            }
265        }
266
267        fn poll_read_vectored(
268            mut self: Pin<&mut Self>,
269            cx: &mut Context<'_>,
270            bufs: &mut [IoSliceMut<'_>],
271        ) -> Poll<io::Result<usize>> {
272            loop {
273                if let Some(res) = self
274                    .as_mut()
275                    .read(cx, |socket| socket.get_mut().read_vectored(bufs))
276                {
277                    return res;
278                }
279            }
280        }
281    }
282
283    impl<H: Read + Write + AsSysFd + Send + 'static> AsyncWrite for AsyncFdWrapper<H> {
284        fn poll_write(
285            mut self: Pin<&mut Self>,
286            cx: &mut Context<'_>,
287            buf: &[u8],
288        ) -> Poll<io::Result<usize>> {
289            loop {
290                if let Some(res) = self
291                    .as_mut()
292                    .write(cx, |socket| socket.get_mut().write(buf))
293                {
294                    return res;
295                }
296            }
297        }
298
299        fn poll_write_vectored(
300            mut self: Pin<&mut Self>,
301            cx: &mut Context<'_>,
302            bufs: &[IoSlice<'_>],
303        ) -> Poll<io::Result<usize>> {
304            loop {
305                if let Some(res) = self
306                    .as_mut()
307                    .write(cx, |socket| socket.get_mut().write_vectored(bufs))
308                {
309                    return res;
310                }
311            }
312        }
313
314        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
315            loop {
316                if let Some(res) = self.as_mut().write(cx, |socket| socket.get_mut().flush()) {
317                    return res;
318                }
319            }
320        }
321
322        fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
323            self.poll_flush(cx)
324        }
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331
332    #[test]
333    fn auto_traits() {
334        use crate::util::test::*;
335        let runtime = Runtime::tokio().unwrap();
336        assert_send(&runtime);
337        assert_sync(&runtime);
338        assert_clone(&runtime);
339    }
340}