orb_tokio/
lib.rs

1//! # Orb Tokio Runtime
2//!
3//! This crate provides a Tokio-based implementation of the Orb async runtime traits.
4//! It allows users to leverage Tokio's powerful async runtime with the unified Orb interface.
5//!
6//! The main type provided is [`TokioRT`], which implements the core runtime functionality.
7//!
8//! See the [main Orb documentation](https://github.com/NaturalIO/orb) for more information.
9//!
10//! ## Usage
11//!
12//! ```rust
13//! use orb_tokio::TokioRT;
14//!
15//! let rt = TokioRT::new_multi_thread(4);
16//! ```
17
18use orb::io::{AsyncFd, AsyncIO};
19pub use orb::runtime::{AsyncExec, AsyncJoinHandle};
20use orb::time::{AsyncTime, TimeInterval};
21use std::fmt;
22use std::future::Future;
23use std::io;
24use std::net::SocketAddr;
25use std::net::TcpStream;
26use std::ops::Deref;
27use std::os::fd::{AsFd, AsRawFd};
28use std::os::unix::net::UnixStream;
29use std::path::PathBuf;
30use std::pin::Pin;
31use std::task::*;
32use std::time::{Duration, Instant};
33use tokio::runtime::{Builder, Handle, Runtime};
34
35/// The main struct for tokio runtime IO, assign this type to AsyncIO trait when used.
36pub enum TokioRT {
37    Runtime(Runtime),
38    Handle(Handle),
39}
40
41impl fmt::Debug for TokioRT {
42    #[inline]
43    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
44        match self {
45            Self::Runtime(_) => write!(f, "tokio(rt)"),
46            Self::Handle(_) => write!(f, "tokio(handle)"),
47        }
48    }
49}
50
51impl TokioRT {
52    /// Capture a runtime
53    #[inline]
54    pub fn new_with_runtime(rt: Runtime) -> Self {
55        Self::Runtime(rt)
56    }
57
58    #[inline]
59    pub fn new_multi_thread(workers: usize) -> Self {
60        let mut builder = Builder::new_multi_thread();
61        if workers > 0 {
62            builder.worker_threads(workers);
63        }
64        Self::Runtime(builder.enable_all().build().unwrap())
65    }
66
67    #[inline]
68    pub fn new_current_thread() -> Self {
69        let mut builder = Builder::new_current_thread();
70        Self::Runtime(builder.enable_all().build().unwrap())
71    }
72
73    /// Only capture a runtime handle. Should acquire with
74    /// `async { Handle::current() }`
75    #[inline]
76    pub fn new_with_handle(handle: Handle) -> Self {
77        Self::Handle(handle)
78    }
79}
80
81impl orb::AsyncRuntime for TokioRT {}
82
83impl AsyncIO for TokioRT {
84    type AsyncFd<T: AsRawFd + AsFd + Send + Sync + 'static> = TokioFD<T>;
85
86    #[inline(always)]
87    async fn connect_tcp(addr: &SocketAddr) -> io::Result<Self::AsyncFd<TcpStream>> {
88        let stream = tokio::net::TcpStream::connect(addr).await?;
89        // into_std will not change back to blocking
90        Self::to_async_fd_rw(stream.into_std()?)
91    }
92
93    #[inline(always)]
94    async fn connect_unix(addr: &PathBuf) -> io::Result<Self::AsyncFd<UnixStream>> {
95        let stream = tokio::net::UnixStream::connect(addr).await?;
96        // into_std will not change back to blocking
97        Self::to_async_fd_rw(stream.into_std()?)
98    }
99
100    #[inline(always)]
101    fn to_async_fd_rd<T: AsRawFd + AsFd + Send + Sync + 'static>(
102        fd: T,
103    ) -> io::Result<Self::AsyncFd<T>> {
104        use tokio::io;
105        Ok(TokioFD(io::unix::AsyncFd::with_interest(fd, io::Interest::READABLE)?))
106    }
107
108    #[inline(always)]
109    fn to_async_fd_rw<T: AsRawFd + AsFd + Send + Sync + 'static>(
110        fd: T,
111    ) -> io::Result<Self::AsyncFd<T>> {
112        use tokio::io;
113        use tokio::io::Interest;
114        Ok(TokioFD(io::unix::AsyncFd::with_interest(fd, Interest::READABLE | Interest::WRITABLE)?))
115    }
116}
117
118impl AsyncTime for TokioRT {
119    type Interval = TokioInterval;
120
121    #[inline(always)]
122    fn sleep(d: Duration) -> impl Future + Send {
123        tokio::time::sleep(d)
124    }
125
126    #[inline(always)]
127    fn tick(d: Duration) -> Self::Interval {
128        let later = tokio::time::Instant::now() + d;
129        TokioInterval(tokio::time::interval_at(later, d))
130    }
131}
132
133impl AsyncExec for TokioRT {
134    /// Spawn a task in the background, returning a handle to await its result
135    #[inline]
136    fn spawn<F, R>(&self, f: F) -> impl AsyncJoinHandle<R>
137    where
138        F: Future<Output = R> + Send + 'static,
139        R: Send + 'static,
140    {
141        match self {
142            Self::Runtime(s) => {
143                return TokioJoinHandle(s.spawn(f));
144            }
145            Self::Handle(s) => {
146                return TokioJoinHandle(s.spawn(f));
147            }
148        }
149    }
150
151    /// Spawn a task and detach it (no handle returned)
152    #[inline]
153    fn spawn_detach<F, R>(&self, f: F)
154    where
155        F: Future<Output = R> + Send + 'static,
156        R: Send + 'static,
157    {
158        match self {
159            Self::Runtime(s) => {
160                s.spawn(f);
161            }
162            Self::Handle(s) => {
163                s.spawn(f);
164            }
165        }
166    }
167
168    /// Run a future to completion on the runtime
169    #[inline]
170    fn block_on<F, R>(&self, f: F) -> R
171    where
172        F: Future<Output = R> + Send,
173        R: Send + 'static,
174    {
175        match self {
176            Self::Runtime(s) => {
177                return s.block_on(f);
178            }
179            Self::Handle(s) => {
180                return s.block_on(f);
181            }
182        }
183    }
184}
185
186/// Associate type for TokioRT
187pub struct TokioInterval(tokio::time::Interval);
188
189impl TimeInterval for TokioInterval {
190    #[inline]
191    fn poll_tick(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Instant> {
192        let _self = self.get_mut();
193        if let Poll::Ready(i) = _self.0.poll_tick(ctx) {
194            Poll::Ready(i.into_std())
195        } else {
196            Poll::Pending
197        }
198    }
199}
200
201/// Associate type for TokioRT
202pub struct TokioFD<T: AsRawFd + AsFd + Send + Sync + 'static>(tokio::io::unix::AsyncFd<T>);
203
204impl<T: AsRawFd + AsFd + Send + Sync + 'static> AsyncFd<T> for TokioFD<T> {
205    #[inline(always)]
206    async fn async_read<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
207        self.0.async_io(tokio::io::Interest::READABLE, f).await
208    }
209
210    #[inline(always)]
211    async fn async_write<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
212        self.0.async_io(tokio::io::Interest::WRITABLE, f).await
213    }
214}
215
216impl<T: AsRawFd + AsFd + Send + Sync + 'static> Deref for TokioFD<T> {
217    type Target = T;
218
219    #[inline(always)]
220    fn deref(&self) -> &Self::Target {
221        self.0.get_ref()
222    }
223}
224
225/// A wrapper around tokio's JoinHandle that implements AsyncJoinHandle
226pub struct TokioJoinHandle<T>(tokio::task::JoinHandle<T>);
227
228impl<T: Send + 'static> AsyncJoinHandle<T> for TokioJoinHandle<T> {
229    #[inline]
230    async fn join(self) -> Result<T, ()> {
231        match self.0.await {
232            Ok(r) => Ok(r),
233            Err(_) => Err(()),
234        }
235    }
236
237    #[inline]
238    fn detach(self) {
239        // Tokio's JoinHandle doesn't need explicit detach, it will run in background
240        // when the handle is dropped
241    }
242}