Skip to main content

orb_smol/
lib.rs

1//! # Smol Runtime adapter for Orb framework
2//!
3//! This crate provides a Smol-based implementation of the Orb async runtime traits.
4//! It allows users to leverage Smol's lightweight async runtime with the unified Orb interface.
5//!
6//! The main type provided is [`SmolRT`], which implements the core runtime functionality.
7//!
8//! See the [Orb crate](https://docs.rs/orb) for more information.
9//!
10//! ## Features
11//!
12//! - `global`: Enables the global executor feature, which allows using `smol` default global executor
13//!   instead of providing your own executor instance. (by default not enabled, omit the `smol`
14//!   dependency)
15//!
16//! - `unwind`: Use AssertUnwindSafe to capture panic inside the task, and return Err(()) to the
17//! task join handle. (by default not enabled, panic terminates the program)
18//!
19//! ## Usage
20//!
21//! With a custom executor:
22//!
23//! ```rust
24//! use orb_smol::SmolRT;
25//! use std::sync::Arc;
26//! use async_executor::Executor;
27//!
28//! let executor = Arc::new(Executor::new());
29//! let rt = SmolRT::new(executor);
30//! ```
31//!
32//! With the global executor (requires the `global` feature):
33//!
34//! ```rust
35//! use orb_smol::SmolRT;
36//!
37//! #[cfg(feature = "global")]
38//! let rt = SmolRT::new_global();
39//! ```
40
41use async_executor::Executor;
42use async_io::{Async, Timer};
43use futures_lite::{future::block_on, stream::StreamExt};
44use orb::io::{AsyncFd, AsyncIO};
45use orb::runtime::{AsyncExec, AsyncExecDyn, AsyncHandle, ThreadHandle};
46use orb::time::{AsyncTime, TimeInterval};
47use std::fmt;
48use std::future::Future;
49use std::io;
50use std::net::SocketAddr;
51use std::net::TcpStream;
52use std::ops::Deref;
53use std::os::fd::{AsFd, AsRawFd};
54use std::os::unix::net::UnixStream;
55use std::path::Path;
56use std::pin::Pin;
57use std::sync::Arc;
58use std::task::*;
59use std::time::{Duration, Instant};
60
61/// The SmolRT implements AsyncRuntime trait
62#[derive(Clone)]
63pub struct SmolRT(Option<Arc<Executor<'static>>>);
64
65impl fmt::Debug for SmolRT {
66    #[inline]
67    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
68        if self.0.is_some() { write!(f, "smol") } else { write!(f, "smol(global)") }
69    }
70}
71
72impl SmolRT {
73    #[cfg(feature = "global")]
74    #[inline]
75    pub fn new_global() -> Self {
76        Self(None)
77    }
78
79    /// spawn coroutine with specified Executor
80    #[inline]
81    pub fn new(executor: Arc<Executor<'static>>) -> Self {
82        Self(Some(executor))
83    }
84}
85
86impl orb::AsyncRuntime for SmolRT {}
87
88impl AsyncIO for SmolRT {
89    type AsyncFd<T: AsRawFd + AsFd + Send + Sync + 'static> = SmolFD<T>;
90
91    #[inline(always)]
92    async fn connect_tcp(addr: &SocketAddr) -> io::Result<Self::AsyncFd<TcpStream>> {
93        let _addr = addr.clone();
94        let stream = Async::<TcpStream>::connect(_addr).await?;
95        // into_inner will not change back to blocking
96        Self::to_async_fd_rw(stream.into_inner()?)
97    }
98
99    #[inline(always)]
100    async fn connect_unix(addr: &Path) -> io::Result<Self::AsyncFd<UnixStream>> {
101        let stream = Async::<UnixStream>::connect(addr).await?;
102        // into_inner will not change back to blocking
103        Self::to_async_fd_rw(stream.into_inner()?)
104    }
105
106    #[inline(always)]
107    fn to_async_fd_rd<T: AsRawFd + AsFd + Send + Sync + 'static>(
108        fd: T,
109    ) -> io::Result<Self::AsyncFd<T>> {
110        Ok(SmolFD(Async::new(fd)?))
111    }
112
113    #[inline(always)]
114    fn to_async_fd_rw<T: AsRawFd + AsFd + Send + Sync + 'static>(
115        fd: T,
116    ) -> io::Result<Self::AsyncFd<T>> {
117        Ok(SmolFD(Async::new(fd)?))
118    }
119}
120
121impl AsyncTime for SmolRT {
122    type Interval = SmolInterval;
123
124    #[inline(always)]
125    fn sleep(d: Duration) -> impl Future + Send {
126        Timer::after(d)
127    }
128
129    #[inline(always)]
130    fn interval(d: Duration) -> Self::Interval {
131        let later = std::time::Instant::now() + d;
132        SmolInterval(Timer::interval_at(later, d))
133    }
134}
135
136macro_rules! unwind_wrap {
137    ($f: expr) => {{
138        #[cfg(feature = "unwind")]
139        {
140            use futures_lite::future::FutureExt;
141            std::panic::AssertUnwindSafe($f).catch_unwind()
142        }
143        #[cfg(not(feature = "unwind"))]
144        $f
145    }};
146}
147
148/// AsyncHandle implementation for smol
149#[cfg(feature = "unwind")]
150pub struct SmolJoinHandle<T>(
151    Option<async_executor::Task<Result<T, Box<dyn std::any::Any + Send>>>>,
152);
153#[cfg(not(feature = "unwind"))]
154pub struct SmolJoinHandle<T>(Option<async_executor::Task<T>>);
155
156impl<T: Send> AsyncHandle<T> for SmolJoinHandle<T> {
157    #[inline]
158    fn is_finished(&self) -> bool {
159        self.0.as_ref().unwrap().is_finished()
160    }
161
162    #[inline(always)]
163    fn abort(self) {
164        // do nothing, the inner task will be dropped
165    }
166
167    #[inline(always)]
168    fn detach(mut self) {
169        self.0.take().unwrap().detach();
170    }
171
172    #[inline(always)]
173    fn abort_boxed(self: Box<Self>) {
174        // do nothing, the inner task will be dropped
175    }
176
177    #[inline(always)]
178    fn detach_boxed(mut self: Box<Self>) {
179        self.0.take().unwrap().detach();
180    }
181}
182
183impl<T> Future for SmolJoinHandle<T> {
184    type Output = Result<T, ()>;
185
186    #[inline]
187    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
188        let _self = unsafe { self.get_unchecked_mut() };
189        if let Some(inner) = _self.0.as_mut() {
190            if let Poll::Ready(r) = Pin::new(inner).poll(cx) {
191                #[cfg(feature = "unwind")]
192                {
193                    return Poll::Ready(r.map_err(|_e| ()));
194                }
195                #[cfg(not(feature = "unwind"))]
196                {
197                    return Poll::Ready(Ok(r));
198                }
199            }
200            Poll::Pending
201        } else {
202            Poll::Ready(Err(()))
203        }
204    }
205}
206
207impl<T> Drop for SmolJoinHandle<T> {
208    fn drop(&mut self) {
209        if let Some(handle) = self.0.take() {
210            handle.detach();
211        }
212    }
213}
214
215pub struct BlockingJoinHandle<T>(async_executor::Task<T>);
216
217impl<T> ThreadHandle<T> for BlockingJoinHandle<T> {
218    #[inline]
219    fn is_finished(&self) -> bool {
220        self.0.is_finished()
221    }
222}
223
224impl<T> Future for BlockingJoinHandle<T> {
225    type Output = Result<T, ()>;
226
227    #[inline]
228    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
229        let _self = unsafe { self.get_unchecked_mut() };
230        if let Poll::Ready(r) = Pin::new(&mut _self.0).poll(cx) {
231            return Poll::Ready(Ok(r));
232        }
233        Poll::Pending
234    }
235}
236
237impl AsyncExecDyn for SmolRT {
238    #[inline(always)]
239    fn spawn_detach_dyn(&self, f: Box<dyn Future<Output = ()> + Send + Unpin>) {
240        self.spawn(unwind_wrap!(f)).detach();
241    }
242}
243
244impl AsyncExec for SmolRT {
245    type AsyncHandle<R: Send> = SmolJoinHandle<R>;
246
247    type ThreadHandle<R: Send> = BlockingJoinHandle<R>;
248
249    /// Spawn a task in the background
250    fn spawn<F, R>(&self, f: F) -> Self::AsyncHandle<R>
251    where
252        F: Future<Output = R> + Send + 'static,
253        R: Send + 'static,
254    {
255        // Although SmolJoinHandle don't need Send marker, but here in the spawn()
256        // need to restrict the requirements
257        let handle = match &self.0 {
258            Some(exec) => exec.spawn(unwind_wrap!(f)),
259            None => {
260                #[cfg(feature = "global")]
261                {
262                    smol::spawn(unwind_wrap!(f))
263                }
264                #[cfg(not(feature = "global"))]
265                unreachable!();
266            }
267        };
268        SmolJoinHandle(Some(handle))
269    }
270
271    /// Depends on how you initialize SmolRT, spawn with executor or globally
272    #[inline]
273    fn spawn_detach<F, R>(&self, f: F)
274    where
275        F: Future<Output = R> + Send + 'static,
276        R: Send + 'static,
277    {
278        self.spawn(unwind_wrap!(f)).detach();
279    }
280
281    #[inline]
282    fn spawn_blocking<F, R>(f: F) -> Self::ThreadHandle<R>
283    where
284        F: FnOnce() -> R + Send + 'static,
285        R: Send + 'static,
286    {
287        BlockingJoinHandle(blocking::unblock(f))
288    }
289
290    /// Run a future to completion on the runtime
291    ///
292    /// NOTE: when initialized  with an executor,  will block current thread until the future
293    /// returns
294    #[inline]
295    fn block_on<F, R>(&self, f: F) -> R
296    where
297        F: Future<Output = R> + Send,
298        R: Send + 'static,
299    {
300        if let Some(exec) = &self.0 {
301            block_on(exec.run(f))
302        } else {
303            #[cfg(feature = "global")]
304            {
305                smol::block_on(f)
306            }
307            #[cfg(not(feature = "global"))]
308            unreachable!();
309        }
310    }
311}
312
313/// Associate type for SmolRT
314pub struct SmolInterval(Timer);
315
316impl TimeInterval for SmolInterval {
317    #[inline]
318    fn poll_tick(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Instant> {
319        let _self = self.get_mut();
320        match _self.0.poll_next(ctx) {
321            Poll::Ready(Some(i)) => Poll::Ready(i),
322            Poll::Ready(None) => unreachable!(),
323            Poll::Pending => Poll::Pending,
324        }
325    }
326}
327
328/// Associate type for SmolRT
329pub struct SmolFD<T: AsRawFd + AsFd + Send + Sync + 'static>(Async<T>);
330
331impl<T: AsRawFd + AsFd + Send + Sync + 'static> AsyncFd<T> for SmolFD<T> {
332    #[inline(always)]
333    async fn async_read<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
334        self.0.read_with(f).await
335    }
336
337    #[inline(always)]
338    async fn async_write<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
339        self.0.write_with(f).await
340    }
341}
342
343impl<T: AsRawFd + AsFd + Send + Sync + 'static> Deref for SmolFD<T> {
344    type Target = T;
345
346    #[inline(always)]
347    fn deref(&self) -> &Self::Target {
348        self.0.get_ref()
349    }
350}