Skip to main content

async_rs/implementors/
tokio.rs

1//! tokio implementation of async runtime definition traits
2
3use crate::{
4    Runtime,
5    sys::AsSysFd,
6    traits::{Executor, Reactor, RuntimeKit},
7    util::Task,
8};
9use async_compat::{Compat, CompatExt};
10use futures_core::Stream;
11use futures_io::{AsyncRead, AsyncWrite};
12use std::{
13    future::Future,
14    io::{self, Read, Write},
15    net::SocketAddr,
16    pin::Pin,
17    sync::Arc,
18    task::{Context, Poll},
19    time::{Duration, Instant},
20};
21use tokio::{
22    net::TcpStream,
23    runtime::{EnterGuard, Handle, Runtime as TokioRT},
24    time::Sleep,
25};
26use tokio_stream::{StreamExt, wrappers::IntervalStream};
27
28use task::TTask;
29
30/// Type alias for the tokio runtime
31pub type TokioRuntime = Runtime<Tokio>;
32
33impl TokioRuntime {
34    /// Create a new TokioRuntime and bind it to this tokio runtime.
35    pub fn tokio() -> io::Result<Self> {
36        Ok(Self::tokio_with_runtime(TokioRT::new()?))
37    }
38
39    /// Create a new TokioRuntime and bind it to the current tokio runtime by default.
40    pub fn tokio_current() -> Self {
41        Self::new(Tokio::current())
42    }
43
44    /// Create a new TokioRuntime and bind it to the tokio runtime associated to this handle by default.
45    pub fn tokio_with_handle(handle: Handle) -> Self {
46        Self::new(Tokio::default().with_handle(handle))
47    }
48
49    /// Create a new TokioRuntime and bind it to this tokio runtime.
50    pub fn tokio_with_runtime(runtime: TokioRT) -> Self {
51        Self::new(Tokio::default().with_runtime(runtime))
52    }
53}
54
55/// Dummy object implementing async common interfaces on top of tokio
56#[derive(Default, Clone, Debug)]
57pub struct Tokio {
58    handle: Option<Handle>,
59    runtime: Option<Arc<TokioRT>>,
60}
61
62impl Tokio {
63    /// Bind to the tokio Runtime associated to this handle by default.
64    pub fn with_handle(mut self, handle: Handle) -> Self {
65        self.handle = Some(handle);
66        self
67    }
68
69    /// Bind to the tokio Runtime associated to this handle by default.
70    pub fn with_runtime(mut self, runtime: TokioRT) -> Self {
71        let handle = runtime.handle().clone();
72        self.runtime = Some(Arc::new(runtime));
73        self.with_handle(handle)
74    }
75
76    /// Bind to the current tokio Runtime by default.
77    pub fn current() -> Self {
78        Self::default().with_handle(Handle::current())
79    }
80
81    fn handle(&self) -> Option<Handle> {
82        self.handle.clone().or_else(|| Handle::try_current().ok())
83    }
84
85    fn enter(&self) -> Option<EnterGuard<'_>> {
86        self.runtime
87            .as_ref()
88            .map(|r| r.handle())
89            .or(self.handle.as_ref())
90            .map(Handle::enter)
91    }
92}
93
94impl RuntimeKit for Tokio {}
95
96impl Executor for Tokio {
97    type Task<T: Send + 'static> = TTask<T>;
98
99    fn block_on<T, F: Future<Output = T>>(&self, f: F) -> T {
100        if let Some(runtime) = self.runtime.as_ref() {
101            runtime.block_on(f)
102        } else if let Some(handle) = self.handle() {
103            handle.block_on(f)
104        } else {
105            Handle::try_current()
106                .expect("no tokio runtime: use Runtime::tokio() or Runtime::tokio_with_handle()")
107                .block_on(f)
108        }
109    }
110
111    fn spawn<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
112        &self,
113        f: F,
114    ) -> Task<Self::Task<T>> {
115        TTask(Some(if let Some(handle) = self.handle() {
116            handle.spawn(f)
117        } else {
118            tokio::task::spawn(f)
119        }))
120        .into()
121    }
122
123    fn spawn_blocking<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(
124        &self,
125        f: F,
126    ) -> Task<Self::Task<T>> {
127        TTask(Some(if let Some(handle) = self.handle() {
128            handle.spawn_blocking(f)
129        } else {
130            tokio::task::spawn_blocking(f)
131        }))
132        .into()
133    }
134}
135
136impl Reactor for Tokio {
137    type TcpStream = Compat<TcpStream>;
138    type Sleep = Sleep;
139
140    fn register<H: Read + Write + AsSysFd + Send + 'static>(
141        &self,
142        socket: H,
143    ) -> io::Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static> {
144        let _enter = self.enter();
145        #[cfg(unix)]
146        {
147            Ok(unix::AsyncFdWrapper(tokio::io::unix::AsyncFd::new(socket)?))
148        }
149        #[cfg(not(unix))]
150        {
151            let _ = socket;
152            Err::<crate::util::DummyIO, _>(io::Error::other(
153                "Registering FD on tokio reactor is only supported on unix",
154            ))
155        }
156    }
157
158    fn sleep(&self, dur: Duration) -> Self::Sleep {
159        let _enter = self.enter();
160        tokio::time::sleep(dur)
161    }
162
163    fn interval(&self, dur: Duration) -> impl Stream<Item = Instant> + Send + 'static {
164        let _enter = self.enter();
165        IntervalStream::new(tokio::time::interval(dur)).map(tokio::time::Instant::into_std)
166    }
167
168    fn tcp_connect_addr(
169        &self,
170        addr: SocketAddr,
171    ) -> impl Future<Output = io::Result<Self::TcpStream>> + Send + 'static {
172        let _enter = self.enter();
173        async move { Ok(TcpStream::connect(addr).await?.compat()) }
174    }
175}
176
177mod task {
178    use crate::util::TaskImpl;
179    use async_trait::async_trait;
180    use std::{
181        future::Future,
182        pin::Pin,
183        task::{Context, Poll},
184    };
185
186    /// A tokio task
187    #[derive(Debug)]
188    pub struct TTask<T: Send + 'static>(pub(super) Option<tokio::task::JoinHandle<T>>);
189
190    #[async_trait]
191    impl<T: Send + 'static> TaskImpl for TTask<T> {
192        async fn cancel(&mut self) -> Option<T> {
193            let task = self.0.take()?;
194            task.abort();
195            task.await.ok()
196        }
197    }
198
199    impl<T: Send + 'static> Future for TTask<T> {
200        type Output = T;
201
202        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
203            let task = self.0.as_mut().expect("task has been canceled");
204            match Pin::new(task).poll(cx) {
205                Poll::Pending => Poll::Pending,
206                Poll::Ready(res) => Poll::Ready(res.expect("task has been canceled")),
207            }
208        }
209    }
210}
211
212#[cfg(unix)]
213mod unix {
214    use super::*;
215    use futures_io::{AsyncRead, AsyncWrite};
216    use std::io::{IoSlice, IoSliceMut};
217    use tokio::io::unix::AsyncFd;
218
219    pub(super) struct AsyncFdWrapper<H: Read + Write + AsSysFd>(pub(super) AsyncFd<H>);
220
221    impl<H: Read + Write + AsSysFd> AsyncFdWrapper<H> {
222        fn read<F: FnOnce(&mut AsyncFd<H>) -> io::Result<usize>>(
223            mut self: Pin<&mut Self>,
224            cx: &mut Context<'_>,
225            f: F,
226        ) -> Option<Poll<io::Result<usize>>> {
227            Some(match self.0.poll_read_ready_mut(cx) {
228                Poll::Pending => Poll::Pending,
229                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
230                Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
231                    Ok(res) => Poll::Ready(res),
232                    Err(_) => return None,
233                },
234            })
235        }
236
237        fn write<R, F: FnOnce(&mut AsyncFd<H>) -> io::Result<R>>(
238            mut self: Pin<&mut Self>,
239            cx: &mut Context<'_>,
240            f: F,
241        ) -> Option<Poll<io::Result<R>>> {
242            Some(match self.0.poll_write_ready_mut(cx) {
243                Poll::Pending => Poll::Pending,
244                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
245                Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
246                    Ok(res) => Poll::Ready(res),
247                    Err(_) => return None,
248                },
249            })
250        }
251    }
252
253    impl<H: Read + Write + AsSysFd> Unpin for AsyncFdWrapper<H> {}
254
255    impl<H: Read + Write + AsSysFd> AsyncRead for AsyncFdWrapper<H> {
256        fn poll_read(
257            mut self: Pin<&mut Self>,
258            cx: &mut Context<'_>,
259            buf: &mut [u8],
260        ) -> Poll<io::Result<usize>> {
261            loop {
262                if let Some(res) = self.as_mut().read(cx, |socket| socket.get_mut().read(buf)) {
263                    return res;
264                }
265            }
266        }
267
268        fn poll_read_vectored(
269            mut self: Pin<&mut Self>,
270            cx: &mut Context<'_>,
271            bufs: &mut [IoSliceMut<'_>],
272        ) -> Poll<io::Result<usize>> {
273            loop {
274                if let Some(res) = self
275                    .as_mut()
276                    .read(cx, |socket| socket.get_mut().read_vectored(bufs))
277                {
278                    return res;
279                }
280            }
281        }
282    }
283
284    impl<H: Read + Write + AsSysFd + Send + 'static> AsyncWrite for AsyncFdWrapper<H> {
285        fn poll_write(
286            mut self: Pin<&mut Self>,
287            cx: &mut Context<'_>,
288            buf: &[u8],
289        ) -> Poll<io::Result<usize>> {
290            loop {
291                if let Some(res) = self
292                    .as_mut()
293                    .write(cx, |socket| socket.get_mut().write(buf))
294                {
295                    return res;
296                }
297            }
298        }
299
300        fn poll_write_vectored(
301            mut self: Pin<&mut Self>,
302            cx: &mut Context<'_>,
303            bufs: &[IoSlice<'_>],
304        ) -> Poll<io::Result<usize>> {
305            loop {
306                if let Some(res) = self
307                    .as_mut()
308                    .write(cx, |socket| socket.get_mut().write_vectored(bufs))
309                {
310                    return res;
311                }
312            }
313        }
314
315        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
316            loop {
317                if let Some(res) = self.as_mut().write(cx, |socket| socket.get_mut().flush()) {
318                    return res;
319                }
320            }
321        }
322
323        fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
324            self.poll_flush(cx)
325        }
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    #[test]
334    fn auto_traits() {
335        use crate::util::test::*;
336        let runtime = Runtime::tokio().unwrap();
337        assert_send(&runtime);
338        assert_sync(&runtime);
339        assert_clone(&runtime);
340    }
341}