1use orb::io::{AsyncFd, AsyncIO};
19pub use orb::runtime::{AsyncExec, AsyncHandle, ThreadHandle};
20use orb::time::{AsyncTime, TimeInterval};
21use std::fmt;
22use std::future::Future;
23use std::io;
24use std::net::SocketAddr;
25use std::net::TcpStream;
26use std::ops::Deref;
27use std::os::fd::{AsFd, AsRawFd};
28use std::os::unix::net::UnixStream;
29use std::path::Path;
30use std::pin::Pin;
31use std::task::*;
32use std::time::{Duration, Instant};
33use tokio::runtime::{Builder, Handle, Runtime};
34
35pub enum TokioRT {
37 Runtime(Runtime),
38 Handle(Handle),
39}
40
41impl fmt::Debug for TokioRT {
42 #[inline]
43 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
44 match self {
45 Self::Runtime(_) => write!(f, "tokio(rt)"),
46 Self::Handle(_) => write!(f, "tokio(handle)"),
47 }
48 }
49}
50
51impl TokioRT {
52 #[inline]
54 pub fn new_with_runtime(rt: Runtime) -> Self {
55 Self::Runtime(rt)
56 }
57
58 #[inline]
59 pub fn new_multi_thread(workers: usize) -> Self {
60 let mut builder = Builder::new_multi_thread();
61 if workers > 0 {
62 builder.worker_threads(workers);
63 }
64 Self::Runtime(builder.enable_all().build().unwrap())
65 }
66
67 #[inline]
68 pub fn new_current_thread() -> Self {
69 let mut builder = Builder::new_current_thread();
70 Self::Runtime(builder.enable_all().build().unwrap())
71 }
72
73 #[inline]
76 pub fn new_with_handle(handle: Handle) -> Self {
77 Self::Handle(handle)
78 }
79}
80
81impl Clone for TokioRT {
82 fn clone(&self) -> Self {
84 match self {
85 Self::Handle(h) => {
86 return Self::Handle(h.clone());
87 }
88 Self::Runtime(r) => {
89 let handle = {
90 let _guard = r.enter();
91 Handle::current()
92 };
93 Self::Handle(handle)
94 }
95 }
96 }
97}
98
99impl orb::AsyncRuntime for TokioRT {}
100
101impl AsyncIO for TokioRT {
102 type AsyncFd<T: AsRawFd + AsFd + Send + Sync + 'static> = TokioFD<T>;
103
104 #[inline(always)]
105 async fn connect_tcp(addr: &SocketAddr) -> io::Result<Self::AsyncFd<TcpStream>> {
106 let stream = tokio::net::TcpStream::connect(addr).await?;
107 Self::to_async_fd_rw(stream.into_std()?)
109 }
110
111 #[inline(always)]
112 async fn connect_unix(addr: &Path) -> io::Result<Self::AsyncFd<UnixStream>> {
113 let stream = tokio::net::UnixStream::connect(addr).await?;
114 Self::to_async_fd_rw(stream.into_std()?)
116 }
117
118 #[inline(always)]
119 fn to_async_fd_rd<T: AsRawFd + AsFd + Send + Sync + 'static>(
120 fd: T,
121 ) -> io::Result<Self::AsyncFd<T>> {
122 use tokio::io;
123 Ok(TokioFD(io::unix::AsyncFd::with_interest(fd, io::Interest::READABLE)?))
124 }
125
126 #[inline(always)]
127 fn to_async_fd_rw<T: AsRawFd + AsFd + Send + Sync + 'static>(
128 fd: T,
129 ) -> io::Result<Self::AsyncFd<T>> {
130 use tokio::io;
131 use tokio::io::Interest;
132 Ok(TokioFD(io::unix::AsyncFd::with_interest(fd, Interest::READABLE | Interest::WRITABLE)?))
133 }
134}
135
136impl AsyncTime for TokioRT {
137 type Interval = TokioInterval;
138
139 #[inline(always)]
140 fn sleep(d: Duration) -> impl Future + Send {
141 tokio::time::sleep(d)
142 }
143
144 #[inline(always)]
145 fn interval(d: Duration) -> Self::Interval {
146 let later = tokio::time::Instant::now() + d;
147 TokioInterval(tokio::time::interval_at(later, d))
148 }
149}
150
151impl AsyncExec for TokioRT {
152 type AsyncHandle<R: Send> = TokioJoinHandle<R>;
153
154 type ThreadHandle<R: Send> = TokioThreadHandle<R>;
155
156 #[inline]
158 fn spawn<F, R>(&self, f: F) -> Self::AsyncHandle<R>
159 where
160 F: Future<Output = R> + Send + 'static,
161 R: Send + 'static,
162 {
163 match self {
166 Self::Runtime(s) => {
167 return TokioJoinHandle(s.spawn(f));
168 }
169 Self::Handle(s) => {
170 return TokioJoinHandle(s.spawn(f));
171 }
172 }
173 }
174
175 #[inline]
177 fn spawn_detach<F, R>(&self, f: F)
178 where
179 F: Future<Output = R> + Send + 'static,
180 R: Send + 'static,
181 {
182 match self {
183 Self::Runtime(s) => {
184 s.spawn(f);
185 }
186 Self::Handle(s) => {
187 s.spawn(f);
188 }
189 }
190 }
191
192 #[inline(always)]
193 fn spawn_blocking<F, R>(f: F) -> Self::ThreadHandle<R>
194 where
195 F: FnOnce() -> R + Send + 'static,
196 R: Send + 'static,
197 {
198 TokioThreadHandle(tokio::task::spawn_blocking(f))
199 }
200
201 #[inline]
203 fn block_on<F, R>(&self, f: F) -> R
204 where
205 F: Future<Output = R>,
206 R: 'static,
207 {
208 match self {
209 Self::Runtime(s) => {
210 return s.block_on(f);
211 }
212 Self::Handle(_s) => {
213 panic!("handle is not allowed to block_on");
216 }
217 }
218 }
219}
220
221pub struct TokioInterval(tokio::time::Interval);
223
224impl TimeInterval for TokioInterval {
225 #[inline]
226 fn poll_tick(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Instant> {
227 let _self = self.get_mut();
228 if let Poll::Ready(i) = _self.0.poll_tick(ctx) {
229 Poll::Ready(i.into_std())
230 } else {
231 Poll::Pending
232 }
233 }
234}
235
236pub struct TokioFD<T: AsRawFd + AsFd + Send + Sync + 'static>(tokio::io::unix::AsyncFd<T>);
238
239impl<T: AsRawFd + AsFd + Send + Sync + 'static> AsyncFd<T> for TokioFD<T> {
240 #[inline(always)]
241 async fn async_read<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
242 self.0.async_io(tokio::io::Interest::READABLE, f).await
243 }
244
245 #[inline(always)]
246 async fn async_write<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
247 self.0.async_io(tokio::io::Interest::WRITABLE, f).await
248 }
249}
250
251impl<T: AsRawFd + AsFd + Send + Sync + 'static> Deref for TokioFD<T> {
252 type Target = T;
253
254 #[inline(always)]
255 fn deref(&self) -> &Self::Target {
256 self.0.get_ref()
257 }
258}
259
260pub struct TokioJoinHandle<T>(tokio::task::JoinHandle<T>);
262
263impl<T: Send> AsyncHandle<T> for TokioJoinHandle<T> {
264 #[inline]
265 fn is_finished(&self) -> bool {
266 self.0.is_finished()
267 }
268
269 #[inline]
270 fn detach(self) {
271 }
274
275 #[inline]
276 fn abort(self) {
277 self.0.abort();
278 }
279}
280
281impl<T> Future for TokioJoinHandle<T> {
282 type Output = Result<T, ()>;
283
284 #[inline]
285 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
286 let _self = unsafe { self.get_unchecked_mut() };
287 if let Poll::Ready(r) = Pin::new(&mut _self.0).poll(cx) {
288 return Poll::Ready(r.map_err(|_e| ()));
289 }
290 Poll::Pending
291 }
292}
293
294pub struct TokioThreadHandle<T>(tokio::task::JoinHandle<T>);
296
297impl<T> ThreadHandle<T> for TokioThreadHandle<T> {
298 #[inline]
299 fn is_finished(&self) -> bool {
300 self.0.is_finished()
301 }
302}
303
304impl<T> Future for TokioThreadHandle<T> {
305 type Output = Result<T, ()>;
306
307 #[inline]
308 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
309 let _self = unsafe { self.get_unchecked_mut() };
310 if let Poll::Ready(r) = Pin::new(&mut _self.0).poll(cx) {
311 return Poll::Ready(r.map_err(|_e| ()));
312 }
313 Poll::Pending
314 }
315}