1use crate::{AsyncIOHandle, Executor, IOHandle, Reactor, Runtime, RuntimeKit, Task, sys::IO};
4use async_trait::async_trait;
5use cfg_if::cfg_if;
6use futures_core::Stream;
7use std::{
8 future::Future,
9 io,
10 net::SocketAddr,
11 pin::Pin,
12 task::{Context, Poll},
13 time::{Duration, Instant},
14};
15use tokio::{net::TcpStream, runtime::Handle};
16use tokio_stream::{StreamExt, wrappers::IntervalStream};
17use tokio_util::compat::TokioAsyncReadCompatExt;
18
19pub type TokioRuntime = Runtime<Tokio>;
21
22impl TokioRuntime {
23 pub fn tokio() -> Self {
25 Self::new(Tokio::current())
26 }
27
28 pub fn tokio_with_handle(handle: Handle) -> Self {
30 Self::new(Tokio::default().with_handle(handle))
31 }
32}
33
34#[derive(Default, Debug, Clone)]
36pub struct Tokio {
37 handle: Option<Handle>,
38}
39
40impl Tokio {
41 pub fn with_handle(mut self, handle: Handle) -> Self {
43 self.handle = Some(handle);
44 self
45 }
46
47 pub fn current() -> Self {
49 Self::default().with_handle(Handle::current())
50 }
51
52 pub(crate) fn handle(&self) -> Option<Handle> {
53 Handle::try_current().ok().or_else(|| self.handle.clone())
54 }
55}
56
57struct TTask<T: Send>(Option<tokio::task::JoinHandle<T>>);
58
59impl RuntimeKit for Tokio {}
60
61impl Executor for Tokio {
62 fn block_on<T, F: Future<Output = T>>(&self, f: F) -> T {
63 if let Some(handle) = self.handle() {
64 handle.block_on(f)
65 } else {
66 Handle::current().block_on(f)
67 }
68 }
69
70 fn spawn<T: Send + 'static>(
71 &self,
72 f: impl Future<Output = T> + Send + 'static,
73 ) -> impl Task<T> {
74 TTask(Some(if let Some(handle) = self.handle() {
75 handle.spawn(f)
76 } else {
77 tokio::task::spawn(f)
78 }))
79 }
80
81 fn spawn_blocking<F: FnOnce() -> T + Send + 'static, T: Send + 'static>(
82 &self,
83 f: F,
84 ) -> impl Task<T> {
85 TTask(Some(if let Some(handle) = self.handle() {
86 handle.spawn_blocking(f)
87 } else {
88 tokio::task::spawn_blocking(f)
89 }))
90 }
91}
92
93#[async_trait(?Send)]
94impl<T: Send> Task<T> for TTask<T> {
95 async fn cancel(&mut self) -> Option<T> {
96 let task = self.0.take()?;
97 task.abort();
98 task.await.ok()
99 }
100}
101
102impl<T: Send> Future for TTask<T> {
103 type Output = T;
104
105 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
106 let task = self.0.as_mut().expect("task has been canceled");
107 match Pin::new(task).poll(cx) {
108 Poll::Pending => Poll::Pending,
109 Poll::Ready(res) => Poll::Ready(res.expect("task has been canceled")),
110 }
111 }
112}
113
114#[async_trait]
115impl Reactor for Tokio {
116 fn register<H: IO + Send + 'static>(
117 &self,
118 socket: IOHandle<H>,
119 ) -> io::Result<impl AsyncIOHandle + Send> {
120 let _enter = self.handle().as_ref().map(|handle| handle.enter());
121 cfg_if! {
122 if #[cfg(unix)] {
123 Ok(Box::new(unix::AsyncFdWrapper(
124 tokio::io::unix::AsyncFd::new(socket)?,
125 )))
126 } else {
127 Err::<windows::Dummy, _>(io::Error::other(
128 "Registering FD on tokio reactor is only supported on unix",
129 ))
130 }
131 }
132 }
133
134 fn sleep(&self, dur: Duration) -> impl Future<Output = ()> {
135 tokio::time::sleep(dur)
136 }
137
138 fn interval(&self, dur: Duration) -> impl Stream<Item = Instant> {
139 let _enter = self.handle().as_ref().map(|handle| handle.enter());
140 Box::new(
141 IntervalStream::new(tokio::time::interval(dur)).map(tokio::time::Instant::into_std),
142 )
143 }
144
145 async fn tcp_connect(&self, addr: SocketAddr) -> io::Result<impl AsyncIOHandle + Send> {
146 Ok(TcpStream::connect(addr).await?.compat())
149 }
150}
151
152#[cfg(unix)]
153mod unix {
154 use super::*;
155 use futures_io::{AsyncRead, AsyncWrite};
156 use std::io::{IoSlice, IoSliceMut, Read, Write};
157 use tokio::io::unix::AsyncFd;
158
159 pub(super) struct AsyncFdWrapper<H: IO + Send + 'static>(pub(super) AsyncFd<IOHandle<H>>);
160
161 impl<H: IO + Send + 'static> AsyncFdWrapper<H> {
162 fn read<F: FnOnce(&mut AsyncFd<IOHandle<H>>) -> io::Result<usize>>(
163 mut self: Pin<&mut Self>,
164 cx: &mut Context<'_>,
165 f: F,
166 ) -> Option<Poll<io::Result<usize>>> {
167 Some(match self.0.poll_read_ready_mut(cx) {
168 Poll::Pending => Poll::Pending,
169 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
170 Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
171 Ok(res) => Poll::Ready(res),
172 Err(_) => return None,
173 },
174 })
175 }
176
177 fn write<R, F: FnOnce(&mut AsyncFd<IOHandle<H>>) -> io::Result<R>>(
178 mut self: Pin<&mut Self>,
179 cx: &mut Context<'_>,
180 f: F,
181 ) -> Option<Poll<io::Result<R>>> {
182 Some(match self.0.poll_write_ready_mut(cx) {
183 Poll::Pending => Poll::Pending,
184 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
185 Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
186 Ok(res) => Poll::Ready(res),
187 Err(_) => return None,
188 },
189 })
190 }
191 }
192
193 impl<H: IO + Send + 'static> Unpin for AsyncFdWrapper<H> {}
194
195 impl<H: IO + Send + 'static> AsyncRead for AsyncFdWrapper<H> {
196 fn poll_read(
197 mut self: Pin<&mut Self>,
198 cx: &mut Context<'_>,
199 buf: &mut [u8],
200 ) -> Poll<io::Result<usize>> {
201 loop {
202 if let Some(res) = self.as_mut().read(cx, |socket| socket.get_mut().read(buf)) {
203 return res;
204 }
205 }
206 }
207
208 fn poll_read_vectored(
209 mut self: Pin<&mut Self>,
210 cx: &mut Context<'_>,
211 bufs: &mut [IoSliceMut<'_>],
212 ) -> Poll<io::Result<usize>> {
213 loop {
214 if let Some(res) = self
215 .as_mut()
216 .read(cx, |socket| socket.get_mut().read_vectored(bufs))
217 {
218 return res;
219 }
220 }
221 }
222 }
223
224 impl<H: IO + Send + 'static> AsyncWrite for AsyncFdWrapper<H> {
225 fn poll_write(
226 mut self: Pin<&mut Self>,
227 cx: &mut Context<'_>,
228 buf: &[u8],
229 ) -> Poll<io::Result<usize>> {
230 loop {
231 if let Some(res) = self
232 .as_mut()
233 .write(cx, |socket| socket.get_mut().write(buf))
234 {
235 return res;
236 }
237 }
238 }
239
240 fn poll_write_vectored(
241 mut self: Pin<&mut Self>,
242 cx: &mut Context<'_>,
243 bufs: &[IoSlice<'_>],
244 ) -> Poll<io::Result<usize>> {
245 loop {
246 if let Some(res) = self
247 .as_mut()
248 .write(cx, |socket| socket.get_mut().write_vectored(bufs))
249 {
250 return res;
251 }
252 }
253 }
254
255 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
256 loop {
257 if let Some(res) = self.as_mut().write(cx, |socket| socket.get_mut().flush()) {
258 return res;
259 }
260 }
261 }
262
263 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
264 self.poll_flush(cx)
265 }
266 }
267}
268
269#[cfg(windows)]
270mod windows {
271 use super::*;
272 use futures_io::{AsyncRead, AsyncWrite};
273
274 pub(super) struct Dummy;
275
276 impl AsyncRead for Dummy {
277 fn poll_read(
278 self: Pin<&mut Self>,
279 cx: &mut Context<'_>,
280 buf: &mut [u8],
281 ) -> Poll<io::Result<usize>> {
282 Poll::Pending
283 }
284 }
285
286 impl AsyncWrite for Dummy {
287 fn poll_write(
288 self: Pin<&mut Self>,
289 cx: &mut Context<'_>,
290 buf: &[u8],
291 ) -> Poll<io::Result<usize>> {
292 Poll::Pending
293 }
294
295 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
296 Poll::Pending
297 }
298
299 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
300 Poll::Pending
301 }
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 #[test]
310 fn dyn_compat() {
311 struct Test {
312 _executor: Box<dyn Executor>,
313 _reactor: Box<dyn Reactor>,
314 _kit: Box<dyn RuntimeKit>,
315 _task: Box<dyn Task<String>>,
316 }
317
318 let _ = Test {
319 _executor: Box::new(Tokio::default()),
320 _reactor: Box::new(Tokio::default()),
321 _kit: Box::new(Tokio::default()),
322 _task: Box::new(TTask(None)),
323 };
324 }
325}