1use crate::{
4 Runtime,
5 sys::AsSysFd,
6 traits::{Executor, Reactor, RuntimeKit},
7 util::Task,
8};
9use async_compat::{Compat, CompatExt};
10use futures_core::Stream;
11use futures_io::{AsyncRead, AsyncWrite};
12use std::{
13 future::Future,
14 io::{self, Read, Write},
15 net::SocketAddr,
16 pin::Pin,
17 sync::Arc,
18 task::{Context, Poll},
19 time::{Duration, Instant},
20};
21use tokio::{
22 net::TcpStream,
23 runtime::{EnterGuard, Handle, Runtime as TokioRT},
24 time::Sleep,
25};
26use tokio_stream::{StreamExt, wrappers::IntervalStream};
27
28use task::TTask;
29
30pub type TokioRuntime = Runtime<Tokio>;
32
33impl TokioRuntime {
34 pub fn tokio() -> io::Result<Self> {
36 Ok(Self::tokio_with_runtime(TokioRT::new()?))
37 }
38
39 pub fn tokio_current() -> Self {
41 Self::new(Tokio::current())
42 }
43
44 pub fn tokio_with_handle(handle: Handle) -> Self {
46 Self::new(Tokio::default().with_handle(handle))
47 }
48
49 pub fn tokio_with_runtime(runtime: TokioRT) -> Self {
51 Self::new(Tokio::default().with_runtime(runtime))
52 }
53}
54
55#[derive(Default, Clone, Debug)]
57pub struct Tokio {
58 handle: Option<Handle>,
59 runtime: Option<Arc<TokioRT>>,
60}
61
62impl Tokio {
63 pub fn with_handle(mut self, handle: Handle) -> Self {
65 self.handle = Some(handle);
66 self
67 }
68
69 pub fn with_runtime(mut self, runtime: TokioRT) -> Self {
71 let handle = runtime.handle().clone();
72 self.runtime = Some(Arc::new(runtime));
73 self.with_handle(handle)
74 }
75
76 pub fn current() -> Self {
78 Self::default().with_handle(Handle::current())
79 }
80
81 fn handle(&self) -> Option<Handle> {
82 self.handle.clone().or_else(|| Handle::try_current().ok())
83 }
84
85 fn enter(&self) -> Option<EnterGuard<'_>> {
86 self.runtime
87 .as_ref()
88 .map(|r| r.handle())
89 .or(self.handle.as_ref())
90 .map(Handle::enter)
91 }
92}
93
94impl RuntimeKit for Tokio {}
95
96impl Executor for Tokio {
97 type Task<T: Send + 'static> = TTask<T>;
98
99 fn block_on<T, F: Future<Output = T>>(&self, f: F) -> T {
100 if let Some(runtime) = self.runtime.as_ref() {
101 runtime.block_on(f)
102 } else if let Some(handle) = self.handle() {
103 handle.block_on(f)
104 } else {
105 Handle::try_current()
106 .expect("no tokio runtime: use Runtime::tokio() or Runtime::tokio_with_handle()")
107 .block_on(f)
108 }
109 }
110
111 fn spawn<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
112 &self,
113 f: F,
114 ) -> Task<Self::Task<T>> {
115 TTask(Some(if let Some(handle) = self.handle() {
116 handle.spawn(f)
117 } else {
118 tokio::task::spawn(f)
119 }))
120 .into()
121 }
122
123 fn spawn_blocking<T: Send + 'static, F: FnOnce() -> T + Send + 'static>(
124 &self,
125 f: F,
126 ) -> Task<Self::Task<T>> {
127 TTask(Some(if let Some(handle) = self.handle() {
128 handle.spawn_blocking(f)
129 } else {
130 tokio::task::spawn_blocking(f)
131 }))
132 .into()
133 }
134}
135
136impl Reactor for Tokio {
137 type TcpStream = Compat<TcpStream>;
138 type Sleep = Sleep;
139
140 fn register<H: Read + Write + AsSysFd + Send + 'static>(
141 &self,
142 socket: H,
143 ) -> io::Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static> {
144 let _enter = self.enter();
145 #[cfg(unix)]
146 {
147 Ok(unix::AsyncFdWrapper(tokio::io::unix::AsyncFd::new(socket)?))
148 }
149 #[cfg(not(unix))]
150 {
151 let _ = socket;
152 Err::<crate::util::DummyIO, _>(io::Error::other(
153 "Registering FD on tokio reactor is only supported on unix",
154 ))
155 }
156 }
157
158 fn sleep(&self, dur: Duration) -> Self::Sleep {
159 let _enter = self.enter();
160 tokio::time::sleep(dur)
161 }
162
163 fn interval(&self, dur: Duration) -> impl Stream<Item = Instant> + Send + 'static {
164 let _enter = self.enter();
165 IntervalStream::new(tokio::time::interval(dur)).map(tokio::time::Instant::into_std)
166 }
167
168 fn tcp_connect_addr(
169 &self,
170 addr: SocketAddr,
171 ) -> impl Future<Output = io::Result<Self::TcpStream>> + Send + 'static {
172 let _enter = self.enter();
173 async move { Ok(TcpStream::connect(addr).await?.compat()) }
174 }
175}
176
177mod task {
178 use crate::util::TaskImpl;
179 use async_trait::async_trait;
180 use std::{
181 future::Future,
182 pin::Pin,
183 task::{Context, Poll},
184 };
185
186 #[derive(Debug)]
188 pub struct TTask<T: Send + 'static>(pub(super) Option<tokio::task::JoinHandle<T>>);
189
190 #[async_trait]
191 impl<T: Send + 'static> TaskImpl for TTask<T> {
192 async fn cancel(&mut self) -> Option<T> {
193 let task = self.0.take()?;
194 task.abort();
195 task.await.ok()
196 }
197 }
198
199 impl<T: Send + 'static> Future for TTask<T> {
200 type Output = T;
201
202 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
203 let task = self.0.as_mut().expect("task has been canceled");
204 match Pin::new(task).poll(cx) {
205 Poll::Pending => Poll::Pending,
206 Poll::Ready(res) => Poll::Ready(res.expect("task has been canceled")),
207 }
208 }
209 }
210}
211
212#[cfg(unix)]
213mod unix {
214 use super::*;
215 use futures_io::{AsyncRead, AsyncWrite};
216 use std::io::{IoSlice, IoSliceMut};
217 use tokio::io::unix::AsyncFd;
218
219 pub(super) struct AsyncFdWrapper<H: Read + Write + AsSysFd>(pub(super) AsyncFd<H>);
220
221 impl<H: Read + Write + AsSysFd> AsyncFdWrapper<H> {
222 fn read<F: FnOnce(&mut AsyncFd<H>) -> io::Result<usize>>(
223 mut self: Pin<&mut Self>,
224 cx: &mut Context<'_>,
225 f: F,
226 ) -> Option<Poll<io::Result<usize>>> {
227 Some(match self.0.poll_read_ready_mut(cx) {
228 Poll::Pending => Poll::Pending,
229 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
230 Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
231 Ok(res) => Poll::Ready(res),
232 Err(_) => return None,
233 },
234 })
235 }
236
237 fn write<R, F: FnOnce(&mut AsyncFd<H>) -> io::Result<R>>(
238 mut self: Pin<&mut Self>,
239 cx: &mut Context<'_>,
240 f: F,
241 ) -> Option<Poll<io::Result<R>>> {
242 Some(match self.0.poll_write_ready_mut(cx) {
243 Poll::Pending => Poll::Pending,
244 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
245 Poll::Ready(Ok(mut guard)) => match guard.try_io(f) {
246 Ok(res) => Poll::Ready(res),
247 Err(_) => return None,
248 },
249 })
250 }
251 }
252
253 impl<H: Read + Write + AsSysFd> Unpin for AsyncFdWrapper<H> {}
254
255 impl<H: Read + Write + AsSysFd> AsyncRead for AsyncFdWrapper<H> {
256 fn poll_read(
257 mut self: Pin<&mut Self>,
258 cx: &mut Context<'_>,
259 buf: &mut [u8],
260 ) -> Poll<io::Result<usize>> {
261 loop {
262 if let Some(res) = self.as_mut().read(cx, |socket| socket.get_mut().read(buf)) {
263 return res;
264 }
265 }
266 }
267
268 fn poll_read_vectored(
269 mut self: Pin<&mut Self>,
270 cx: &mut Context<'_>,
271 bufs: &mut [IoSliceMut<'_>],
272 ) -> Poll<io::Result<usize>> {
273 loop {
274 if let Some(res) = self
275 .as_mut()
276 .read(cx, |socket| socket.get_mut().read_vectored(bufs))
277 {
278 return res;
279 }
280 }
281 }
282 }
283
284 impl<H: Read + Write + AsSysFd + Send + 'static> AsyncWrite for AsyncFdWrapper<H> {
285 fn poll_write(
286 mut self: Pin<&mut Self>,
287 cx: &mut Context<'_>,
288 buf: &[u8],
289 ) -> Poll<io::Result<usize>> {
290 loop {
291 if let Some(res) = self
292 .as_mut()
293 .write(cx, |socket| socket.get_mut().write(buf))
294 {
295 return res;
296 }
297 }
298 }
299
300 fn poll_write_vectored(
301 mut self: Pin<&mut Self>,
302 cx: &mut Context<'_>,
303 bufs: &[IoSlice<'_>],
304 ) -> Poll<io::Result<usize>> {
305 loop {
306 if let Some(res) = self
307 .as_mut()
308 .write(cx, |socket| socket.get_mut().write_vectored(bufs))
309 {
310 return res;
311 }
312 }
313 }
314
315 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
316 loop {
317 if let Some(res) = self.as_mut().write(cx, |socket| socket.get_mut().flush()) {
318 return res;
319 }
320 }
321 }
322
323 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
324 self.poll_flush(cx)
325 }
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 #[test]
334 fn auto_traits() {
335 use crate::util::test::*;
336 let runtime = Runtime::tokio().unwrap();
337 assert_send(&runtime);
338 assert_sync(&runtime);
339 assert_clone(&runtime);
340 }
341}