1use crate::{
4 Runtime,
5 sys::AsSysFd,
6 traits::{Executor, Reactor, RuntimeKit},
7 util::Task,
8};
9use async_compat::{Compat, CompatExt};
10use futures_core::Stream;
11use futures_io::{AsyncRead, AsyncWrite};
12use std::{
13 future::Future,
14 io::{self, Read, Write},
15 net::SocketAddr,
16 sync::Arc,
17 time::{Duration, Instant},
18};
19use tokio::{
20 net::TcpStream,
21 runtime::{EnterGuard, Handle, Runtime as TokioRT},
22 time::Sleep,
23};
24use tokio_stream::{StreamExt, wrappers::IntervalStream};
25
26use task::TTask;
27
28pub type TokioRuntime = Runtime<Tokio>;
30
31impl TokioRuntime {
32 pub fn tokio() -> io::Result<Self> {
34 Ok(Self::tokio_with_runtime(TokioRT::new()?))
35 }
36
37 pub fn tokio_current() -> Self {
39 Self::new(Tokio::current())
40 }
41
42 pub fn tokio_with_handle(handle: Handle) -> Self {
44 Self::new(Tokio::default().with_handle(handle))
45 }
46
47 pub fn tokio_with_runtime(runtime: TokioRT) -> Self {
49 Self::new(Tokio::default().with_runtime(runtime))
50 }
51}
52
53#[derive(Default, Clone, Debug)]
55pub struct Tokio {
56 handle: Option<Handle>,
57 runtime: Option<Arc<TokioRT>>,
58}
59
60impl Tokio {
61 pub fn with_handle(mut self, handle: Handle) -> Self {
63 self.handle = Some(handle);
64 self
65 }
66
67 pub fn with_runtime(mut self, runtime: TokioRT) -> Self {
69 let handle = runtime.handle().clone();
70 self.runtime = Some(Arc::new(runtime));
71 self.with_handle(handle)
72 }
73
74 pub fn current() -> Self {
76 Self::default().with_handle(Handle::current())
77 }
78
79 fn handle(&self) -> Option<Handle> {
80 self.handle.clone().or_else(|| Handle::try_current().ok())
81 }
82
83 fn enter(&self) -> Option<EnterGuard<'_>> {
84 self.runtime
85 .as_ref()
86 .map(|r| r.handle())
87 .or(self.handle.as_ref())
88 .map(Handle::enter)
89 }
90}
91
92impl RuntimeKit for Tokio {}
93
94impl Executor for Tokio {
95 type Task<T: Send + 'static> = TTask<T>;
96
97 fn block_on<T, F: Future<Output = T>>(&self, f: F) -> T {
98 if let Some(runtime) = self.runtime.as_ref() {
99 runtime.block_on(f)
100 } else if let Some(handle) = self.handle() {
101 handle.block_on(f)
102 } else {
103 Handle::try_current()
104 .expect("no tokio runtime: use Runtime::tokio() or Runtime::tokio_with_handle()")
105 .block_on(f)
106 }
107 }
108
109 fn spawn<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
110 &self,
111 f: F,
112 ) -> Task<Self::Task<T>> {
113 TTask(Some(if let Some(handle) = self.handle() {
114 handle.spawn(f)
115 } else {
116 tokio::task::spawn(f)
117 }))
118 .into()
119 }
120
121 fn spawn_blocking<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(
122 &self,
123 f: F,
124 ) -> Task<Self::Task<T>> {
125 TTask(Some(if let Some(handle) = self.handle() {
126 handle.spawn_blocking(f)
127 } else {
128 tokio::task::spawn_blocking(f)
129 }))
130 .into()
131 }
132}
133
134impl Reactor for Tokio {
135 type TcpStream = Compat<TcpStream>;
136 type Sleep = Sleep;
137
138 fn register<H: Read + Write + AsSysFd + Send + 'static>(
139 &self,
140 socket: H,
141 ) -> io::Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static> {
142 let _enter = self.enter();
143 #[cfg(unix)]
144 {
145 Ok(unix::AsyncFdWrapper(tokio::io::unix::AsyncFd::new(socket)?))
146 }
147 #[cfg(not(unix))]
148 {
149 let _ = socket;
150 Err::<crate::util::DummyIO, _>(io::Error::other(
151 "Registering FD on tokio reactor is only supported on unix",
152 ))
153 }
154 }
155
156 fn sleep(&self, dur: Duration) -> Self::Sleep {
157 let _enter = self.enter();
158 tokio::time::sleep(dur)
159 }
160
161 fn interval(&self, dur: Duration) -> impl Stream<Item = Instant> + Send + 'static {
162 let _enter = self.enter();
163 IntervalStream::new(tokio::time::interval(dur)).map(tokio::time::Instant::into_std)
164 }
165
166 fn tcp_connect_addr(
167 &self,
168 addr: SocketAddr,
169 ) -> impl Future<Output = io::Result<Self::TcpStream>> + Send + 'static {
170 let _enter = self.enter();
171 async move { Ok(TcpStream::connect(addr).await?.compat()) }
172 }
173}
174
175mod task {
176 use crate::util::TaskImpl;
177 use async_trait::async_trait;
178 use std::{
179 future::Future,
180 pin::Pin,
181 task::{Context, Poll},
182 };
183
184 #[derive(Debug)]
186 pub struct TTask<T: Send + 'static>(pub(super) Option<tokio::task::JoinHandle<T>>);
187
188 #[async_trait]
189 impl<T: Send + 'static> TaskImpl for TTask<T> {
190 async fn cancel(&mut self) -> Option<T> {
191 let task = self.0.take()?;
192 task.abort();
193 task.await.ok()
194 }
195 }
196
197 impl<T: Send + 'static> Future for TTask<T> {
198 type Output = T;
199
200 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
201 let task = self.0.as_mut().expect("task has been canceled");
202 match Pin::new(task).poll(cx) {
203 Poll::Pending => Poll::Pending,
204 Poll::Ready(res) => Poll::Ready(res.expect("task has been canceled")),
205 }
206 }
207 }
208}
209
210#[cfg(unix)]
211mod unix {
212 use super::*;
213 use futures_io::{AsyncRead, AsyncWrite};
214 use std::{
215 io::{IoSlice, IoSliceMut},
216 pin::Pin,
217 task::{Context, Poll},
218 };
219 use tokio::io::unix::AsyncFd;
220
221 pub(super) struct AsyncFdWrapper<H: Read + Write + AsSysFd>(pub(super) AsyncFd<H>);
222
223 impl<H: Read + Write + AsSysFd> AsyncFdWrapper<H> {
224 fn read<F: FnOnce(&mut AsyncFd<H>) -> io::Result<usize>>(
225 mut self: Pin<&mut Self>,
226 cx: &mut Context<'_>,
227 f: F,
228 ) -> Option<Poll<io::Result<usize>>> {
229 Some(match self.0.poll_read_ready_mut(cx) {
230 Poll::Pending => Poll::Pending,
231 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
232 Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
233 Ok(res) => Poll::Ready(res),
234 Err(_) => return None,
235 },
236 })
237 }
238
239 fn write<R, F: FnOnce(&mut AsyncFd<H>) -> io::Result<R>>(
240 mut self: Pin<&mut Self>,
241 cx: &mut Context<'_>,
242 f: F,
243 ) -> Option<Poll<io::Result<R>>> {
244 Some(match self.0.poll_write_ready_mut(cx) {
245 Poll::Pending => Poll::Pending,
246 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
247 Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
248 Ok(res) => Poll::Ready(res),
249 Err(_) => return None,
250 },
251 })
252 }
253 }
254
255 impl<H: Read + Write + AsSysFd> Unpin for AsyncFdWrapper<H> {}
256
257 impl<H: Read + Write + AsSysFd> AsyncRead for AsyncFdWrapper<H> {
258 fn poll_read(
259 mut self: Pin<&mut Self>,
260 cx: &mut Context<'_>,
261 buf: &mut [u8],
262 ) -> Poll<io::Result<usize>> {
263 loop {
264 if let Some(res) = self.as_mut().read(cx, |socket| socket.get_mut().read(buf)) {
265 return res;
266 }
267 }
268 }
269
270 fn poll_read_vectored(
271 mut self: Pin<&mut Self>,
272 cx: &mut Context<'_>,
273 bufs: &mut [IoSliceMut<'_>],
274 ) -> Poll<io::Result<usize>> {
275 loop {
276 if let Some(res) = self
277 .as_mut()
278 .read(cx, |socket| socket.get_mut().read_vectored(bufs))
279 {
280 return res;
281 }
282 }
283 }
284 }
285
286 impl<H: Read + Write + AsSysFd + Send + 'static> AsyncWrite for AsyncFdWrapper<H> {
287 fn poll_write(
288 mut self: Pin<&mut Self>,
289 cx: &mut Context<'_>,
290 buf: &[u8],
291 ) -> Poll<io::Result<usize>> {
292 loop {
293 if let Some(res) = self
294 .as_mut()
295 .write(cx, |socket| socket.get_mut().write(buf))
296 {
297 return res;
298 }
299 }
300 }
301
302 fn poll_write_vectored(
303 mut self: Pin<&mut Self>,
304 cx: &mut Context<'_>,
305 bufs: &[IoSlice<'_>],
306 ) -> Poll<io::Result<usize>> {
307 loop {
308 if let Some(res) = self
309 .as_mut()
310 .write(cx, |socket| socket.get_mut().write_vectored(bufs))
311 {
312 return res;
313 }
314 }
315 }
316
317 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
318 loop {
319 if let Some(res) = self.as_mut().write(cx, |socket| socket.get_mut().flush()) {
320 return res;
321 }
322 }
323 }
324
325 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
326 self.poll_flush(cx)
327 }
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334
335 #[test]
336 fn auto_traits() {
337 use crate::util::test::*;
338 let runtime = Runtime::tokio().unwrap();
339 assert_send(&runtime);
340 assert_sync(&runtime);
341 assert_clone(&runtime);
342 }
343}