1use crate::{
4 Runtime,
5 sys::AsSysFd,
6 traits::{Executor, Reactor, RuntimeKit},
7 util::Task,
8};
9use async_compat::{Compat, CompatExt};
10use futures_core::Stream;
11use futures_io::{AsyncRead, AsyncWrite};
12use std::{
13 future::Future,
14 io::{self, Read, Write},
15 net::SocketAddr,
16 sync::Arc,
17 time::{Duration, Instant},
18};
19use tokio::{
20 net::TcpStream,
21 runtime::{EnterGuard, Handle, Runtime as TokioRT},
22 time::Sleep,
23};
24use tokio_stream::{StreamExt, wrappers::IntervalStream};
25
26use task::TTask;
27
28pub type TokioRuntime = Runtime<Tokio>;
30
31impl TokioRuntime {
32 pub fn tokio() -> io::Result<Self> {
34 Ok(Self::tokio_with_runtime(TokioRT::new()?))
35 }
36
37 pub fn tokio_current() -> Self {
39 Self::new(Tokio::current())
40 }
41
42 pub fn tokio_with_handle(handle: Handle) -> Self {
44 Self::new(Tokio::default().with_handle(handle))
45 }
46
47 pub fn tokio_with_runtime(runtime: TokioRT) -> Self {
49 Self::new(Tokio::default().with_runtime(runtime))
50 }
51}
52
53#[derive(Default, Clone, Debug)]
55pub struct Tokio {
56 handle: Option<Handle>,
57 runtime: Option<Arc<TokioRT>>,
58}
59
60impl Tokio {
61 pub fn with_handle(mut self, handle: Handle) -> Self {
63 self.handle = Some(handle);
64 self
65 }
66
67 pub fn with_runtime(mut self, runtime: TokioRT) -> Self {
69 let handle = runtime.handle().clone();
70 self.runtime = Some(Arc::new(runtime));
71 self.with_handle(handle)
72 }
73
74 pub fn current() -> Self {
76 Self::default().with_handle(Handle::current())
77 }
78
79 fn handle(&self) -> Option<Handle> {
80 self.handle.clone().or_else(|| Handle::try_current().ok())
81 }
82
83 fn enter(&self) -> Option<EnterGuard<'_>> {
84 self.runtime
85 .as_ref()
86 .map(|r| r.handle())
87 .or(self.handle.as_ref())
88 .map(Handle::enter)
89 }
90}
91
92impl RuntimeKit for Tokio {}
93
94impl Executor for Tokio {
95 type Task<T: Send + 'static> = TTask<T>;
96
97 fn block_on<T, F: Future<Output = T>>(&self, f: F) -> T {
98 if let Some(runtime) = self.runtime.as_ref() {
99 runtime.block_on(f)
100 } else if let Some(handle) = self.handle() {
101 handle.block_on(f)
102 } else {
103 Handle::try_current()
104 .expect("no tokio runtime: use Runtime::tokio() or Runtime::tokio_with_handle()")
105 .block_on(f)
106 }
107 }
108
109 fn spawn<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
110 &self,
111 f: F,
112 ) -> Task<Self::Task<T>> {
113 TTask(Some(if let Some(handle) = self.handle() {
114 handle.spawn(f)
115 } else {
116 tokio::task::spawn(f)
117 }))
118 .into()
119 }
120
121 fn spawn_blocking<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(
122 &self,
123 f: F,
124 ) -> Task<Self::Task<T>> {
125 TTask(Some(if let Some(handle) = self.handle() {
126 handle.spawn_blocking(f)
127 } else {
128 tokio::task::spawn_blocking(f)
129 }))
130 .into()
131 }
132}
133
134impl Reactor for Tokio {
135 type TcpStream = Compat<TcpStream>;
136 type Sleep = Sleep;
137
138 fn register<H: Read + Write + AsSysFd + Send + 'static>(
139 &self,
140 socket: H,
141 ) -> io::Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static> {
142 let _enter = self.enter();
143 #[cfg(unix)]
144 {
145 Ok(unix::AsyncFdWrapper(tokio::io::unix::AsyncFd::new(socket)?))
146 }
147 #[cfg(not(unix))]
148 {
149 let _ = socket;
150 Err::<crate::util::DummyIO, _>(io::Error::other(
151 "Registering FD on tokio reactor is only supported on unix",
152 ))
153 }
154 }
155
156 fn sleep(&self, dur: Duration) -> Self::Sleep {
157 let _enter = self.enter();
158 tokio::time::sleep(dur)
159 }
160
161 fn interval(&self, dur: Duration) -> impl Stream<Item = Instant> + Send + 'static {
162 let _enter = self.enter();
163 IntervalStream::new(tokio::time::interval(dur)).map(tokio::time::Instant::into_std)
164 }
165
166 fn tcp_connect_addr(
167 &self,
168 addr: SocketAddr,
169 ) -> impl Future<Output = io::Result<Self::TcpStream>> + Send + 'static {
170 let _enter = self.enter();
171 async move {
172 let stream = TcpStream::connect(addr).await?;
173 stream.set_nodelay(true)?;
174 Ok(stream.compat())
175 }
176 }
177}
178
179mod task {
180 use crate::util::TaskImpl;
181 use async_trait::async_trait;
182 use std::{
183 future::Future,
184 pin::Pin,
185 task::{Context, Poll},
186 };
187
188 #[derive(Debug)]
190 pub struct TTask<T: Send + 'static>(pub(super) Option<tokio::task::JoinHandle<T>>);
191
192 #[async_trait]
193 impl<T: Send + 'static> TaskImpl for TTask<T> {
194 async fn cancel(&mut self) -> Option<T> {
195 let task = self.0.take()?;
196 task.abort();
197 task.await.ok()
198 }
199 }
200
201 impl<T: Send + 'static> Future for TTask<T> {
202 type Output = T;
203
204 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
205 match self.0.as_mut() {
206 None => Poll::Pending,
207 Some(task) => match Pin::new(task).poll(cx) {
208 Poll::Pending => Poll::Pending,
209 Poll::Ready(Ok(res)) => Poll::Ready(res),
210 Poll::Ready(Err(_)) => Poll::Pending,
211 },
212 }
213 }
214 }
215}
216
217#[cfg(unix)]
218mod unix {
219 use super::*;
220 use futures_io::{AsyncRead, AsyncWrite};
221 use std::{
222 io::{IoSlice, IoSliceMut},
223 pin::Pin,
224 task::{Context, Poll},
225 };
226 use tokio::io::unix::AsyncFd;
227
228 pub(super) struct AsyncFdWrapper<H: Read + Write + AsSysFd>(pub(super) AsyncFd<H>);
229
230 impl<H: Read + Write + AsSysFd> AsyncFdWrapper<H> {
231 fn read<F: FnOnce(&mut AsyncFd<H>) -> io::Result<usize>>(
232 mut self: Pin<&mut Self>,
233 cx: &mut Context<'_>,
234 f: F,
235 ) -> Option<Poll<io::Result<usize>>> {
236 Some(match self.0.poll_read_ready_mut(cx) {
237 Poll::Pending => Poll::Pending,
238 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
239 Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
240 Ok(res) => Poll::Ready(res),
241 Err(_) => return None,
242 },
243 })
244 }
245
246 fn write<R, F: FnOnce(&mut AsyncFd<H>) -> io::Result<R>>(
247 mut self: Pin<&mut Self>,
248 cx: &mut Context<'_>,
249 f: F,
250 ) -> Option<Poll<io::Result<R>>> {
251 Some(match self.0.poll_write_ready_mut(cx) {
252 Poll::Pending => Poll::Pending,
253 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
254 Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
255 Ok(res) => Poll::Ready(res),
256 Err(_) => return None,
257 },
258 })
259 }
260 }
261
262 impl<H: Read + Write + AsSysFd> Unpin for AsyncFdWrapper<H> {}
263
264 impl<H: Read + Write + AsSysFd> AsyncRead for AsyncFdWrapper<H> {
265 fn poll_read(
266 mut self: Pin<&mut Self>,
267 cx: &mut Context<'_>,
268 buf: &mut [u8],
269 ) -> Poll<io::Result<usize>> {
270 loop {
271 if let Some(res) = self.as_mut().read(cx, |socket| socket.get_mut().read(buf)) {
272 return res;
273 }
274 }
275 }
276
277 fn poll_read_vectored(
278 mut self: Pin<&mut Self>,
279 cx: &mut Context<'_>,
280 bufs: &mut [IoSliceMut<'_>],
281 ) -> Poll<io::Result<usize>> {
282 loop {
283 if let Some(res) = self
284 .as_mut()
285 .read(cx, |socket| socket.get_mut().read_vectored(bufs))
286 {
287 return res;
288 }
289 }
290 }
291 }
292
293 impl<H: Read + Write + AsSysFd + Send + 'static> AsyncWrite for AsyncFdWrapper<H> {
294 fn poll_write(
295 mut self: Pin<&mut Self>,
296 cx: &mut Context<'_>,
297 buf: &[u8],
298 ) -> Poll<io::Result<usize>> {
299 loop {
300 if let Some(res) = self
301 .as_mut()
302 .write(cx, |socket| socket.get_mut().write(buf))
303 {
304 return res;
305 }
306 }
307 }
308
309 fn poll_write_vectored(
310 mut self: Pin<&mut Self>,
311 cx: &mut Context<'_>,
312 bufs: &[IoSlice<'_>],
313 ) -> Poll<io::Result<usize>> {
314 loop {
315 if let Some(res) = self
316 .as_mut()
317 .write(cx, |socket| socket.get_mut().write_vectored(bufs))
318 {
319 return res;
320 }
321 }
322 }
323
324 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
325 loop {
326 if let Some(res) = self.as_mut().write(cx, |socket| socket.get_mut().flush()) {
327 return res;
328 }
329 }
330 }
331
332 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
333 self.poll_flush(cx)
334 }
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 #[test]
343 fn auto_traits() {
344 use crate::util::test::*;
345 let runtime = Runtime::tokio().unwrap();
346 assert_send(&runtime);
347 assert_sync(&runtime);
348 assert_clone(&runtime);
349 }
350}