1use crate::{
4 Runtime,
5 sys::AsSysFd,
6 traits::{Executor, Reactor, RuntimeKit},
7 util::Task,
8};
9use async_compat::{Compat, CompatExt};
10use cfg_if::cfg_if;
11use futures_core::Stream;
12use futures_io::{AsyncRead, AsyncWrite};
13use std::{
14 future::Future,
15 io::{self, Read, Write},
16 net::SocketAddr,
17 pin::Pin,
18 sync::Arc,
19 task::{Context, Poll},
20 time::{Duration, Instant},
21};
22use tokio::{
23 net::TcpStream,
24 runtime::{EnterGuard, Handle, Runtime as TokioRT},
25 time::Sleep,
26};
27use tokio_stream::{StreamExt, wrappers::IntervalStream};
28
29use task::TTask;
30
31pub type TokioRuntime = Runtime<Tokio>;
33
34impl TokioRuntime {
35 pub fn tokio() -> io::Result<Self> {
37 Ok(Self::tokio_with_runtime(TokioRT::new()?))
38 }
39
40 pub fn tokio_current() -> Self {
42 Self::new(Tokio::current())
43 }
44
45 pub fn tokio_with_handle(handle: Handle) -> Self {
47 Self::new(Tokio::default().with_handle(handle))
48 }
49
50 pub fn tokio_with_runtime(runtime: TokioRT) -> Self {
52 Self::new(Tokio::default().with_runtime(runtime))
53 }
54}
55
56#[derive(Default, Clone, Debug)]
58pub struct Tokio {
59 handle: Option<Handle>,
60 runtime: Option<Arc<TokioRT>>,
61}
62
63impl Tokio {
64 pub fn with_handle(mut self, handle: Handle) -> Self {
66 self.handle = Some(handle);
67 self
68 }
69
70 pub fn with_runtime(mut self, runtime: TokioRT) -> Self {
72 let handle = runtime.handle().clone();
73 self.runtime = Some(Arc::new(runtime));
74 self.with_handle(handle)
75 }
76
77 pub fn current() -> Self {
79 Self::default().with_handle(Handle::current())
80 }
81
82 fn handle(&self) -> Option<Handle> {
83 self.handle.clone().or_else(|| Handle::try_current().ok())
84 }
85
86 fn enter(&self) -> Option<EnterGuard<'_>> {
87 self.runtime
88 .as_ref()
89 .map(|r| r.handle())
90 .or(self.handle.as_ref())
91 .map(Handle::enter)
92 }
93}
94
95impl RuntimeKit for Tokio {}
96
97impl Executor for Tokio {
98 type Task<T: Send + 'static> = TTask<T>;
99
100 fn block_on<T, F: Future<Output = T>>(&self, f: F) -> T {
101 if let Some(runtime) = self.runtime.as_ref() {
102 runtime.block_on(f)
103 } else if let Some(handle) = self.handle() {
104 handle.block_on(f)
105 } else {
106 Handle::current().block_on(f)
107 }
108 }
109
110 fn spawn<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
111 &self,
112 f: F,
113 ) -> Task<Self::Task<T>> {
114 TTask(Some(if let Some(handle) = self.handle() {
115 handle.spawn(f)
116 } else {
117 tokio::task::spawn(f)
118 }))
119 .into()
120 }
121
122 fn spawn_blocking<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(
123 &self,
124 f: F,
125 ) -> Task<Self::Task<T>> {
126 TTask(Some(if let Some(handle) = self.handle() {
127 handle.spawn_blocking(f)
128 } else {
129 tokio::task::spawn_blocking(f)
130 }))
131 .into()
132 }
133}
134
135impl Reactor for Tokio {
136 type TcpStream = Compat<TcpStream>;
137 type Sleep = Sleep;
138
139 fn register<H: Read + Write + AsSysFd + Send + 'static>(
140 &self,
141 socket: H,
142 ) -> io::Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static> {
143 let _enter = self.enter();
144 cfg_if! {
145 if #[cfg(unix)] {
146 Ok(unix::AsyncFdWrapper(
147 tokio::io::unix::AsyncFd::new(socket)?,
148 ))
149 } else {
150 Err::<crate::util::DummyIO, _>(io::Error::other(
151 "Registering FD on tokio reactor is only supported on unix",
152 ))
153 }
154 }
155 }
156
157 fn sleep(&self, dur: Duration) -> Self::Sleep {
158 let _enter = self.enter();
159 tokio::time::sleep(dur)
160 }
161
162 fn interval(&self, dur: Duration) -> impl Stream<Item = Instant> + Send + 'static {
163 let _enter = self.enter();
164 IntervalStream::new(tokio::time::interval(dur)).map(tokio::time::Instant::into_std)
165 }
166
167 fn tcp_connect_addr(
168 &self,
169 addr: SocketAddr,
170 ) -> impl Future<Output = io::Result<Self::TcpStream>> + Send + 'static {
171 let _enter = self.enter();
172 async move { Ok(TcpStream::connect(addr).await?.compat()) }
173 }
174}
175
176mod task {
177 use crate::util::TaskImpl;
178 use async_trait::async_trait;
179 use std::{
180 future::Future,
181 pin::Pin,
182 task::{Context, Poll},
183 };
184
185 #[derive(Debug)]
187 pub struct TTask<T: Send + 'static>(pub(super) Option<tokio::task::JoinHandle<T>>);
188
189 #[async_trait]
190 impl<T: Send + 'static> TaskImpl for TTask<T> {
191 async fn cancel(&mut self) -> Option<T> {
192 let task = self.0.take()?;
193 task.abort();
194 task.await.ok()
195 }
196 }
197
198 impl<T: Send + 'static> Future for TTask<T> {
199 type Output = T;
200
201 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
202 let task = self.0.as_mut().expect("task has been canceled");
203 match Pin::new(task).poll(cx) {
204 Poll::Pending => Poll::Pending,
205 Poll::Ready(res) => Poll::Ready(res.expect("task has been canceled")),
206 }
207 }
208 }
209}
210
211#[cfg(unix)]
212mod unix {
213 use super::*;
214 use futures_io::{AsyncRead, AsyncWrite};
215 use std::io::{IoSlice, IoSliceMut};
216 use tokio::io::unix::AsyncFd;
217
218 pub(super) struct AsyncFdWrapper<H: Read + Write + AsSysFd>(pub(super) AsyncFd<H>);
219
220 impl<H: Read + Write + AsSysFd> AsyncFdWrapper<H> {
221 fn read<F: FnOnce(&mut AsyncFd<H>) -> io::Result<usize>>(
222 mut self: Pin<&mut Self>,
223 cx: &mut Context<'_>,
224 f: F,
225 ) -> Option<Poll<io::Result<usize>>> {
226 Some(match self.0.poll_read_ready_mut(cx) {
227 Poll::Pending => Poll::Pending,
228 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
229 Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
230 Ok(res) => Poll::Ready(res),
231 Err(_) => return None,
232 },
233 })
234 }
235
236 fn write<R, F: FnOnce(&mut AsyncFd<H>) -> io::Result<R>>(
237 mut self: Pin<&mut Self>,
238 cx: &mut Context<'_>,
239 f: F,
240 ) -> Option<Poll<io::Result<R>>> {
241 Some(match self.0.poll_write_ready_mut(cx) {
242 Poll::Pending => Poll::Pending,
243 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
244 Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
245 Ok(res) => Poll::Ready(res),
246 Err(_) => return None,
247 },
248 })
249 }
250 }
251
252 impl<H: Read + Write + AsSysFd> Unpin for AsyncFdWrapper<H> {}
253
254 impl<H: Read + Write + AsSysFd> AsyncRead for AsyncFdWrapper<H> {
255 fn poll_read(
256 mut self: Pin<&mut Self>,
257 cx: &mut Context<'_>,
258 buf: &mut [u8],
259 ) -> Poll<io::Result<usize>> {
260 loop {
261 if let Some(res) = self.as_mut().read(cx, |socket| socket.get_mut().read(buf)) {
262 return res;
263 }
264 }
265 }
266
267 fn poll_read_vectored(
268 mut self: Pin<&mut Self>,
269 cx: &mut Context<'_>,
270 bufs: &mut [IoSliceMut<'_>],
271 ) -> Poll<io::Result<usize>> {
272 loop {
273 if let Some(res) = self
274 .as_mut()
275 .read(cx, |socket| socket.get_mut().read_vectored(bufs))
276 {
277 return res;
278 }
279 }
280 }
281 }
282
283 impl<H: Read + Write + AsSysFd + Send + 'static> AsyncWrite for AsyncFdWrapper<H> {
284 fn poll_write(
285 mut self: Pin<&mut Self>,
286 cx: &mut Context<'_>,
287 buf: &[u8],
288 ) -> Poll<io::Result<usize>> {
289 loop {
290 if let Some(res) = self
291 .as_mut()
292 .write(cx, |socket| socket.get_mut().write(buf))
293 {
294 return res;
295 }
296 }
297 }
298
299 fn poll_write_vectored(
300 mut self: Pin<&mut Self>,
301 cx: &mut Context<'_>,
302 bufs: &[IoSlice<'_>],
303 ) -> Poll<io::Result<usize>> {
304 loop {
305 if let Some(res) = self
306 .as_mut()
307 .write(cx, |socket| socket.get_mut().write_vectored(bufs))
308 {
309 return res;
310 }
311 }
312 }
313
314 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
315 loop {
316 if let Some(res) = self.as_mut().write(cx, |socket| socket.get_mut().flush()) {
317 return res;
318 }
319 }
320 }
321
322 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
323 self.poll_flush(cx)
324 }
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331
332 #[test]
333 fn auto_traits() {
334 use crate::util::test::*;
335 let runtime = Runtime::tokio().unwrap();
336 assert_send(&runtime);
337 assert_sync(&runtime);
338 assert_clone(&runtime);
339 }
340}