Skip to main content

orb_smol/
lib.rs

1//! # Smol Runtime adapter for Orb framework
2//!
3//! This crate provides a Smol-based implementation of the Orb async runtime traits.
4//! It allows users to leverage Smol's lightweight async runtime with the unified Orb interface.
5//!
6//! The main type provided is [`SmolRT`], which implements the core runtime functionality.
7//!
8//! See the [Orb crate](https://docs.rs/orb) for more information.
9//!
10//! ## Features
11//!
12//! - `global`: Enables the global executor feature, which allows using `smol` default global executor
13//!   instead of providing your own executor instance. (by default not enabled, omit the `smol`
14//!   dependency)
15//!
16//! - `unwind`: Use AssertUnwindSafe to capture panic inside the task, and return Err(()) to the
17//! task join handle. (by default not enabled, panic terminates the program)
18//!
19//! ## Usage
20//!
21//! With a custom executor:
22//!
23//! ```rust
24//! use orb_smol::SmolRT;
25//! use std::sync::Arc;
26//! use async_executor::Executor;
27//!
28//! let executor = Arc::new(Executor::new());
29//! let rt = SmolRT::new(executor);
30//! ```
31//!
32//! With the global executor (requires the `global` feature):
33//!
34//! ```rust
35//! use orb_smol::SmolRT;
36//!
37//! #[cfg(feature = "global")]
38//! let rt = SmolRT::new_global();
39//! ```
40
41use async_executor::Executor;
42use async_io::{Async, Timer};
43use futures_lite::{future::block_on, stream::StreamExt};
44use orb::io::{AsyncFd, AsyncIO};
45use orb::runtime::{AsyncExec, AsyncHandle, ThreadHandle};
46use orb::time::{AsyncTime, TimeInterval};
47use std::fmt;
48use std::future::Future;
49use std::io;
50use std::net::SocketAddr;
51use std::net::TcpStream;
52use std::ops::Deref;
53use std::os::fd::{AsFd, AsRawFd};
54use std::os::unix::net::UnixStream;
55use std::path::Path;
56use std::pin::Pin;
57use std::sync::Arc;
58use std::task::*;
59use std::time::{Duration, Instant};
60
61/// The SmolRT implements AsyncRuntime trait
62#[derive(Clone)]
63pub struct SmolRT(Option<Arc<Executor<'static>>>);
64
65impl fmt::Debug for SmolRT {
66    #[inline]
67    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
68        if self.0.is_some() { write!(f, "smol") } else { write!(f, "smol(global)") }
69    }
70}
71
72impl SmolRT {
73    #[cfg(feature = "global")]
74    #[inline]
75    pub fn new_global() -> Self {
76        Self(None)
77    }
78
79    /// spawn coroutine with specified Executor
80    #[inline]
81    pub fn new(executor: Arc<Executor<'static>>) -> Self {
82        Self(Some(executor))
83    }
84}
85
86impl orb::AsyncRuntime for SmolRT {}
87
88impl AsyncIO for SmolRT {
89    type AsyncFd<T: AsRawFd + AsFd + Send + Sync + 'static> = SmolFD<T>;
90
91    #[inline(always)]
92    async fn connect_tcp(addr: &SocketAddr) -> io::Result<Self::AsyncFd<TcpStream>> {
93        let _addr = addr.clone();
94        let stream = Async::<TcpStream>::connect(_addr).await?;
95        // into_inner will not change back to blocking
96        Self::to_async_fd_rw(stream.into_inner()?)
97    }
98
99    #[inline(always)]
100    async fn connect_unix(addr: &Path) -> io::Result<Self::AsyncFd<UnixStream>> {
101        let stream = Async::<UnixStream>::connect(addr).await?;
102        // into_inner will not change back to blocking
103        Self::to_async_fd_rw(stream.into_inner()?)
104    }
105
106    #[inline(always)]
107    fn to_async_fd_rd<T: AsRawFd + AsFd + Send + Sync + 'static>(
108        fd: T,
109    ) -> io::Result<Self::AsyncFd<T>> {
110        Ok(SmolFD(Async::new(fd)?))
111    }
112
113    #[inline(always)]
114    fn to_async_fd_rw<T: AsRawFd + AsFd + Send + Sync + 'static>(
115        fd: T,
116    ) -> io::Result<Self::AsyncFd<T>> {
117        Ok(SmolFD(Async::new(fd)?))
118    }
119}
120
121impl AsyncTime for SmolRT {
122    type Interval = SmolInterval;
123
124    #[inline(always)]
125    fn sleep(d: Duration) -> impl Future + Send {
126        Timer::after(d)
127    }
128
129    #[inline(always)]
130    fn interval(d: Duration) -> Self::Interval {
131        let later = std::time::Instant::now() + d;
132        SmolInterval(Timer::interval_at(later, d))
133    }
134}
135
136macro_rules! unwind_wrap {
137    ($f: expr) => {{
138        #[cfg(feature = "unwind")]
139        {
140            use futures_lite::future::FutureExt;
141            std::panic::AssertUnwindSafe($f).catch_unwind()
142        }
143        #[cfg(not(feature = "unwind"))]
144        $f
145    }};
146}
147
148/// AsyncHandle implementation for smol
149#[cfg(feature = "unwind")]
150pub struct SmolJoinHandle<T>(
151    Option<async_executor::Task<Result<T, Box<dyn std::any::Any + Send>>>>,
152);
153#[cfg(not(feature = "unwind"))]
154pub struct SmolJoinHandle<T>(Option<async_executor::Task<T>>);
155
156impl<T: Send> AsyncHandle<T> for SmolJoinHandle<T> {
157    #[inline(always)]
158    fn abort(self) {
159        // do nothing, the inner task will be dropped
160    }
161
162    #[inline]
163    fn detach(mut self) {
164        self.0.take().unwrap().detach();
165    }
166
167    #[inline]
168    fn is_finished(&self) -> bool {
169        self.0.as_ref().unwrap().is_finished()
170    }
171}
172
173impl<T> Future for SmolJoinHandle<T> {
174    type Output = Result<T, ()>;
175
176    #[inline]
177    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
178        let _self = unsafe { self.get_unchecked_mut() };
179        if let Some(inner) = _self.0.as_mut() {
180            if let Poll::Ready(r) = Pin::new(inner).poll(cx) {
181                #[cfg(feature = "unwind")]
182                {
183                    return Poll::Ready(r.map_err(|_e| ()));
184                }
185                #[cfg(not(feature = "unwind"))]
186                {
187                    return Poll::Ready(Ok(r));
188                }
189            }
190            Poll::Pending
191        } else {
192            Poll::Ready(Err(()))
193        }
194    }
195}
196
197impl<T> Drop for SmolJoinHandle<T> {
198    fn drop(&mut self) {
199        if let Some(handle) = self.0.take() {
200            handle.detach();
201        }
202    }
203}
204
205pub struct BlockingJoinHandle<T>(async_executor::Task<T>);
206
207impl<T> ThreadHandle<T> for BlockingJoinHandle<T> {
208    #[inline]
209    fn is_finished(&self) -> bool {
210        self.0.is_finished()
211    }
212}
213
214impl<T> Future for BlockingJoinHandle<T> {
215    type Output = Result<T, ()>;
216
217    #[inline]
218    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
219        let _self = unsafe { self.get_unchecked_mut() };
220        if let Poll::Ready(r) = Pin::new(&mut _self.0).poll(cx) {
221            return Poll::Ready(Ok(r));
222        }
223        Poll::Pending
224    }
225}
226
227impl AsyncExec for SmolRT {
228    type AsyncHandle<R: Send> = SmolJoinHandle<R>;
229
230    type ThreadHandle<R: Send> = BlockingJoinHandle<R>;
231
232    /// Spawn a task in the background
233    fn spawn<F, R>(&self, f: F) -> Self::AsyncHandle<R>
234    where
235        F: Future<Output = R> + Send + 'static,
236        R: Send + 'static,
237    {
238        // Although SmolJoinHandle don't need Send marker, but here in the spawn()
239        // need to restrict the requirements
240        let handle = match &self.0 {
241            Some(exec) => exec.spawn(unwind_wrap!(f)),
242            None => {
243                #[cfg(feature = "global")]
244                {
245                    smol::spawn(unwind_wrap!(f))
246                }
247                #[cfg(not(feature = "global"))]
248                unreachable!();
249            }
250        };
251        SmolJoinHandle(Some(handle))
252    }
253
254    /// Depends on how you initialize SmolRT, spawn with executor or globally
255    #[inline]
256    fn spawn_detach<F, R>(&self, f: F)
257    where
258        F: Future<Output = R> + Send + 'static,
259        R: Send + 'static,
260    {
261        self.spawn(unwind_wrap!(f)).detach();
262    }
263
264    #[inline]
265    fn spawn_blocking<F, R>(f: F) -> Self::ThreadHandle<R>
266    where
267        F: FnOnce() -> R + Send + 'static,
268        R: Send + 'static,
269    {
270        BlockingJoinHandle(blocking::unblock(f))
271    }
272
273    /// Run a future to completion on the runtime
274    ///
275    /// NOTE: when initialized  with an executor,  will block current thread until the future
276    /// returns
277    #[inline]
278    fn block_on<F, R>(&self, f: F) -> R
279    where
280        F: Future<Output = R> + Send,
281        R: Send + 'static,
282    {
283        if let Some(exec) = &self.0 {
284            block_on(exec.run(f))
285        } else {
286            #[cfg(feature = "global")]
287            {
288                smol::block_on(f)
289            }
290            #[cfg(not(feature = "global"))]
291            unreachable!();
292        }
293    }
294}
295
296/// Associate type for SmolRT
297pub struct SmolInterval(Timer);
298
299impl TimeInterval for SmolInterval {
300    #[inline]
301    fn poll_tick(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Instant> {
302        let _self = self.get_mut();
303        match _self.0.poll_next(ctx) {
304            Poll::Ready(Some(i)) => Poll::Ready(i),
305            Poll::Ready(None) => unreachable!(),
306            Poll::Pending => Poll::Pending,
307        }
308    }
309}
310
311/// Associate type for SmolRT
312pub struct SmolFD<T: AsRawFd + AsFd + Send + Sync + 'static>(Async<T>);
313
314impl<T: AsRawFd + AsFd + Send + Sync + 'static> AsyncFd<T> for SmolFD<T> {
315    #[inline(always)]
316    async fn async_read<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
317        self.0.read_with(f).await
318    }
319
320    #[inline(always)]
321    async fn async_write<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
322        self.0.write_with(f).await
323    }
324}
325
326impl<T: AsRawFd + AsFd + Send + Sync + 'static> Deref for SmolFD<T> {
327    type Target = T;
328
329    #[inline(always)]
330    fn deref(&self) -> &Self::Target {
331        self.0.get_ref()
332    }
333}