1use crate::{
4 Runtime,
5 sys::AsSysFd,
6 traits::{Executor, Reactor, RuntimeKit, Task},
7};
8use async_trait::async_trait;
9use cfg_if::cfg_if;
10use futures_core::Stream;
11use futures_io::{AsyncRead, AsyncWrite};
12use std::{
13 future::Future,
14 io::{self, Read, Write},
15 net::SocketAddr,
16 pin::Pin,
17 task::{Context, Poll},
18 time::{Duration, Instant},
19};
20use tokio::{
21 net::TcpStream,
22 runtime::{EnterGuard, Handle, Runtime as TokioRT},
23};
24use tokio_stream::{StreamExt, wrappers::IntervalStream};
25use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
26
27pub type TokioRuntime = Runtime<Tokio>;
29
30impl TokioRuntime {
31 pub fn tokio() -> io::Result<Self> {
33 Ok(Self::tokio_with_runtime(TokioRT::new()?))
34 }
35
36 pub fn tokio_current() -> Self {
38 Self::new(Tokio::current())
39 }
40
41 pub fn tokio_with_handle(handle: Handle) -> Self {
43 Self::new(Tokio::default().with_handle(handle))
44 }
45
46 pub fn tokio_with_runtime(runtime: TokioRT) -> Self {
48 Self::new(Tokio::default().with_runtime(runtime))
49 }
50}
51
52#[derive(Default, Debug)]
54pub struct Tokio {
55 handle: Option<Handle>,
56 runtime: Option<TokioRT>,
57}
58
59impl Tokio {
60 pub fn with_handle(mut self, handle: Handle) -> Self {
62 self.handle = Some(handle);
63 self
64 }
65
66 pub fn with_runtime(mut self, runtime: TokioRT) -> Self {
68 let handle = runtime.handle().clone();
69 self.runtime = Some(runtime);
70 self.with_handle(handle)
71 }
72
73 pub fn current() -> Self {
75 Self::default().with_handle(Handle::current())
76 }
77
78 fn handle(&self) -> Option<Handle> {
79 self.runtime
80 .as_ref()
81 .map(|r| r.handle().clone())
82 .or_else(|| Handle::try_current().ok())
83 .or_else(|| self.handle.clone())
84 }
85
86 fn enter(&self) -> Option<EnterGuard<'_>> {
87 self.runtime
88 .as_ref()
89 .map(TokioRT::handle)
90 .or(self.handle.as_ref())
91 .map(Handle::enter)
92 }
93}
94
95struct TTask<T: Send + 'static>(Option<tokio::task::JoinHandle<T>>);
96
97impl RuntimeKit for Tokio {}
98
99impl Executor for Tokio {
100 fn block_on<T, F: Future<Output = T>>(&self, f: F) -> T {
101 if let Some(runtime) = self.runtime.as_ref() {
102 runtime.block_on(f)
103 } else if let Some(handle) = self.handle() {
104 handle.block_on(f)
105 } else {
106 Handle::current().block_on(f)
107 }
108 }
109
110 fn spawn<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
111 &self,
112 f: F,
113 ) -> impl Task<T> + 'static {
114 TTask(Some(if let Some(handle) = self.handle() {
115 handle.spawn(f)
116 } else {
117 tokio::task::spawn(f)
118 }))
119 }
120
121 fn spawn_blocking<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(
122 &self,
123 f: F,
124 ) -> impl Task<T> + 'static {
125 TTask(Some(if let Some(handle) = self.handle() {
126 handle.spawn_blocking(f)
127 } else {
128 tokio::task::spawn_blocking(f)
129 }))
130 }
131}
132
133#[async_trait]
134impl<T: Send + 'static> Task<T> for TTask<T> {
135 async fn cancel(&mut self) -> Option<T> {
136 let task = self.0.take()?;
137 task.abort();
138 task.await.ok()
139 }
140}
141
142impl<T: Send + 'static> Future for TTask<T> {
143 type Output = T;
144
145 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
146 let task = self.0.as_mut().expect("task has been canceled");
147 match Pin::new(task).poll(cx) {
148 Poll::Pending => Poll::Pending,
149 Poll::Ready(res) => Poll::Ready(res.expect("task has been canceled")),
150 }
151 }
152}
153
154impl Reactor for Tokio {
155 type TcpStream = Compat<TcpStream>;
156
157 fn register<H: Read + Write + AsSysFd + Send + 'static>(
158 &self,
159 socket: H,
160 ) -> io::Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static> {
161 let _enter = self.enter();
162 cfg_if! {
163 if #[cfg(unix)] {
164 Ok(unix::AsyncFdWrapper(
165 tokio::io::unix::AsyncFd::new(socket)?,
166 ))
167 } else {
168 Err::<windows::Dummy, _>(io::Error::other(
169 "Registering FD on tokio reactor is only supported on unix",
170 ))
171 }
172 }
173 }
174
175 fn sleep(&self, dur: Duration) -> impl Future<Output = ()> + Send + 'static {
176 let _enter = self.enter();
177 tokio::time::sleep(dur)
178 }
179
180 fn interval(&self, dur: Duration) -> impl Stream<Item = Instant> + Send + 'static {
181 let _enter = self.enter();
182 IntervalStream::new(tokio::time::interval(dur)).map(tokio::time::Instant::into_std)
183 }
184
185 fn tcp_connect_addr(
186 &self,
187 addr: SocketAddr,
188 ) -> impl Future<Output = io::Result<Self::TcpStream>> + Send + 'static {
189 let _enter = self.enter();
190 async move { Ok(TcpStream::connect(addr).await?.compat()) }
191 }
192}
193
194#[cfg(unix)]
195mod unix {
196 use super::*;
197 use futures_io::{AsyncRead, AsyncWrite};
198 use std::io::{IoSlice, IoSliceMut};
199 use tokio::io::unix::AsyncFd;
200
201 pub(super) struct AsyncFdWrapper<H: Read + Write + AsSysFd>(pub(super) AsyncFd<H>);
202
203 impl<H: Read + Write + AsSysFd> AsyncFdWrapper<H> {
204 fn read<F: FnOnce(&mut AsyncFd<H>) -> io::Result<usize>>(
205 mut self: Pin<&mut Self>,
206 cx: &mut Context<'_>,
207 f: F,
208 ) -> Option<Poll<io::Result<usize>>> {
209 Some(match self.0.poll_read_ready_mut(cx) {
210 Poll::Pending => Poll::Pending,
211 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
212 Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
213 Ok(res) => Poll::Ready(res),
214 Err(_) => return None,
215 },
216 })
217 }
218
219 fn write<R, F: FnOnce(&mut AsyncFd<H>) -> io::Result<R>>(
220 mut self: Pin<&mut Self>,
221 cx: &mut Context<'_>,
222 f: F,
223 ) -> Option<Poll<io::Result<R>>> {
224 Some(match self.0.poll_write_ready_mut(cx) {
225 Poll::Pending => Poll::Pending,
226 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
227 Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
228 Ok(res) => Poll::Ready(res),
229 Err(_) => return None,
230 },
231 })
232 }
233 }
234
235 impl<H: Read + Write + AsSysFd> Unpin for AsyncFdWrapper<H> {}
236
237 impl<H: Read + Write + AsSysFd> AsyncRead for AsyncFdWrapper<H> {
238 fn poll_read(
239 mut self: Pin<&mut Self>,
240 cx: &mut Context<'_>,
241 buf: &mut [u8],
242 ) -> Poll<io::Result<usize>> {
243 loop {
244 if let Some(res) = self.as_mut().read(cx, |socket| socket.get_mut().read(buf)) {
245 return res;
246 }
247 }
248 }
249
250 fn poll_read_vectored(
251 mut self: Pin<&mut Self>,
252 cx: &mut Context<'_>,
253 bufs: &mut [IoSliceMut<'_>],
254 ) -> Poll<io::Result<usize>> {
255 loop {
256 if let Some(res) = self
257 .as_mut()
258 .read(cx, |socket| socket.get_mut().read_vectored(bufs))
259 {
260 return res;
261 }
262 }
263 }
264 }
265
266 impl<H: Read + Write + AsSysFd + Send + 'static> AsyncWrite for AsyncFdWrapper<H> {
267 fn poll_write(
268 mut self: Pin<&mut Self>,
269 cx: &mut Context<'_>,
270 buf: &[u8],
271 ) -> Poll<io::Result<usize>> {
272 loop {
273 if let Some(res) = self
274 .as_mut()
275 .write(cx, |socket| socket.get_mut().write(buf))
276 {
277 return res;
278 }
279 }
280 }
281
282 fn poll_write_vectored(
283 mut self: Pin<&mut Self>,
284 cx: &mut Context<'_>,
285 bufs: &[IoSlice<'_>],
286 ) -> Poll<io::Result<usize>> {
287 loop {
288 if let Some(res) = self
289 .as_mut()
290 .write(cx, |socket| socket.get_mut().write_vectored(bufs))
291 {
292 return res;
293 }
294 }
295 }
296
297 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
298 loop {
299 if let Some(res) = self.as_mut().write(cx, |socket| socket.get_mut().flush()) {
300 return res;
301 }
302 }
303 }
304
305 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
306 self.poll_flush(cx)
307 }
308 }
309}
310
311#[cfg(windows)]
312mod windows {
313 use super::*;
314 use futures_io::{AsyncRead, AsyncWrite};
315
316 pub(super) struct Dummy;
317
318 impl AsyncRead for Dummy {
319 fn poll_read(
320 self: Pin<&mut Self>,
321 cx: &mut Context<'_>,
322 buf: &mut [u8],
323 ) -> Poll<io::Result<usize>> {
324 Poll::Pending
325 }
326 }
327
328 impl AsyncWrite for Dummy {
329 fn poll_write(
330 self: Pin<&mut Self>,
331 cx: &mut Context<'_>,
332 buf: &[u8],
333 ) -> Poll<io::Result<usize>> {
334 Poll::Pending
335 }
336
337 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
338 Poll::Pending
339 }
340
341 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
342 Poll::Pending
343 }
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350
351 #[test]
352 fn dyn_compat() {
353 struct Test {
354 _executor: Box<dyn Executor>,
355 _reactor: Box<dyn Reactor<TcpStream = Compat<TcpStream>>>,
356 _kit: Box<dyn RuntimeKit<TcpStream = Compat<TcpStream>>>,
357 _task: Box<dyn Task<String>>,
358 }
359
360 let _ = Test {
361 _executor: Box::new(Tokio::default()),
362 _reactor: Box::new(Tokio::default()),
363 _kit: Box::new(Tokio::default()),
364 _task: Box::new(TTask(None)),
365 };
366 }
367}