Skip to main content

orb_tokio/
lib.rs

1//! # Orb Tokio Runtime
2//!
3//! This crate provides a Tokio-based implementation of the Orb async runtime traits.
4//! It allows users to leverage Tokio's powerful async runtime with the unified Orb interface.
5//!
6//! The main type provided is [`TokioRT`], which implements the core runtime functionality.
7//!
8//! See the [main Orb documentation](https://github.com/NaturalIO/orb) for more information.
9//!
10//! ## Usage
11//!
12//! ```rust
13//! use orb_tokio::TokioRT;
14//!
15//! let rt = TokioRT::new_multi_thread(4);
16//! ```
17
18use orb::io::{AsyncFd, AsyncIO};
19pub use orb::runtime::{AsyncExec, AsyncHandle, ThreadHandle};
20use orb::time::{AsyncTime, TimeInterval};
21use std::fmt;
22use std::future::Future;
23use std::io;
24use std::net::SocketAddr;
25use std::net::TcpStream;
26use std::ops::Deref;
27use std::os::fd::{AsFd, AsRawFd};
28use std::os::unix::net::UnixStream;
29use std::path::Path;
30use std::pin::Pin;
31use std::task::*;
32use std::time::{Duration, Instant};
33use tokio::runtime::{Builder, Handle, Runtime};
34
35/// The main struct for tokio runtime IO, assign this type to AsyncIO trait when used.
36pub enum TokioRT {
37    Runtime(Runtime),
38    Handle(Handle),
39}
40
41impl fmt::Debug for TokioRT {
42    #[inline]
43    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
44        match self {
45            Self::Runtime(_) => write!(f, "tokio(rt)"),
46            Self::Handle(_) => write!(f, "tokio(handle)"),
47        }
48    }
49}
50
51impl TokioRT {
52    /// Capture a runtime
53    #[inline]
54    pub fn new_with_runtime(rt: Runtime) -> Self {
55        Self::Runtime(rt)
56    }
57
58    #[inline]
59    pub fn new_multi_thread(workers: usize) -> Self {
60        let mut builder = Builder::new_multi_thread();
61        if workers > 0 {
62            builder.worker_threads(workers);
63        }
64        Self::Runtime(builder.enable_all().build().unwrap())
65    }
66
67    #[inline]
68    pub fn new_current_thread() -> Self {
69        let mut builder = Builder::new_current_thread();
70        Self::Runtime(builder.enable_all().build().unwrap())
71    }
72
73    /// Only capture a runtime handle. Should acquire with
74    /// `async { Handle::current() }`
75    #[inline]
76    pub fn new_with_handle(handle: Handle) -> Self {
77        Self::Handle(handle)
78    }
79}
80
81impl Clone for TokioRT {
82    /// Clone a TokioRT::Handle out of runtime, for spawn
83    fn clone(&self) -> Self {
84        match self {
85            Self::Handle(h) => {
86                return Self::Handle(h.clone());
87            }
88            Self::Runtime(r) => {
89                let handle = {
90                    let _guard = r.enter();
91                    Handle::current()
92                };
93                Self::Handle(handle)
94            }
95        }
96    }
97}
98
99impl orb::AsyncRuntime for TokioRT {}
100
101impl AsyncIO for TokioRT {
102    type AsyncFd<T: AsRawFd + AsFd + Send + Sync + 'static> = TokioFD<T>;
103
104    #[inline(always)]
105    async fn connect_tcp(addr: &SocketAddr) -> io::Result<Self::AsyncFd<TcpStream>> {
106        let stream = tokio::net::TcpStream::connect(addr).await?;
107        // into_std will not change back to blocking
108        Self::to_async_fd_rw(stream.into_std()?)
109    }
110
111    #[inline(always)]
112    async fn connect_unix(addr: &Path) -> io::Result<Self::AsyncFd<UnixStream>> {
113        let stream = tokio::net::UnixStream::connect(addr).await?;
114        // into_std will not change back to blocking
115        Self::to_async_fd_rw(stream.into_std()?)
116    }
117
118    #[inline(always)]
119    fn to_async_fd_rd<T: AsRawFd + AsFd + Send + Sync + 'static>(
120        fd: T,
121    ) -> io::Result<Self::AsyncFd<T>> {
122        use tokio::io;
123        Ok(TokioFD(io::unix::AsyncFd::with_interest(fd, io::Interest::READABLE)?))
124    }
125
126    #[inline(always)]
127    fn to_async_fd_rw<T: AsRawFd + AsFd + Send + Sync + 'static>(
128        fd: T,
129    ) -> io::Result<Self::AsyncFd<T>> {
130        use tokio::io;
131        use tokio::io::Interest;
132        Ok(TokioFD(io::unix::AsyncFd::with_interest(fd, Interest::READABLE | Interest::WRITABLE)?))
133    }
134}
135
136impl AsyncTime for TokioRT {
137    type Interval = TokioInterval;
138
139    #[inline(always)]
140    fn sleep(d: Duration) -> impl Future + Send {
141        tokio::time::sleep(d)
142    }
143
144    #[inline(always)]
145    fn interval(d: Duration) -> Self::Interval {
146        let later = tokio::time::Instant::now() + d;
147        TokioInterval(tokio::time::interval_at(later, d))
148    }
149}
150
151impl AsyncExec for TokioRT {
152    type AsyncHandle<R: Send> = TokioJoinHandle<R>;
153
154    type ThreadHandle<R: Send> = TokioThreadHandle<R>;
155
156    /// Spawn a task in the background, returning a handle to await its result
157    #[inline]
158    fn spawn<F, R>(&self, f: F) -> Self::AsyncHandle<R>
159    where
160        F: Future<Output = R> + Send + 'static,
161        R: Send + 'static,
162    {
163        // Although AsyncHandle don't need Send marker, but here in the spawn()
164        // need to restrict the requirements
165        match self {
166            Self::Runtime(s) => {
167                return TokioJoinHandle(s.spawn(f));
168            }
169            Self::Handle(s) => {
170                return TokioJoinHandle(s.spawn(f));
171            }
172        }
173    }
174
175    /// Spawn a task and detach it (no handle returned)
176    #[inline]
177    fn spawn_detach<F, R>(&self, f: F)
178    where
179        F: Future<Output = R> + Send + 'static,
180        R: Send + 'static,
181    {
182        match self {
183            Self::Runtime(s) => {
184                s.spawn(f);
185            }
186            Self::Handle(s) => {
187                s.spawn(f);
188            }
189        }
190    }
191
192    #[inline(always)]
193    fn spawn_blocking<F, R>(f: F) -> Self::ThreadHandle<R>
194    where
195        F: FnOnce() -> R + Send + 'static,
196        R: Send + 'static,
197    {
198        TokioThreadHandle(tokio::task::spawn_blocking(f))
199    }
200
201    /// Run a future to completion on the runtime
202    #[inline]
203    fn block_on<F, R>(&self, f: F) -> R
204    where
205        F: Future<Output = R>,
206        R: 'static,
207    {
208        match self {
209            Self::Runtime(s) => {
210                return s.block_on(f);
211            }
212            Self::Handle(_s) => {
213                // panic in order to prevent misbehaved code.
214                // refer to https://docs.rs/tokio/latest/tokio/runtime/struct.Handle.html#method.block_on
215                panic!("handle is not allowed to block_on");
216            }
217        }
218    }
219}
220
221/// Associate type for TokioRT
222pub struct TokioInterval(tokio::time::Interval);
223
224impl TimeInterval for TokioInterval {
225    #[inline]
226    fn poll_tick(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Instant> {
227        let _self = self.get_mut();
228        if let Poll::Ready(i) = _self.0.poll_tick(ctx) {
229            Poll::Ready(i.into_std())
230        } else {
231            Poll::Pending
232        }
233    }
234}
235
236/// Associate type for TokioRT
237pub struct TokioFD<T: AsRawFd + AsFd + Send + Sync + 'static>(tokio::io::unix::AsyncFd<T>);
238
239impl<T: AsRawFd + AsFd + Send + Sync + 'static> AsyncFd<T> for TokioFD<T> {
240    #[inline(always)]
241    async fn async_read<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
242        self.0.async_io(tokio::io::Interest::READABLE, f).await
243    }
244
245    #[inline(always)]
246    async fn async_write<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
247        self.0.async_io(tokio::io::Interest::WRITABLE, f).await
248    }
249}
250
251impl<T: AsRawFd + AsFd + Send + Sync + 'static> Deref for TokioFD<T> {
252    type Target = T;
253
254    #[inline(always)]
255    fn deref(&self) -> &Self::Target {
256        self.0.get_ref()
257    }
258}
259
260/// A wrapper around tokio's JoinHandle that implements AsyncHandle
261pub struct TokioJoinHandle<T>(tokio::task::JoinHandle<T>);
262
263impl<T: Send> AsyncHandle<T> for TokioJoinHandle<T> {
264    #[inline]
265    fn is_finished(&self) -> bool {
266        self.0.is_finished()
267    }
268
269    #[inline]
270    fn detach(self) {
271        // Tokio's JoinHandle doesn't need explicit detach, it will run in background
272        // when the handle is dropped
273    }
274
275    #[inline]
276    fn abort(self) {
277        self.0.abort();
278    }
279}
280
281impl<T> Future for TokioJoinHandle<T> {
282    type Output = Result<T, ()>;
283
284    #[inline]
285    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
286        let _self = unsafe { self.get_unchecked_mut() };
287        if let Poll::Ready(r) = Pin::new(&mut _self.0).poll(cx) {
288            return Poll::Ready(r.map_err(|_e| ()));
289        }
290        Poll::Pending
291    }
292}
293
294/// A wrapper around tokio's JoinHandle that implements ThreadHandle
295pub struct TokioThreadHandle<T>(tokio::task::JoinHandle<T>);
296
297impl<T> ThreadHandle<T> for TokioThreadHandle<T> {
298    #[inline]
299    fn is_finished(&self) -> bool {
300        self.0.is_finished()
301    }
302}
303
304impl<T> Future for TokioThreadHandle<T> {
305    type Output = Result<T, ()>;
306
307    #[inline]
308    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
309        let _self = unsafe { self.get_unchecked_mut() };
310        if let Poll::Ready(r) = Pin::new(&mut _self.0).poll(cx) {
311            return Poll::Ready(r.map_err(|_e| ()));
312        }
313        Poll::Pending
314    }
315}