1use crate::{
4 Runtime,
5 sys::AsSysFd,
6 traits::{Executor, Reactor, RuntimeKit},
7 util::Task,
8};
9use async_compat::{Compat, CompatExt};
10use cfg_if::cfg_if;
11use futures_core::Stream;
12use futures_io::{AsyncRead, AsyncWrite};
13use std::{
14 future::Future,
15 io::{self, Read, Write},
16 net::SocketAddr,
17 pin::Pin,
18 sync::Arc,
19 task::{Context, Poll},
20 time::{Duration, Instant},
21};
22use tokio::{
23 net::TcpStream,
24 runtime::{EnterGuard, Handle, Runtime as TokioRT},
25 time::Sleep,
26};
27use tokio_stream::{StreamExt, wrappers::IntervalStream};
28
29use task::TTask;
30
31pub type TokioRuntime = Runtime<Tokio>;
33
34impl TokioRuntime {
35 pub fn tokio() -> io::Result<Self> {
37 Ok(Self::tokio_with_runtime(TokioRT::new()?))
38 }
39
40 pub fn tokio_current() -> Self {
42 Self::new(Tokio::current())
43 }
44
45 pub fn tokio_with_handle(handle: Handle) -> Self {
47 Self::new(Tokio::default().with_handle(handle))
48 }
49
50 pub fn tokio_with_runtime(runtime: TokioRT) -> Self {
52 Self::new(Tokio::default().with_runtime(runtime))
53 }
54}
55
56#[derive(Default, Clone, Debug)]
58pub struct Tokio {
59 handle: Option<Handle>,
60 runtime: Option<Arc<TokioRT>>,
61}
62
63impl Tokio {
64 pub fn with_handle(mut self, handle: Handle) -> Self {
66 self.handle = Some(handle);
67 self
68 }
69
70 pub fn with_runtime(mut self, runtime: TokioRT) -> Self {
72 let handle = runtime.handle().clone();
73 self.runtime = Some(Arc::new(runtime));
74 self.with_handle(handle)
75 }
76
77 pub fn current() -> Self {
79 Self::default().with_handle(Handle::current())
80 }
81
82 fn handle(&self) -> Option<Handle> {
83 self.runtime
84 .as_ref()
85 .map(|r| r.handle().clone())
86 .or_else(|| Handle::try_current().ok())
87 .or_else(|| self.handle.clone())
88 }
89
90 fn enter(&self) -> Option<EnterGuard<'_>> {
91 self.runtime
92 .as_ref()
93 .map(|r| r.handle())
94 .or(self.handle.as_ref())
95 .map(Handle::enter)
96 }
97}
98
99impl RuntimeKit for Tokio {}
100
101impl Executor for Tokio {
102 type Task<T: Send + 'static> = TTask<T>;
103
104 fn block_on<T, F: Future<Output = T>>(&self, f: F) -> T {
105 if let Some(runtime) = self.runtime.as_ref() {
106 runtime.block_on(f)
107 } else if let Some(handle) = self.handle() {
108 handle.block_on(f)
109 } else {
110 Handle::current().block_on(f)
111 }
112 }
113
114 fn spawn<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
115 &self,
116 f: F,
117 ) -> Task<Self::Task<T>> {
118 TTask(Some(if let Some(handle) = self.handle() {
119 handle.spawn(f)
120 } else {
121 tokio::task::spawn(f)
122 }))
123 .into()
124 }
125
126 fn spawn_blocking<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(
127 &self,
128 f: F,
129 ) -> Task<Self::Task<T>> {
130 TTask(Some(if let Some(handle) = self.handle() {
131 handle.spawn_blocking(f)
132 } else {
133 tokio::task::spawn_blocking(f)
134 }))
135 .into()
136 }
137}
138
139impl Reactor for Tokio {
140 type TcpStream = Compat<TcpStream>;
141 type Sleep = Sleep;
142
143 fn register<H: Read + Write + AsSysFd + Send + 'static>(
144 &self,
145 socket: H,
146 ) -> io::Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static> {
147 let _enter = self.enter();
148 cfg_if! {
149 if #[cfg(unix)] {
150 Ok(unix::AsyncFdWrapper(
151 tokio::io::unix::AsyncFd::new(socket)?,
152 ))
153 } else {
154 Err::<crate::util::DummyIO, _>(io::Error::other(
155 "Registering FD on tokio reactor is only supported on unix",
156 ))
157 }
158 }
159 }
160
161 fn sleep(&self, dur: Duration) -> Self::Sleep {
162 let _enter = self.enter();
163 tokio::time::sleep(dur)
164 }
165
166 fn interval(&self, dur: Duration) -> impl Stream<Item = Instant> + Send + 'static {
167 let _enter = self.enter();
168 IntervalStream::new(tokio::time::interval(dur)).map(tokio::time::Instant::into_std)
169 }
170
171 fn tcp_connect_addr(
172 &self,
173 addr: SocketAddr,
174 ) -> impl Future<Output = io::Result<Self::TcpStream>> + Send + 'static {
175 let _enter = self.enter();
176 async move { Ok(TcpStream::connect(addr).await?.compat()) }
177 }
178}
179
180mod task {
181 use crate::util::TaskImpl;
182 use async_trait::async_trait;
183 use std::{
184 future::Future,
185 pin::Pin,
186 task::{Context, Poll},
187 };
188
189 #[derive(Debug)]
191 pub struct TTask<T: Send + 'static>(pub(super) Option<tokio::task::JoinHandle<T>>);
192
193 #[async_trait]
194 impl<T: Send + 'static> TaskImpl for TTask<T> {
195 async fn cancel(&mut self) -> Option<T> {
196 let task = self.0.take()?;
197 task.abort();
198 task.await.ok()
199 }
200 }
201
202 impl<T: Send + 'static> Future for TTask<T> {
203 type Output = T;
204
205 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
206 let task = self.0.as_mut().expect("task has been canceled");
207 match Pin::new(task).poll(cx) {
208 Poll::Pending => Poll::Pending,
209 Poll::Ready(res) => Poll::Ready(res.expect("task has been canceled")),
210 }
211 }
212 }
213}
214
215#[cfg(unix)]
216mod unix {
217 use super::*;
218 use futures_io::{AsyncRead, AsyncWrite};
219 use std::io::{IoSlice, IoSliceMut};
220 use tokio::io::unix::AsyncFd;
221
222 pub(super) struct AsyncFdWrapper<H: Read + Write + AsSysFd>(pub(super) AsyncFd<H>);
223
224 impl<H: Read + Write + AsSysFd> AsyncFdWrapper<H> {
225 fn read<F: FnOnce(&mut AsyncFd<H>) -> io::Result<usize>>(
226 mut self: Pin<&mut Self>,
227 cx: &mut Context<'_>,
228 f: F,
229 ) -> Option<Poll<io::Result<usize>>> {
230 Some(match self.0.poll_read_ready_mut(cx) {
231 Poll::Pending => Poll::Pending,
232 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
233 Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
234 Ok(res) => Poll::Ready(res),
235 Err(_) => return None,
236 },
237 })
238 }
239
240 fn write<R, F: FnOnce(&mut AsyncFd<H>) -> io::Result<R>>(
241 mut self: Pin<&mut Self>,
242 cx: &mut Context<'_>,
243 f: F,
244 ) -> Option<Poll<io::Result<R>>> {
245 Some(match self.0.poll_write_ready_mut(cx) {
246 Poll::Pending => Poll::Pending,
247 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
248 Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
249 Ok(res) => Poll::Ready(res),
250 Err(_) => return None,
251 },
252 })
253 }
254 }
255
256 impl<H: Read + Write + AsSysFd> Unpin for AsyncFdWrapper<H> {}
257
258 impl<H: Read + Write + AsSysFd> AsyncRead for AsyncFdWrapper<H> {
259 fn poll_read(
260 mut self: Pin<&mut Self>,
261 cx: &mut Context<'_>,
262 buf: &mut [u8],
263 ) -> Poll<io::Result<usize>> {
264 loop {
265 if let Some(res) = self.as_mut().read(cx, |socket| socket.get_mut().read(buf)) {
266 return res;
267 }
268 }
269 }
270
271 fn poll_read_vectored(
272 mut self: Pin<&mut Self>,
273 cx: &mut Context<'_>,
274 bufs: &mut [IoSliceMut<'_>],
275 ) -> Poll<io::Result<usize>> {
276 loop {
277 if let Some(res) = self
278 .as_mut()
279 .read(cx, |socket| socket.get_mut().read_vectored(bufs))
280 {
281 return res;
282 }
283 }
284 }
285 }
286
287 impl<H: Read + Write + AsSysFd + Send + 'static> AsyncWrite for AsyncFdWrapper<H> {
288 fn poll_write(
289 mut self: Pin<&mut Self>,
290 cx: &mut Context<'_>,
291 buf: &[u8],
292 ) -> Poll<io::Result<usize>> {
293 loop {
294 if let Some(res) = self
295 .as_mut()
296 .write(cx, |socket| socket.get_mut().write(buf))
297 {
298 return res;
299 }
300 }
301 }
302
303 fn poll_write_vectored(
304 mut self: Pin<&mut Self>,
305 cx: &mut Context<'_>,
306 bufs: &[IoSlice<'_>],
307 ) -> Poll<io::Result<usize>> {
308 loop {
309 if let Some(res) = self
310 .as_mut()
311 .write(cx, |socket| socket.get_mut().write_vectored(bufs))
312 {
313 return res;
314 }
315 }
316 }
317
318 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
319 loop {
320 if let Some(res) = self.as_mut().write(cx, |socket| socket.get_mut().flush()) {
321 return res;
322 }
323 }
324 }
325
326 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
327 self.poll_flush(cx)
328 }
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 #[test]
337 fn auto_traits() {
338 use crate::util::test::*;
339 let runtime = Runtime::tokio().unwrap();
340 assert_send(&runtime);
341 assert_sync(&runtime);
342 assert_clone(&runtime);
343 }
344}