Skip to main content

orb_smol/
lib.rs

1//! # Smol Runtime adapter for Orb framework
2//!
3//! This crate provides a Smol-based implementation of the Orb async runtime traits.
4//! It allows users to leverage Smol's lightweight async runtime with the unified Orb interface.
5//!
6//! The main type provided is [`SmolRT`], which implements the core runtime functionality.
7//!
8//! See the [Orb crate](https://docs.rs/orb) for more information.
9//!
10//! ## Features
11//!
12//! - `global`: Enables the global executor feature, which allows using `smol` default global executor
13//!   which require polling in smol dependency, and requires setting
14//!   (by default not enabled, our `SmolRT::multi()` is more convenient)
15//!
16//! - `unwind`: Use AssertUnwindSafe to capture panic inside the task, and return Err(()) to the
17//! task join handle. (by default not enabled, panic terminates the program)
18//!
19//! ## Usage
20//!
21//! With multi thread runtime
22//!
23//! ```rust
24//! use orb_smol::SmolRT;
25//! use orb::prelude::*;
26//! use std::sync::{Arc, atomic::{AtomicUsize, Ordering}};
27//! use std::time::Duration;
28//! let rt = SmolRT::multi(0); // spawn background thread with cpu number
29//! let counter = Arc::new(AtomicUsize::new(0));
30//! let _counter = counter.clone();
31//! rt.spawn(async move {
32//!     loop {
33//!         SmolRT::sleep(Duration::from_secs(1)).await;
34//!         _counter.fetch_add(1, Ordering::SeqCst);
35//!     }
36//! });
37//! // background task will continue to run until rt is drop
38//! std::thread::sleep(Duration::from_secs(3));
39//! drop(rt);
40//! let count = counter.load(Ordering::SeqCst);
41//! assert!(count >= 2 && count <= 4, "{count}");
42//! ```
43//!
44//! ## Static Spawn
45//!
46//! This runtime supports static spawn methods through the [`AsyncRuntime`] trait
47//! that use the current runtime context:
48//!
49//! ```rust
50//! use orb::AsyncRuntime;
51//! use orb::runtime::AsyncExec;
52//!
53//! fn example<RT: AsyncRuntime>() {
54//!     let rt = RT::multi(2);
55//!     rt.block_on(async {
56//!         // Spawn a task using the static method - uses current runtime context
57//!         let handle = RT::spawn(async {
58//!             42
59//!         });
60//!         let result = handle.await.unwrap();
61//!         assert_eq!(result, 42);
62//!
63//!         // Spawn and detach a task
64//!         RT::spawn_detach(async {
65//!             println!("detached task running");
66//!         });
67//!     });
68//! }
69//! ```
70//!
71//! The static spawn methods ([`AsyncRuntime::spawn`], [`AsyncRuntime::spawn_detach`]) automatically use
72//! the runtime context of the current thread. This is implemented using thread-local
73//! storage that is registered when entering `block_on` or when worker threads are spawned.
74//!
75//! This feature provides a unified interface across different runtime implementations
76//! (smol, tokio, etc.) and fills the gap in `async-executor`'s native functionality.
77
78use async_executor::Executor;
79use async_io::{Async, Timer};
80#[allow(unused_imports)]
81use crossfire::{MAsyncRx, mpmc, null::CloseHandle};
82use futures_lite::{future::block_on, stream::StreamExt};
83use orb::AsyncRuntime;
84use orb::io::{AsyncFd, AsyncIO};
85use orb::runtime::{AsyncExec, AsyncJoiner, ThreadJoiner};
86use orb::time::{AsyncTime, TimeInterval};
87use std::cell::Cell;
88use std::fmt;
89use std::future::Future;
90use std::io;
91use std::net::{SocketAddr, TcpStream};
92use std::num::NonZero;
93use std::ops::Deref;
94use std::os::{
95    fd::{AsFd, AsRawFd},
96    unix::net::UnixStream,
97};
98use std::path::Path;
99use std::pin::Pin;
100use std::ptr;
101use std::sync::Arc;
102use std::task::*;
103use std::thread;
104use std::time::{Duration, Instant};
105
106pub struct SmolRT {}
107
108/// The SmolRT implements AsyncRuntime trait
109#[derive(Clone)]
110pub struct SmolExec(Option<SmolExecInner>);
111
112#[derive(Clone)]
113struct SmolExecInner {
114    ex: Arc<Executor<'static>>,
115    _close_h: Option<CloseHandle<mpmc::Null>>,
116}
117
118// Thread local storage for the current executor context (pointer to Arc<Executor>)
119thread_local! {
120    static CURRENT_EXECUTOR: Cell<*const Executor<'static>> = const { Cell::new(ptr::null()) };
121}
122
123/// Set the current executor context for this thread
124fn set_current_executor(exec: *const Executor<'static>) {
125    CURRENT_EXECUTOR.set(exec);
126}
127
128#[cfg(not(feature = "global"))]
129/// Get the current executor context for this thread
130fn get_current_executor() -> *const Executor<'static> {
131    CURRENT_EXECUTOR.get()
132}
133
134impl fmt::Debug for SmolExec {
135    #[inline]
136    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
137        if self.0.is_some() { write!(f, "smol") } else { write!(f, "smol(global)") }
138    }
139}
140
141impl AsyncIO for SmolRT {
142    type AsyncFd<T: AsRawFd + AsFd + Send + Sync + 'static> = SmolFD<T>;
143
144    #[inline(always)]
145    async fn connect_tcp(addr: &SocketAddr) -> io::Result<Self::AsyncFd<TcpStream>> {
146        let _addr = addr.clone();
147        let stream = Async::<TcpStream>::connect(_addr).await?;
148        // into_inner will not change back to blocking
149        Self::to_async_fd_rw(stream.into_inner()?)
150    }
151
152    #[inline(always)]
153    async fn connect_unix(addr: &Path) -> io::Result<Self::AsyncFd<UnixStream>> {
154        let stream = Async::<UnixStream>::connect(addr).await?;
155        // into_inner will not change back to blocking
156        Self::to_async_fd_rw(stream.into_inner()?)
157    }
158
159    #[inline(always)]
160    fn to_async_fd_rd<T: AsRawFd + AsFd + Send + Sync + 'static>(
161        fd: T,
162    ) -> io::Result<Self::AsyncFd<T>> {
163        Ok(SmolFD(Async::new(fd)?))
164    }
165
166    #[inline(always)]
167    fn to_async_fd_rw<T: AsRawFd + AsFd + Send + Sync + 'static>(
168        fd: T,
169    ) -> io::Result<Self::AsyncFd<T>> {
170        Ok(SmolFD(Async::new(fd)?))
171    }
172}
173
174impl AsyncTime for SmolRT {
175    type Interval = SmolInterval;
176
177    #[inline(always)]
178    fn sleep(d: Duration) -> impl Future + Send {
179        Timer::after(d)
180    }
181
182    #[inline(always)]
183    fn interval(d: Duration) -> Self::Interval {
184        let later = std::time::Instant::now() + d;
185        SmolInterval(Timer::interval_at(later, d))
186    }
187}
188
189macro_rules! unwind_wrap {
190    ($f: expr) => {{
191        #[cfg(feature = "unwind")]
192        {
193            use futures_lite::future::FutureExt;
194            std::panic::AssertUnwindSafe($f).catch_unwind()
195        }
196        #[cfg(not(feature = "unwind"))]
197        $f
198    }};
199}
200
201/// AsyncJoiner implementation for smol
202#[cfg(feature = "unwind")]
203pub struct SmolJoinHandle<T>(
204    Option<async_executor::Task<Result<T, Box<dyn std::any::Any + Send>>>>,
205);
206#[cfg(not(feature = "unwind"))]
207pub struct SmolJoinHandle<T>(Option<async_executor::Task<T>>);
208
209impl<T: Send> AsyncJoiner<T> for SmolJoinHandle<T> {
210    #[inline]
211    fn is_finished(&self) -> bool {
212        self.0.as_ref().unwrap().is_finished()
213    }
214
215    #[inline(always)]
216    fn abort(self) {
217        // do nothing, the inner task will be dropped
218    }
219
220    #[inline(always)]
221    fn detach(mut self) {
222        self.0.take().unwrap().detach();
223    }
224
225    #[inline(always)]
226    fn abort_boxed(self: Box<Self>) {
227        // do nothing, the inner task will be dropped
228    }
229
230    #[inline(always)]
231    fn detach_boxed(mut self: Box<Self>) {
232        self.0.take().unwrap().detach();
233    }
234}
235
236impl<T> Future for SmolJoinHandle<T> {
237    type Output = Result<T, ()>;
238
239    #[inline]
240    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
241        let _self = unsafe { self.get_unchecked_mut() };
242        if let Some(inner) = _self.0.as_mut() {
243            if let Poll::Ready(r) = Pin::new(inner).poll(cx) {
244                #[cfg(feature = "unwind")]
245                {
246                    return Poll::Ready(r.map_err(|_e| ()));
247                }
248                #[cfg(not(feature = "unwind"))]
249                {
250                    return Poll::Ready(Ok(r));
251                }
252            }
253            Poll::Pending
254        } else {
255            Poll::Ready(Err(()))
256        }
257    }
258}
259
260impl<T> Drop for SmolJoinHandle<T> {
261    fn drop(&mut self) {
262        if let Some(handle) = self.0.take() {
263            handle.detach();
264        }
265    }
266}
267
268pub struct BlockingJoinHandle<T>(async_executor::Task<T>);
269
270impl<T> ThreadJoiner<T> for BlockingJoinHandle<T> {
271    #[inline]
272    fn is_finished(&self) -> bool {
273        self.0.is_finished()
274    }
275}
276
277impl<T> Future for BlockingJoinHandle<T> {
278    type Output = Result<T, ()>;
279
280    #[inline]
281    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
282        let _self = unsafe { self.get_unchecked_mut() };
283        if let Poll::Ready(r) = Pin::new(&mut _self.0).poll(cx) {
284            return Poll::Ready(Ok(r));
285        }
286        Poll::Pending
287    }
288}
289
290impl AsyncRuntime for SmolRT {
291    type Exec = SmolExec;
292
293    /// Initiate executor using current thread.
294    ///
295    /// # Safety
296    ///
297    /// You should run [AsyncExec::block_on()] with this executor.
298    ///
299    /// If spawn without a `block_on()` running, it's possible
300    /// the runtime just init future without scheduling.
301    #[inline(always)]
302    fn current() -> SmolExec {
303        SmolExec(Some(SmolExecInner { ex: Arc::new(Executor::new()), _close_h: None }))
304    }
305
306    /// Initiate executor with one background thread.
307    ///
308    /// # NOTE
309    ///
310    /// [AsyncExec::block_on()] is optional, you can directly call [AsyncExec::spawn] with it.
311    #[inline(always)]
312    fn one() -> SmolExec {
313        Self::multi(1)
314    }
315
316    /// Initiate executor with multiple background threads.
317    ///
318    /// # NOTE
319    ///
320    /// When `num` == 0, start threads that match cpu number
321    ///
322    /// [AsyncExec::block_on()] is optional, you can directly call [AsyncExec::spawn] with it.
323    #[inline(always)]
324    fn multi(mut size: usize) -> SmolExec {
325        if size == 0 {
326            size = usize::from(
327                thread::available_parallelism().unwrap_or(NonZero::new(1usize).unwrap()),
328            )
329        }
330        #[cfg(feature = "global")]
331        {
332            unsafe { std::env::set_var("SMOL_THREADS", size.to_string()) };
333            SmolExec(None)
334        }
335        #[cfg(not(feature = "global"))]
336        {
337            let (close_h, rx): (CloseHandle<mpmc::Null>, MAsyncRx<mpmc::Null>) = mpmc::new();
338            // Prevent spawning another thread by running the process driver on this thread.
339            let inner = SmolExecInner { ex: Arc::new(Executor::new()), _close_h: Some(close_h) };
340            #[cfg(not(target_os = "espidf"))]
341            inner.ex.spawn(async_process::driver()).detach();
342            let ex = inner.ex.clone();
343            // Get pointer to the Executor in Arc - this pointer is stable as long as Arc is alive
344            let ex_ptr: usize = Arc::as_ptr(&inner.ex) as usize;
345            for n in 1..=size {
346                let _ex = ex.clone();
347                let _rx = rx.clone();
348                let _ex_ptr = ex_ptr;
349                thread::Builder::new()
350                    .name(format!("smol-{}", n))
351                    .spawn(move || {
352                        set_current_executor(_ex_ptr as *const Executor<'static>);
353                        let _ = block_on(_ex.run(_rx.recv()));
354                        set_current_executor(ptr::null());
355                    })
356                    .expect("cannot spawn executor thread");
357            }
358            SmolExec(Some(inner))
359        }
360    }
361
362    /// Spawn a task in the background
363    fn spawn<F, R>(f: F) -> SmolJoinHandle<R>
364    where
365        F: Future<Output = R> + Send + 'static,
366        R: Send + 'static,
367    {
368        #[cfg(feature = "global")]
369        {
370            SmolJoinHandle(Some(smol::spawn(unwind_wrap!(f))))
371        }
372        #[cfg(not(feature = "global"))]
373        {
374            let ex_ptr = get_current_executor();
375            assert!(!ex_ptr.is_null(), "spawn must be called in runtime worker context");
376            let ex = unsafe { &*ex_ptr };
377            SmolJoinHandle(Some(ex.spawn(unwind_wrap!(f))))
378        }
379    }
380
381    /// Depends on how you initialize SmolRT, spawn with executor or globally
382    #[inline]
383    fn spawn_detach<F, R>(f: F)
384    where
385        F: Future<Output = R> + Send + 'static,
386        R: Send + 'static,
387    {
388        #[cfg(feature = "global")]
389        {
390            smol::spawn(f).detach()
391        }
392        #[cfg(not(feature = "global"))]
393        {
394            let ex_ptr = get_current_executor();
395            assert!(!ex_ptr.is_null(), "spawn_detach must be called in runtime worker context");
396            let ex = unsafe { &*ex_ptr };
397            ex.spawn(unwind_wrap!(f)).detach();
398        }
399    }
400
401    #[inline]
402    fn spawn_blocking<F, R>(f: F) -> BlockingJoinHandle<R>
403    where
404        F: FnOnce() -> R + Send + 'static,
405        R: Send + 'static,
406    {
407        BlockingJoinHandle(blocking::unblock(f))
408    }
409}
410
411impl AsyncExec for SmolExec {
412    type AsyncJoiner<R: Send> = SmolJoinHandle<R>;
413
414    type ThreadJoiner<R: Send> = BlockingJoinHandle<R>;
415
416    /// Spawn a task in the background
417    fn spawn<F, R>(&self, f: F) -> Self::AsyncJoiner<R>
418    where
419        F: Future<Output = R> + Send + 'static,
420        R: Send + 'static,
421    {
422        // Although SmolJoinHandle don't need Send marker, but here in the spawn()
423        // need to restrict the requirements
424        let handle = match &self.0 {
425            Some(inner) => inner.ex.spawn(unwind_wrap!(f)),
426            None => {
427                #[cfg(feature = "global")]
428                {
429                    smol::spawn(unwind_wrap!(f))
430                }
431                #[cfg(not(feature = "global"))]
432                unreachable!();
433            }
434        };
435        SmolJoinHandle(Some(handle))
436    }
437
438    /// Depends on how you initialize SmolRT, spawn with executor or globally
439    #[inline]
440    fn spawn_detach<F, R>(&self, f: F)
441    where
442        F: Future<Output = R> + Send + 'static,
443        R: Send + 'static,
444    {
445        self.spawn(unwind_wrap!(f)).detach();
446    }
447
448    #[inline]
449    fn spawn_blocking<F, R>(&self, f: F) -> Self::ThreadJoiner<R>
450    where
451        F: FnOnce() -> R + Send + 'static,
452        R: Send + 'static,
453    {
454        BlockingJoinHandle(blocking::unblock(f))
455    }
456
457    /// Run a future to completion on the runtime
458    ///
459    /// NOTE: when initialized  with an executor,  will block current thread until the future
460    /// returns
461    #[inline]
462    fn block_on<F, R>(&self, f: F) -> R
463    where
464        F: Future<Output = R> + Send,
465        R: 'static,
466    {
467        if let Some(inner) = &self.0 {
468            let ex_ptr: *const Executor<'static> = Arc::as_ptr(&inner.ex);
469            set_current_executor(ex_ptr);
470            let result = block_on(inner.ex.run(f));
471            set_current_executor(ptr::null());
472            result
473        } else {
474            #[cfg(feature = "global")]
475            {
476                smol::block_on(f)
477            }
478            #[cfg(not(feature = "global"))]
479            unreachable!();
480        }
481    }
482}
483
484/// Associate type for SmolRT
485pub struct SmolInterval(Timer);
486
487impl TimeInterval for SmolInterval {
488    #[inline]
489    fn poll_tick(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Instant> {
490        let _self = self.get_mut();
491        match _self.0.poll_next(ctx) {
492            Poll::Ready(Some(i)) => Poll::Ready(i),
493            Poll::Ready(None) => unreachable!(),
494            Poll::Pending => Poll::Pending,
495        }
496    }
497}
498
499/// Associate type for SmolRT
500pub struct SmolFD<T: AsRawFd + AsFd + Send + Sync + 'static>(Async<T>);
501
502impl<T: AsRawFd + AsFd + Send + Sync + 'static> AsyncFd<T> for SmolFD<T> {
503    #[inline(always)]
504    async fn async_read<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
505        self.0.read_with(f).await
506    }
507
508    #[inline(always)]
509    async fn async_write<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
510        self.0.write_with(f).await
511    }
512}
513
514impl<T: AsRawFd + AsFd + Send + Sync + 'static> Deref for SmolFD<T> {
515    type Target = T;
516
517    #[inline(always)]
518    fn deref(&self) -> &Self::Target {
519        self.0.get_ref()
520    }
521}