1use crate::{
4 Runtime,
5 sys::AsSysFd,
6 traits::{Executor, Reactor, RuntimeKit},
7 util::Task,
8};
9use async_compat::{Compat, CompatExt};
10use futures_core::Stream;
11use futures_io::{AsyncRead, AsyncWrite};
12use std::{
13 future::Future,
14 io::{self, Read, Write},
15 net::SocketAddr,
16 sync::Arc,
17 time::{Duration, Instant},
18};
19use tokio::{
20 net::TcpStream,
21 runtime::{EnterGuard, Handle, Runtime as TokioRT},
22 time::Sleep,
23};
24use tokio_stream::{StreamExt, wrappers::IntervalStream};
25
26use task::TTask;
27
28pub type TokioRuntime = Runtime<Tokio>;
30
31impl TokioRuntime {
32 pub fn tokio() -> io::Result<Self> {
34 Ok(Self::tokio_with_runtime(TokioRT::new()?))
35 }
36
37 pub fn tokio_current() -> Self {
39 Self::new(Tokio::current())
40 }
41
42 pub fn tokio_with_handle(handle: Handle) -> Self {
44 Self::new(Tokio::default().with_handle(handle))
45 }
46
47 pub fn tokio_with_runtime(runtime: TokioRT) -> Self {
49 Self::new(Tokio::default().with_runtime(runtime))
50 }
51}
52
53#[derive(Default, Clone, Debug)]
55pub struct Tokio {
56 handle: Option<Handle>,
57 runtime: Option<Arc<TokioRT>>,
58}
59
60impl Tokio {
61 pub fn with_handle(mut self, handle: Handle) -> Self {
63 self.handle = Some(handle);
64 self
65 }
66
67 pub fn with_runtime(mut self, runtime: TokioRT) -> Self {
69 let handle = runtime.handle().clone();
70 self.runtime = Some(Arc::new(runtime));
71 self.with_handle(handle)
72 }
73
74 pub fn current() -> Self {
76 Self::default().with_handle(Handle::current())
77 }
78
79 fn handle(&self) -> Option<Handle> {
80 self.handle.clone().or_else(|| Handle::try_current().ok())
81 }
82
83 fn enter(&self) -> Option<EnterGuard<'_>> {
84 self.runtime
85 .as_ref()
86 .map(|r| r.handle())
87 .or(self.handle.as_ref())
88 .map(Handle::enter)
89 }
90}
91
92impl RuntimeKit for Tokio {}
93
94impl Executor for Tokio {
95 type Task<T: Send + 'static> = TTask<T>;
96
97 fn block_on<T, F: Future<Output = T>>(&self, f: F) -> T {
98 if let Some(runtime) = self.runtime.as_ref() {
99 runtime.block_on(f)
100 } else if let Some(handle) = self.handle() {
101 handle.block_on(f)
102 } else {
103 Handle::try_current()
104 .expect("no tokio runtime: use Runtime::tokio() or Runtime::tokio_with_handle()")
105 .block_on(f)
106 }
107 }
108
109 fn spawn<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
110 &self,
111 f: F,
112 ) -> Task<Self::Task<T>> {
113 TTask(Some(if let Some(handle) = self.handle() {
114 handle.spawn(f)
115 } else {
116 tokio::task::spawn(f)
117 }))
118 .into()
119 }
120
121 fn spawn_blocking<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(
122 &self,
123 f: F,
124 ) -> Task<Self::Task<T>> {
125 TTask(Some(if let Some(handle) = self.handle() {
126 handle.spawn_blocking(f)
127 } else {
128 tokio::task::spawn_blocking(f)
129 }))
130 .into()
131 }
132}
133
134impl Reactor for Tokio {
135 type TcpStream = Compat<TcpStream>;
136 type Sleep = Sleep;
137
138 fn register<H: Read + Write + AsSysFd + Send + 'static>(
139 &self,
140 socket: H,
141 ) -> io::Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static> {
142 let _enter = self.enter();
143 #[cfg(unix)]
144 {
145 Ok(unix::AsyncFdWrapper(tokio::io::unix::AsyncFd::new(socket)?))
146 }
147 #[cfg(not(unix))]
148 {
149 let _ = socket;
150 Err::<crate::util::DummyIO, _>(io::Error::other(
151 "Registering FD on tokio reactor is only supported on unix",
152 ))
153 }
154 }
155
156 fn sleep(&self, dur: Duration) -> Self::Sleep {
157 let _enter = self.enter();
158 tokio::time::sleep(dur)
159 }
160
161 fn interval(&self, dur: Duration) -> impl Stream<Item = Instant> + Send + 'static {
162 let _enter = self.enter();
163 IntervalStream::new(tokio::time::interval(dur)).map(tokio::time::Instant::into_std)
164 }
165
166 fn tcp_connect_addr(
167 &self,
168 addr: SocketAddr,
169 ) -> impl Future<Output = io::Result<Self::TcpStream>> + Send + 'static {
170 let _enter = self.enter();
171 async move {
172 let stream = TcpStream::connect(addr).await?;
173 stream.set_nodelay(true)?;
174 Ok(stream.compat())
175 }
176 }
177}
178
179mod task {
180 use crate::util::TaskImpl;
181 use async_trait::async_trait;
182 use std::{
183 future::Future,
184 pin::Pin,
185 task::{Context, Poll},
186 };
187
188 #[derive(Debug)]
190 pub struct TTask<T: Send + 'static>(pub(super) Option<tokio::task::JoinHandle<T>>);
191
192 #[async_trait]
193 impl<T: Send + 'static> TaskImpl for TTask<T> {
194 async fn cancel(&mut self) -> Option<T> {
195 let task = self.0.take()?;
196 task.abort();
197 task.await.ok()
198 }
199 }
200
201 impl<T: Send + 'static> Future for TTask<T> {
202 type Output = T;
203
204 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
205 let task = self.0.as_mut().expect("task has been canceled");
206 match Pin::new(task).poll(cx) {
207 Poll::Pending => Poll::Pending,
208 Poll::Ready(res) => Poll::Ready(res.expect("task has been canceled")),
209 }
210 }
211 }
212}
213
214#[cfg(unix)]
215mod unix {
216 use super::*;
217 use futures_io::{AsyncRead, AsyncWrite};
218 use std::{
219 io::{IoSlice, IoSliceMut},
220 pin::Pin,
221 task::{Context, Poll},
222 };
223 use tokio::io::unix::AsyncFd;
224
225 pub(super) struct AsyncFdWrapper<H: Read + Write + AsSysFd>(pub(super) AsyncFd<H>);
226
227 impl<H: Read + Write + AsSysFd> AsyncFdWrapper<H> {
228 fn read<F: FnOnce(&mut AsyncFd<H>) -> io::Result<usize>>(
229 mut self: Pin<&mut Self>,
230 cx: &mut Context<'_>,
231 f: F,
232 ) -> Option<Poll<io::Result<usize>>> {
233 Some(match self.0.poll_read_ready_mut(cx) {
234 Poll::Pending => Poll::Pending,
235 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
236 Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
237 Ok(res) => Poll::Ready(res),
238 Err(_) => return None,
239 },
240 })
241 }
242
243 fn write<R, F: FnOnce(&mut AsyncFd<H>) -> io::Result<R>>(
244 mut self: Pin<&mut Self>,
245 cx: &mut Context<'_>,
246 f: F,
247 ) -> Option<Poll<io::Result<R>>> {
248 Some(match self.0.poll_write_ready_mut(cx) {
249 Poll::Pending => Poll::Pending,
250 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
251 Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
252 Ok(res) => Poll::Ready(res),
253 Err(_) => return None,
254 },
255 })
256 }
257 }
258
259 impl<H: Read + Write + AsSysFd> Unpin for AsyncFdWrapper<H> {}
260
261 impl<H: Read + Write + AsSysFd> AsyncRead for AsyncFdWrapper<H> {
262 fn poll_read(
263 mut self: Pin<&mut Self>,
264 cx: &mut Context<'_>,
265 buf: &mut [u8],
266 ) -> Poll<io::Result<usize>> {
267 loop {
268 if let Some(res) = self.as_mut().read(cx, |socket| socket.get_mut().read(buf)) {
269 return res;
270 }
271 }
272 }
273
274 fn poll_read_vectored(
275 mut self: Pin<&mut Self>,
276 cx: &mut Context<'_>,
277 bufs: &mut [IoSliceMut<'_>],
278 ) -> Poll<io::Result<usize>> {
279 loop {
280 if let Some(res) = self
281 .as_mut()
282 .read(cx, |socket| socket.get_mut().read_vectored(bufs))
283 {
284 return res;
285 }
286 }
287 }
288 }
289
290 impl<H: Read + Write + AsSysFd + Send + 'static> AsyncWrite for AsyncFdWrapper<H> {
291 fn poll_write(
292 mut self: Pin<&mut Self>,
293 cx: &mut Context<'_>,
294 buf: &[u8],
295 ) -> Poll<io::Result<usize>> {
296 loop {
297 if let Some(res) = self
298 .as_mut()
299 .write(cx, |socket| socket.get_mut().write(buf))
300 {
301 return res;
302 }
303 }
304 }
305
306 fn poll_write_vectored(
307 mut self: Pin<&mut Self>,
308 cx: &mut Context<'_>,
309 bufs: &[IoSlice<'_>],
310 ) -> Poll<io::Result<usize>> {
311 loop {
312 if let Some(res) = self
313 .as_mut()
314 .write(cx, |socket| socket.get_mut().write_vectored(bufs))
315 {
316 return res;
317 }
318 }
319 }
320
321 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
322 loop {
323 if let Some(res) = self.as_mut().write(cx, |socket| socket.get_mut().flush()) {
324 return res;
325 }
326 }
327 }
328
329 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
330 self.poll_flush(cx)
331 }
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338
339 #[test]
340 fn auto_traits() {
341 use crate::util::test::*;
342 let runtime = Runtime::tokio().unwrap();
343 assert_send(&runtime);
344 assert_sync(&runtime);
345 assert_clone(&runtime);
346 }
347}