1use crate::{
4 Runtime,
5 sys::AsSysFd,
6 traits::{Executor, Reactor, RuntimeKit},
7 util::Task,
8};
9use cfg_if::cfg_if;
10use futures_core::Stream;
11use futures_io::{AsyncRead, AsyncWrite};
12use std::{
13 future::Future,
14 io::{self, Read, Write},
15 net::SocketAddr,
16 pin::Pin,
17 sync::Arc,
18 task::{Context, Poll},
19 time::{Duration, Instant},
20};
21use tokio::{
22 net::TcpStream,
23 runtime::{EnterGuard, Handle, Runtime as TokioRT},
24};
25use tokio_stream::{StreamExt, wrappers::IntervalStream};
26use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
27
28use task::TTask;
29
30pub type TokioRuntime = Runtime<Tokio>;
32
33impl TokioRuntime {
34 pub fn tokio() -> io::Result<Self> {
36 Ok(Self::tokio_with_runtime(TokioRT::new()?))
37 }
38
39 pub fn tokio_current() -> Self {
41 Self::new(Tokio::current())
42 }
43
44 pub fn tokio_with_handle(handle: Handle) -> Self {
46 Self::new(Tokio::default().with_handle(handle))
47 }
48
49 pub fn tokio_with_runtime(runtime: TokioRT) -> Self {
51 Self::new(Tokio::default().with_runtime(runtime))
52 }
53}
54
55#[derive(Default, Clone, Debug)]
57pub struct Tokio {
58 handle: Option<Handle>,
59 runtime: Option<Arc<TokioRT>>,
60}
61
62impl Tokio {
63 pub fn with_handle(mut self, handle: Handle) -> Self {
65 self.handle = Some(handle);
66 self
67 }
68
69 pub fn with_runtime(mut self, runtime: TokioRT) -> Self {
71 let handle = runtime.handle().clone();
72 self.runtime = Some(Arc::new(runtime));
73 self.with_handle(handle)
74 }
75
76 pub fn current() -> Self {
78 Self::default().with_handle(Handle::current())
79 }
80
81 fn handle(&self) -> Option<Handle> {
82 self.runtime
83 .as_ref()
84 .map(|r| r.handle().clone())
85 .or_else(|| Handle::try_current().ok())
86 .or_else(|| self.handle.clone())
87 }
88
89 fn enter(&self) -> Option<EnterGuard<'_>> {
90 self.runtime
91 .as_ref()
92 .map(|r| r.handle())
93 .or(self.handle.as_ref())
94 .map(Handle::enter)
95 }
96}
97
98impl RuntimeKit for Tokio {}
99
100impl Executor for Tokio {
101 type Task<T: Send + 'static> = TTask<T>;
102
103 fn block_on<T, F: Future<Output = T>>(&self, f: F) -> T {
104 if let Some(runtime) = self.runtime.as_ref() {
105 runtime.block_on(f)
106 } else if let Some(handle) = self.handle() {
107 handle.block_on(f)
108 } else {
109 Handle::current().block_on(f)
110 }
111 }
112
113 fn spawn<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
114 &self,
115 f: F,
116 ) -> Task<Self::Task<T>> {
117 TTask(Some(if let Some(handle) = self.handle() {
118 handle.spawn(f)
119 } else {
120 tokio::task::spawn(f)
121 }))
122 .into()
123 }
124
125 fn spawn_blocking<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(
126 &self,
127 f: F,
128 ) -> Task<Self::Task<T>> {
129 TTask(Some(if let Some(handle) = self.handle() {
130 handle.spawn_blocking(f)
131 } else {
132 tokio::task::spawn_blocking(f)
133 }))
134 .into()
135 }
136}
137
138impl Reactor for Tokio {
139 type TcpStream = Compat<TcpStream>;
140
141 fn register<H: Read + Write + AsSysFd + Send + 'static>(
142 &self,
143 socket: H,
144 ) -> io::Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static> {
145 let _enter = self.enter();
146 cfg_if! {
147 if #[cfg(unix)] {
148 Ok(unix::AsyncFdWrapper(
149 tokio::io::unix::AsyncFd::new(socket)?,
150 ))
151 } else {
152 Err::<crate::util::DummyIO, _>(io::Error::other(
153 "Registering FD on tokio reactor is only supported on unix",
154 ))
155 }
156 }
157 }
158
159 fn sleep(&self, dur: Duration) -> impl Future<Output = ()> + Send + 'static {
160 let _enter = self.enter();
161 tokio::time::sleep(dur)
162 }
163
164 fn interval(&self, dur: Duration) -> impl Stream<Item = Instant> + Send + 'static {
165 let _enter = self.enter();
166 IntervalStream::new(tokio::time::interval(dur)).map(tokio::time::Instant::into_std)
167 }
168
169 fn tcp_connect_addr(
170 &self,
171 addr: SocketAddr,
172 ) -> impl Future<Output = io::Result<Self::TcpStream>> + Send + 'static {
173 let _enter = self.enter();
174 async move { Ok(TcpStream::connect(addr).await?.compat()) }
175 }
176}
177
178mod task {
179 use crate::util::TaskImpl;
180 use async_trait::async_trait;
181 use std::{
182 future::Future,
183 pin::Pin,
184 task::{Context, Poll},
185 };
186
187 #[derive(Debug)]
189 pub struct TTask<T: Send + 'static>(pub(super) Option<tokio::task::JoinHandle<T>>);
190
191 #[async_trait]
192 impl<T: Send + 'static> TaskImpl for TTask<T> {
193 async fn cancel(&mut self) -> Option<T> {
194 let task = self.0.take()?;
195 task.abort();
196 task.await.ok()
197 }
198 }
199
200 impl<T: Send + 'static> Future for TTask<T> {
201 type Output = T;
202
203 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
204 let task = self.0.as_mut().expect("task has been canceled");
205 match Pin::new(task).poll(cx) {
206 Poll::Pending => Poll::Pending,
207 Poll::Ready(res) => Poll::Ready(res.expect("task has been canceled")),
208 }
209 }
210 }
211}
212
213#[cfg(unix)]
214mod unix {
215 use super::*;
216 use futures_io::{AsyncRead, AsyncWrite};
217 use std::io::{IoSlice, IoSliceMut};
218 use tokio::io::unix::AsyncFd;
219
220 pub(super) struct AsyncFdWrapper<H: Read + Write + AsSysFd>(pub(super) AsyncFd<H>);
221
222 impl<H: Read + Write + AsSysFd> AsyncFdWrapper<H> {
223 fn read<F: FnOnce(&mut AsyncFd<H>) -> io::Result<usize>>(
224 mut self: Pin<&mut Self>,
225 cx: &mut Context<'_>,
226 f: F,
227 ) -> Option<Poll<io::Result<usize>>> {
228 Some(match self.0.poll_read_ready_mut(cx) {
229 Poll::Pending => Poll::Pending,
230 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
231 Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
232 Ok(res) => Poll::Ready(res),
233 Err(_) => return None,
234 },
235 })
236 }
237
238 fn write<R, F: FnOnce(&mut AsyncFd<H>) -> io::Result<R>>(
239 mut self: Pin<&mut Self>,
240 cx: &mut Context<'_>,
241 f: F,
242 ) -> Option<Poll<io::Result<R>>> {
243 Some(match self.0.poll_write_ready_mut(cx) {
244 Poll::Pending => Poll::Pending,
245 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
246 Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
247 Ok(res) => Poll::Ready(res),
248 Err(_) => return None,
249 },
250 })
251 }
252 }
253
254 impl<H: Read + Write + AsSysFd> Unpin for AsyncFdWrapper<H> {}
255
256 impl<H: Read + Write + AsSysFd> AsyncRead for AsyncFdWrapper<H> {
257 fn poll_read(
258 mut self: Pin<&mut Self>,
259 cx: &mut Context<'_>,
260 buf: &mut [u8],
261 ) -> Poll<io::Result<usize>> {
262 loop {
263 if let Some(res) = self.as_mut().read(cx, |socket| socket.get_mut().read(buf)) {
264 return res;
265 }
266 }
267 }
268
269 fn poll_read_vectored(
270 mut self: Pin<&mut Self>,
271 cx: &mut Context<'_>,
272 bufs: &mut [IoSliceMut<'_>],
273 ) -> Poll<io::Result<usize>> {
274 loop {
275 if let Some(res) = self
276 .as_mut()
277 .read(cx, |socket| socket.get_mut().read_vectored(bufs))
278 {
279 return res;
280 }
281 }
282 }
283 }
284
285 impl<H: Read + Write + AsSysFd + Send + 'static> AsyncWrite for AsyncFdWrapper<H> {
286 fn poll_write(
287 mut self: Pin<&mut Self>,
288 cx: &mut Context<'_>,
289 buf: &[u8],
290 ) -> Poll<io::Result<usize>> {
291 loop {
292 if let Some(res) = self
293 .as_mut()
294 .write(cx, |socket| socket.get_mut().write(buf))
295 {
296 return res;
297 }
298 }
299 }
300
301 fn poll_write_vectored(
302 mut self: Pin<&mut Self>,
303 cx: &mut Context<'_>,
304 bufs: &[IoSlice<'_>],
305 ) -> Poll<io::Result<usize>> {
306 loop {
307 if let Some(res) = self
308 .as_mut()
309 .write(cx, |socket| socket.get_mut().write_vectored(bufs))
310 {
311 return res;
312 }
313 }
314 }
315
316 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
317 loop {
318 if let Some(res) = self.as_mut().write(cx, |socket| socket.get_mut().flush()) {
319 return res;
320 }
321 }
322 }
323
324 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
325 self.poll_flush(cx)
326 }
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[test]
335 fn auto_traits() {
336 use crate::util::test::*;
337 let runtime = Runtime::tokio().unwrap();
338 assert_send(&runtime);
339 assert_sync(&runtime);
340 assert_clone(&runtime);
341 }
342}