1use crate::{
4 Runtime,
5 sys::AsSysFd,
6 traits::{Executor, Reactor, RuntimeKit},
7 util::Task,
8};
9use async_compat::{Compat, CompatExt};
10use futures_core::Stream;
11use futures_io::{AsyncRead, AsyncWrite};
12use std::{
13 future::Future,
14 io::{self, Read, Write},
15 net::SocketAddr,
16 sync::Arc,
17 time::{Duration, Instant},
18};
19use tokio::{
20 net::TcpStream,
21 runtime::{EnterGuard, Handle, Runtime as TokioRT},
22 time::Sleep,
23};
24use tokio_stream::{StreamExt, wrappers::IntervalStream};
25
26use task::TTask;
27
28pub type TokioRuntime = Runtime<Tokio>;
30
31impl TokioRuntime {
32 pub fn tokio() -> io::Result<Self> {
34 Ok(Self::tokio_with_runtime(TokioRT::new()?))
35 }
36
37 #[must_use]
39 pub fn tokio_current() -> Self {
40 Self::new(Tokio::current())
41 }
42
43 #[must_use]
45 pub fn tokio_with_handle(handle: Handle) -> Self {
46 Self::new(Tokio::default().with_handle(handle))
47 }
48
49 pub fn tokio_with_runtime(runtime: TokioRT) -> Self {
51 Self::new(Tokio::default().with_runtime(runtime))
52 }
53}
54
55#[derive(Default, Clone, Debug)]
57pub struct Tokio {
58 handle: Option<Handle>,
59 runtime: Option<Arc<TokioRT>>,
60}
61
62impl Tokio {
63 #[must_use]
65 pub fn with_handle(mut self, handle: Handle) -> Self {
66 self.handle = Some(handle);
67 self
68 }
69
70 pub fn with_runtime(mut self, runtime: TokioRT) -> Self {
72 let handle = runtime.handle().clone();
73 self.runtime = Some(Arc::new(runtime));
74 self.with_handle(handle)
75 }
76
77 #[must_use]
79 pub fn current() -> Self {
80 Self::default().with_handle(Handle::current())
81 }
82
83 fn handle(&self) -> Option<Handle> {
84 self.handle.clone().or_else(|| Handle::try_current().ok())
85 }
86
87 fn enter(&self) -> Option<EnterGuard<'_>> {
88 self.runtime
89 .as_ref()
90 .map(|r| r.handle())
91 .or(self.handle.as_ref())
92 .map(Handle::enter)
93 }
94}
95
96impl RuntimeKit for Tokio {}
97
98impl Executor for Tokio {
99 type Task<T: Send + 'static> = TTask<T>;
100
101 fn block_on<T, F: Future<Output = T>>(&self, f: F) -> T {
102 if let Some(runtime) = self.runtime.as_ref() {
103 runtime.block_on(f)
104 } else if let Some(handle) = self.handle() {
105 handle.block_on(f)
106 } else {
107 Handle::try_current()
108 .expect("no tokio runtime: use Runtime::tokio() or Runtime::tokio_with_handle()")
109 .block_on(f)
110 }
111 }
112
113 fn spawn<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
114 &self,
115 f: F,
116 ) -> Task<Self::Task<T>> {
117 TTask(Some(if let Some(handle) = self.handle() {
118 handle.spawn(f)
119 } else {
120 tokio::task::spawn(f)
121 }))
122 .into()
123 }
124
125 fn spawn_blocking<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(
126 &self,
127 f: F,
128 ) -> Task<Self::Task<T>> {
129 TTask(Some(if let Some(handle) = self.handle() {
130 handle.spawn_blocking(f)
131 } else {
132 tokio::task::spawn_blocking(f)
133 }))
134 .into()
135 }
136}
137
138impl Reactor for Tokio {
139 type TcpStream = Compat<TcpStream>;
140 type Sleep = Sleep;
141
142 fn register<H: Read + Write + AsSysFd + Send + 'static>(
143 &self,
144 socket: H,
145 ) -> io::Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static> {
146 let _enter = self.enter();
147 #[cfg(unix)]
148 {
149 Ok(unix::AsyncFdWrapper(tokio::io::unix::AsyncFd::new(socket)?))
150 }
151 #[cfg(not(unix))]
152 {
153 let _ = socket;
154 Err::<crate::util::DummyIO, _>(io::Error::other(
155 "Registering FD on tokio reactor is only supported on unix",
156 ))
157 }
158 }
159
160 fn sleep(&self, dur: Duration) -> Self::Sleep {
161 let _enter = self.enter();
162 tokio::time::sleep(dur)
163 }
164
165 fn interval(&self, dur: Duration) -> impl Stream<Item = Instant> + Send + 'static {
166 let _enter = self.enter();
167 IntervalStream::new(tokio::time::interval(dur)).map(tokio::time::Instant::into_std)
168 }
169
170 fn tcp_connect_addr(
171 &self,
172 addr: SocketAddr,
173 ) -> impl Future<Output = io::Result<Self::TcpStream>> + Send + 'static {
174 let _enter = self.enter();
175 async move {
176 let stream = TcpStream::connect(addr).await?;
177 stream.set_nodelay(true)?;
178 Ok(stream.compat())
179 }
180 }
181}
182
183mod task {
184 use crate::util::TaskImpl;
185 use async_trait::async_trait;
186 use std::{
187 future::Future,
188 pin::Pin,
189 task::{Context, Poll},
190 };
191
192 #[derive(Debug)]
194 pub struct TTask<T: Send + 'static>(pub(super) Option<tokio::task::JoinHandle<T>>);
195
196 #[async_trait]
197 impl<T: Send + 'static> TaskImpl for TTask<T> {
198 async fn cancel(&mut self) -> Option<T> {
199 let task = self.0.take()?;
200 task.abort();
201 task.await.ok()
202 }
203 }
204
205 impl<T: Send + 'static> Future for TTask<T> {
206 type Output = T;
207
208 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
209 match self.0.as_mut() {
210 None => Poll::Pending,
211 Some(task) => match Pin::new(task).poll(cx) {
212 Poll::Pending => Poll::Pending,
213 Poll::Ready(Ok(res)) => Poll::Ready(res),
214 Poll::Ready(Err(_)) => Poll::Pending,
215 },
216 }
217 }
218 }
219}
220
221#[cfg(unix)]
222mod unix {
223 use super::*;
224 use futures_io::{AsyncRead, AsyncWrite};
225 use std::{
226 io::{IoSlice, IoSliceMut},
227 pin::Pin,
228 task::{Context, Poll},
229 };
230 use tokio::io::unix::AsyncFd;
231
232 pub(super) struct AsyncFdWrapper<H: Read + Write + AsSysFd>(pub(super) AsyncFd<H>);
233
234 impl<H: Read + Write + AsSysFd> AsyncFdWrapper<H> {
235 fn read<F: FnOnce(&mut AsyncFd<H>) -> io::Result<usize>>(
236 mut self: Pin<&mut Self>,
237 cx: &mut Context<'_>,
238 f: F,
239 ) -> Option<Poll<io::Result<usize>>> {
240 Some(match self.0.poll_read_ready_mut(cx) {
241 Poll::Pending => Poll::Pending,
242 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
243 Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
244 Ok(res) => Poll::Ready(res),
245 Err(_) => return None,
246 },
247 })
248 }
249
250 fn write<R, F: FnOnce(&mut AsyncFd<H>) -> io::Result<R>>(
251 mut self: Pin<&mut Self>,
252 cx: &mut Context<'_>,
253 f: F,
254 ) -> Option<Poll<io::Result<R>>> {
255 Some(match self.0.poll_write_ready_mut(cx) {
256 Poll::Pending => Poll::Pending,
257 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
258 Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
259 Ok(res) => Poll::Ready(res),
260 Err(_) => return None,
261 },
262 })
263 }
264 }
265
266 impl<H: Read + Write + AsSysFd> Unpin for AsyncFdWrapper<H> {}
267
268 impl<H: Read + Write + AsSysFd> AsyncRead for AsyncFdWrapper<H> {
269 fn poll_read(
270 mut self: Pin<&mut Self>,
271 cx: &mut Context<'_>,
272 buf: &mut [u8],
273 ) -> Poll<io::Result<usize>> {
274 loop {
275 if let Some(res) = self.as_mut().read(cx, |socket| socket.get_mut().read(buf)) {
276 return res;
277 }
278 }
279 }
280
281 fn poll_read_vectored(
282 mut self: Pin<&mut Self>,
283 cx: &mut Context<'_>,
284 bufs: &mut [IoSliceMut<'_>],
285 ) -> Poll<io::Result<usize>> {
286 loop {
287 if let Some(res) = self
288 .as_mut()
289 .read(cx, |socket| socket.get_mut().read_vectored(bufs))
290 {
291 return res;
292 }
293 }
294 }
295 }
296
297 impl<H: Read + Write + AsSysFd + Send + 'static> AsyncWrite for AsyncFdWrapper<H> {
298 fn poll_write(
299 mut self: Pin<&mut Self>,
300 cx: &mut Context<'_>,
301 buf: &[u8],
302 ) -> Poll<io::Result<usize>> {
303 loop {
304 if let Some(res) = self
305 .as_mut()
306 .write(cx, |socket| socket.get_mut().write(buf))
307 {
308 return res;
309 }
310 }
311 }
312
313 fn poll_write_vectored(
314 mut self: Pin<&mut Self>,
315 cx: &mut Context<'_>,
316 bufs: &[IoSlice<'_>],
317 ) -> Poll<io::Result<usize>> {
318 loop {
319 if let Some(res) = self
320 .as_mut()
321 .write(cx, |socket| socket.get_mut().write_vectored(bufs))
322 {
323 return res;
324 }
325 }
326 }
327
328 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
329 loop {
330 if let Some(res) = self.as_mut().write(cx, |socket| socket.get_mut().flush()) {
331 return res;
332 }
333 }
334 }
335
336 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
337 self.poll_flush(cx)
338 }
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345
346 #[test]
347 fn auto_traits() {
348 use crate::util::test::*;
349 let runtime = Runtime::tokio().unwrap();
350 assert_send(&runtime);
351 assert_sync(&runtime);
352 assert_clone(&runtime);
353 }
354}