1use crate::{
4 Runtime,
5 sys::AsSysFd,
6 traits::{Executor, Reactor, RuntimeKit, Task},
7};
8use async_trait::async_trait;
9use cfg_if::cfg_if;
10use futures_core::Stream;
11use futures_io::{AsyncRead, AsyncWrite};
12use std::{
13 future::Future,
14 io::{self, Read, Write},
15 net::SocketAddr,
16 pin::Pin,
17 task::{Context, Poll},
18 time::{Duration, Instant},
19};
20use tokio::{net::TcpStream, runtime::Handle};
21use tokio_stream::{StreamExt, wrappers::IntervalStream};
22use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
23
24pub type TokioRuntime = Runtime<Tokio>;
26
27impl TokioRuntime {
28 pub fn tokio() -> Self {
30 Self::new(Tokio::current())
31 }
32
33 pub fn tokio_with_handle(handle: Handle) -> Self {
35 Self::new(Tokio::default().with_handle(handle))
36 }
37}
38
39#[derive(Default, Debug, Clone)]
41pub struct Tokio {
42 handle: Option<Handle>,
43}
44
45impl Tokio {
46 pub fn with_handle(mut self, handle: Handle) -> Self {
48 self.handle = Some(handle);
49 self
50 }
51
52 pub fn current() -> Self {
54 Self::default().with_handle(Handle::current())
55 }
56
57 pub(crate) fn handle(&self) -> Option<Handle> {
58 Handle::try_current().ok().or_else(|| self.handle.clone())
59 }
60}
61
62struct TTask<T: Send + 'static>(Option<tokio::task::JoinHandle<T>>);
63
64impl RuntimeKit for Tokio {}
65
66impl Executor for Tokio {
67 fn block_on<T, F: Future<Output = T>>(&self, f: F) -> T {
68 if let Some(handle) = self.handle() {
69 handle.block_on(f)
70 } else {
71 Handle::current().block_on(f)
72 }
73 }
74
75 fn spawn<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
76 &self,
77 f: F,
78 ) -> impl Task<T> + 'static {
79 TTask(Some(if let Some(handle) = self.handle() {
80 handle.spawn(f)
81 } else {
82 tokio::task::spawn(f)
83 }))
84 }
85
86 fn spawn_blocking<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(
87 &self,
88 f: F,
89 ) -> impl Task<T> + 'static {
90 TTask(Some(if let Some(handle) = self.handle() {
91 handle.spawn_blocking(f)
92 } else {
93 tokio::task::spawn_blocking(f)
94 }))
95 }
96}
97
98#[async_trait]
99impl<T: Send + 'static> Task<T> for TTask<T> {
100 async fn cancel(&mut self) -> Option<T> {
101 let task = self.0.take()?;
102 task.abort();
103 task.await.ok()
104 }
105}
106
107impl<T: Send + 'static> Future for TTask<T> {
108 type Output = T;
109
110 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
111 let task = self.0.as_mut().expect("task has been canceled");
112 match Pin::new(task).poll(cx) {
113 Poll::Pending => Poll::Pending,
114 Poll::Ready(res) => Poll::Ready(res.expect("task has been canceled")),
115 }
116 }
117}
118
119impl Reactor for Tokio {
120 type TcpStream = Compat<TcpStream>;
121
122 fn register<H: Read + Write + AsSysFd + Send + 'static>(
123 &self,
124 socket: H,
125 ) -> io::Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static> {
126 let _enter = self.handle().as_ref().map(|handle| handle.enter());
127 cfg_if! {
128 if #[cfg(unix)] {
129 Ok(unix::AsyncFdWrapper(
130 tokio::io::unix::AsyncFd::new(socket)?,
131 ))
132 } else {
133 Err::<windows::Dummy, _>(io::Error::other(
134 "Registering FD on tokio reactor is only supported on unix",
135 ))
136 }
137 }
138 }
139
140 fn sleep(&self, dur: Duration) -> impl Future<Output = ()> + Send + 'static {
141 tokio::time::sleep(dur)
142 }
143
144 fn interval(&self, dur: Duration) -> impl Stream<Item = Instant> + Send + 'static {
145 let _enter = self.handle().as_ref().map(|handle| handle.enter());
146 Box::new(
147 IntervalStream::new(tokio::time::interval(dur)).map(tokio::time::Instant::into_std),
148 )
149 }
150
151 fn tcp_connect(
152 &self,
153 addr: SocketAddr,
154 ) -> impl Future<Output = io::Result<Self::TcpStream>> + Send + 'static {
155 let _enter = self.handle().as_ref().map(|handle| handle.enter());
156 async move { Ok(TcpStream::connect(addr).await?.compat()) }
157 }
158}
159
160#[cfg(unix)]
161mod unix {
162 use super::*;
163 use futures_io::{AsyncRead, AsyncWrite};
164 use std::io::{IoSlice, IoSliceMut};
165 use tokio::io::unix::AsyncFd;
166
167 pub(super) struct AsyncFdWrapper<H: Read + Write + AsSysFd>(pub(super) AsyncFd<H>);
168
169 impl<H: Read + Write + AsSysFd> AsyncFdWrapper<H> {
170 fn read<F: FnOnce(&mut AsyncFd<H>) -> io::Result<usize>>(
171 mut self: Pin<&mut Self>,
172 cx: &mut Context<'_>,
173 f: F,
174 ) -> Option<Poll<io::Result<usize>>> {
175 Some(match self.0.poll_read_ready_mut(cx) {
176 Poll::Pending => Poll::Pending,
177 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
178 Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
179 Ok(res) => Poll::Ready(res),
180 Err(_) => return None,
181 },
182 })
183 }
184
185 fn write<R, F: FnOnce(&mut AsyncFd<H>) -> io::Result<R>>(
186 mut self: Pin<&mut Self>,
187 cx: &mut Context<'_>,
188 f: F,
189 ) -> Option<Poll<io::Result<R>>> {
190 Some(match self.0.poll_write_ready_mut(cx) {
191 Poll::Pending => Poll::Pending,
192 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
193 Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
194 Ok(res) => Poll::Ready(res),
195 Err(_) => return None,
196 },
197 })
198 }
199 }
200
201 impl<H: Read + Write + AsSysFd> Unpin for AsyncFdWrapper<H> {}
202
203 impl<H: Read + Write + AsSysFd> AsyncRead for AsyncFdWrapper<H> {
204 fn poll_read(
205 mut self: Pin<&mut Self>,
206 cx: &mut Context<'_>,
207 buf: &mut [u8],
208 ) -> Poll<io::Result<usize>> {
209 loop {
210 if let Some(res) = self.as_mut().read(cx, |socket| socket.get_mut().read(buf)) {
211 return res;
212 }
213 }
214 }
215
216 fn poll_read_vectored(
217 mut self: Pin<&mut Self>,
218 cx: &mut Context<'_>,
219 bufs: &mut [IoSliceMut<'_>],
220 ) -> Poll<io::Result<usize>> {
221 loop {
222 if let Some(res) = self
223 .as_mut()
224 .read(cx, |socket| socket.get_mut().read_vectored(bufs))
225 {
226 return res;
227 }
228 }
229 }
230 }
231
232 impl<H: Read + Write + AsSysFd + Send + 'static> AsyncWrite for AsyncFdWrapper<H> {
233 fn poll_write(
234 mut self: Pin<&mut Self>,
235 cx: &mut Context<'_>,
236 buf: &[u8],
237 ) -> Poll<io::Result<usize>> {
238 loop {
239 if let Some(res) = self
240 .as_mut()
241 .write(cx, |socket| socket.get_mut().write(buf))
242 {
243 return res;
244 }
245 }
246 }
247
248 fn poll_write_vectored(
249 mut self: Pin<&mut Self>,
250 cx: &mut Context<'_>,
251 bufs: &[IoSlice<'_>],
252 ) -> Poll<io::Result<usize>> {
253 loop {
254 if let Some(res) = self
255 .as_mut()
256 .write(cx, |socket| socket.get_mut().write_vectored(bufs))
257 {
258 return res;
259 }
260 }
261 }
262
263 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
264 loop {
265 if let Some(res) = self.as_mut().write(cx, |socket| socket.get_mut().flush()) {
266 return res;
267 }
268 }
269 }
270
271 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
272 self.poll_flush(cx)
273 }
274 }
275}
276
277#[cfg(windows)]
278mod windows {
279 use super::*;
280 use futures_io::{AsyncRead, AsyncWrite};
281
282 pub(super) struct Dummy;
283
284 impl AsyncRead for Dummy {
285 fn poll_read(
286 self: Pin<&mut Self>,
287 cx: &mut Context<'_>,
288 buf: &mut [u8],
289 ) -> Poll<io::Result<usize>> {
290 Poll::Pending
291 }
292 }
293
294 impl AsyncWrite for Dummy {
295 fn poll_write(
296 self: Pin<&mut Self>,
297 cx: &mut Context<'_>,
298 buf: &[u8],
299 ) -> Poll<io::Result<usize>> {
300 Poll::Pending
301 }
302
303 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
304 Poll::Pending
305 }
306
307 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
308 Poll::Pending
309 }
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn dyn_compat() {
319 struct Test {
320 _executor: Box<dyn Executor>,
321 _reactor: Box<dyn Reactor<TcpStream = Compat<TcpStream>>>,
322 _kit: Box<dyn RuntimeKit<TcpStream = Compat<TcpStream>>>,
323 _task: Box<dyn Task<String>>,
324 }
325
326 let _ = Test {
327 _executor: Box::new(Tokio::default()),
328 _reactor: Box::new(Tokio::default()),
329 _kit: Box::new(Tokio::default()),
330 _task: Box::new(TTask(None)),
331 };
332 }
333}