async_rs/implementors/
tokio.rs

1//! tokio implementation of async runtime definition traits
2
3use crate::{
4    Runtime,
5    sys::AsSysFd,
6    traits::{Executor, Reactor, RuntimeKit},
7    util::Task,
8};
9use cfg_if::cfg_if;
10use futures_core::Stream;
11use futures_io::{AsyncRead, AsyncWrite};
12use std::{
13    future::Future,
14    io::{self, Read, Write},
15    net::SocketAddr,
16    pin::Pin,
17    sync::Arc,
18    task::{Context, Poll},
19    time::{Duration, Instant},
20};
21use tokio::{
22    net::TcpStream,
23    runtime::{EnterGuard, Handle, Runtime as TokioRT},
24};
25use tokio_stream::{StreamExt, wrappers::IntervalStream};
26use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
27
28use task::TTask;
29
30/// Type alias for the tokio runtime
31pub type TokioRuntime = Runtime<Tokio>;
32
33impl TokioRuntime {
34    /// Create a new TokioRuntime and bind it to this tokio runtime.
35    pub fn tokio() -> io::Result<Self> {
36        Ok(Self::tokio_with_runtime(TokioRT::new()?))
37    }
38
39    /// Create a new TokioRuntime and bind it to the current tokio runtime by default.
40    pub fn tokio_current() -> Self {
41        Self::new(Tokio::current())
42    }
43
44    /// Create a new TokioRuntime and bind it to the tokio runtime associated to this handle by default.
45    pub fn tokio_with_handle(handle: Handle) -> Self {
46        Self::new(Tokio::default().with_handle(handle))
47    }
48
49    /// Create a new TokioRuntime and bind it to this tokio runtime.
50    pub fn tokio_with_runtime(runtime: TokioRT) -> Self {
51        Self::new(Tokio::default().with_runtime(runtime))
52    }
53}
54
55/// Dummy object implementing async common interfaces on top of tokio
56#[derive(Default, Clone, Debug)]
57pub struct Tokio {
58    handle: Option<Handle>,
59    runtime: Option<Arc<TokioRT>>,
60}
61
62impl Tokio {
63    /// Bind to the tokio Runtime associated to this handle by default.
64    pub fn with_handle(mut self, handle: Handle) -> Self {
65        self.handle = Some(handle);
66        self
67    }
68
69    /// Bind to the tokio Runtime associated to this handle by default.
70    pub fn with_runtime(mut self, runtime: TokioRT) -> Self {
71        let handle = runtime.handle().clone();
72        self.runtime = Some(Arc::new(runtime));
73        self.with_handle(handle)
74    }
75
76    /// Bind to the current tokio Runtime by default.
77    pub fn current() -> Self {
78        Self::default().with_handle(Handle::current())
79    }
80
81    fn handle(&self) -> Option<Handle> {
82        self.runtime
83            .as_ref()
84            .map(|r| r.handle().clone())
85            .or_else(|| Handle::try_current().ok())
86            .or_else(|| self.handle.clone())
87    }
88
89    fn enter(&self) -> Option<EnterGuard<'_>> {
90        self.runtime
91            .as_ref()
92            .map(|r| r.handle())
93            .or(self.handle.as_ref())
94            .map(Handle::enter)
95    }
96}
97
98impl RuntimeKit for Tokio {}
99
100impl Executor for Tokio {
101    type Task<T: Send + 'static> = TTask<T>;
102
103    fn block_on<T, F: Future<Output = T>>(&self, f: F) -> T {
104        if let Some(runtime) = self.runtime.as_ref() {
105            runtime.block_on(f)
106        } else if let Some(handle) = self.handle() {
107            handle.block_on(f)
108        } else {
109            Handle::current().block_on(f)
110        }
111    }
112
113    fn spawn<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
114        &self,
115        f: F,
116    ) -> Task<Self::Task<T>> {
117        TTask(Some(if let Some(handle) = self.handle() {
118            handle.spawn(f)
119        } else {
120            tokio::task::spawn(f)
121        }))
122        .into()
123    }
124
125    fn spawn_blocking<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(
126        &self,
127        f: F,
128    ) -> Task<Self::Task<T>> {
129        TTask(Some(if let Some(handle) = self.handle() {
130            handle.spawn_blocking(f)
131        } else {
132            tokio::task::spawn_blocking(f)
133        }))
134        .into()
135    }
136}
137
138impl Reactor for Tokio {
139    type TcpStream = Compat<TcpStream>;
140
141    fn register<H: Read + Write + AsSysFd + Send + 'static>(
142        &self,
143        socket: H,
144    ) -> io::Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static> {
145        let _enter = self.enter();
146        cfg_if! {
147            if #[cfg(unix)] {
148                Ok(unix::AsyncFdWrapper(
149                    tokio::io::unix::AsyncFd::new(socket)?,
150                ))
151            } else {
152                Err::<crate::util::DummyIO, _>(io::Error::other(
153                    "Registering FD on tokio reactor is only supported on unix",
154                ))
155            }
156        }
157    }
158
159    fn sleep(&self, dur: Duration) -> impl Future<Output = ()> + Send + 'static {
160        let _enter = self.enter();
161        tokio::time::sleep(dur)
162    }
163
164    fn interval(&self, dur: Duration) -> impl Stream<Item = Instant> + Send + 'static {
165        let _enter = self.enter();
166        IntervalStream::new(tokio::time::interval(dur)).map(tokio::time::Instant::into_std)
167    }
168
169    fn tcp_connect_addr(
170        &self,
171        addr: SocketAddr,
172    ) -> impl Future<Output = io::Result<Self::TcpStream>> + Send + 'static {
173        let _enter = self.enter();
174        async move { Ok(TcpStream::connect(addr).await?.compat()) }
175    }
176}
177
178mod task {
179    use crate::util::TaskImpl;
180    use async_trait::async_trait;
181    use std::{
182        future::Future,
183        pin::Pin,
184        task::{Context, Poll},
185    };
186
187    /// A tokio task
188    #[derive(Debug)]
189    pub struct TTask<T: Send + 'static>(pub(super) Option<tokio::task::JoinHandle<T>>);
190
191    #[async_trait]
192    impl<T: Send + 'static> TaskImpl for TTask<T> {
193        async fn cancel(&mut self) -> Option<T> {
194            let task = self.0.take()?;
195            task.abort();
196            task.await.ok()
197        }
198    }
199
200    impl<T: Send + 'static> Future for TTask<T> {
201        type Output = T;
202
203        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
204            let task = self.0.as_mut().expect("task has been canceled");
205            match Pin::new(task).poll(cx) {
206                Poll::Pending => Poll::Pending,
207                Poll::Ready(res) => Poll::Ready(res.expect("task has been canceled")),
208            }
209        }
210    }
211}
212
213#[cfg(unix)]
214mod unix {
215    use super::*;
216    use futures_io::{AsyncRead, AsyncWrite};
217    use std::io::{IoSlice, IoSliceMut};
218    use tokio::io::unix::AsyncFd;
219
220    pub(super) struct AsyncFdWrapper<H: Read + Write + AsSysFd>(pub(super) AsyncFd<H>);
221
222    impl<H: Read + Write + AsSysFd> AsyncFdWrapper<H> {
223        fn read<F: FnOnce(&mut AsyncFd<H>) -> io::Result<usize>>(
224            mut self: Pin<&mut Self>,
225            cx: &mut Context<'_>,
226            f: F,
227        ) -> Option<Poll<io::Result<usize>>> {
228            Some(match self.0.poll_read_ready_mut(cx) {
229                Poll::Pending => Poll::Pending,
230                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
231                Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
232                    Ok(res) => Poll::Ready(res),
233                    Err(_) => return None,
234                },
235            })
236        }
237
238        fn write<R, F: FnOnce(&mut AsyncFd<H>) -> io::Result<R>>(
239            mut self: Pin<&mut Self>,
240            cx: &mut Context<'_>,
241            f: F,
242        ) -> Option<Poll<io::Result<R>>> {
243            Some(match self.0.poll_write_ready_mut(cx) {
244                Poll::Pending => Poll::Pending,
245                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
246                Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
247                    Ok(res) => Poll::Ready(res),
248                    Err(_) => return None,
249                },
250            })
251        }
252    }
253
254    impl<H: Read + Write + AsSysFd> Unpin for AsyncFdWrapper<H> {}
255
256    impl<H: Read + Write + AsSysFd> AsyncRead for AsyncFdWrapper<H> {
257        fn poll_read(
258            mut self: Pin<&mut Self>,
259            cx: &mut Context<'_>,
260            buf: &mut [u8],
261        ) -> Poll<io::Result<usize>> {
262            loop {
263                if let Some(res) = self.as_mut().read(cx, |socket| socket.get_mut().read(buf)) {
264                    return res;
265                }
266            }
267        }
268
269        fn poll_read_vectored(
270            mut self: Pin<&mut Self>,
271            cx: &mut Context<'_>,
272            bufs: &mut [IoSliceMut<'_>],
273        ) -> Poll<io::Result<usize>> {
274            loop {
275                if let Some(res) = self
276                    .as_mut()
277                    .read(cx, |socket| socket.get_mut().read_vectored(bufs))
278                {
279                    return res;
280                }
281            }
282        }
283    }
284
285    impl<H: Read + Write + AsSysFd + Send + 'static> AsyncWrite for AsyncFdWrapper<H> {
286        fn poll_write(
287            mut self: Pin<&mut Self>,
288            cx: &mut Context<'_>,
289            buf: &[u8],
290        ) -> Poll<io::Result<usize>> {
291            loop {
292                if let Some(res) = self
293                    .as_mut()
294                    .write(cx, |socket| socket.get_mut().write(buf))
295                {
296                    return res;
297                }
298            }
299        }
300
301        fn poll_write_vectored(
302            mut self: Pin<&mut Self>,
303            cx: &mut Context<'_>,
304            bufs: &[IoSlice<'_>],
305        ) -> Poll<io::Result<usize>> {
306            loop {
307                if let Some(res) = self
308                    .as_mut()
309                    .write(cx, |socket| socket.get_mut().write_vectored(bufs))
310                {
311                    return res;
312                }
313            }
314        }
315
316        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
317            loop {
318                if let Some(res) = self.as_mut().write(cx, |socket| socket.get_mut().flush()) {
319                    return res;
320                }
321            }
322        }
323
324        fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
325            self.poll_flush(cx)
326        }
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    #[test]
335    fn auto_traits() {
336        use crate::util::test::*;
337        let runtime = Runtime::tokio().unwrap();
338        assert_send(&runtime);
339        assert_sync(&runtime);
340        assert_clone(&runtime);
341    }
342}