Skip to main content

orb_smol/
lib.rs

1//! # Smol Runtime adapter for Orb framework
2//!
3//! This crate provides a Smol-based implementation of the Orb async runtime traits.
4//! It allows users to leverage Smol's lightweight async runtime with the unified Orb interface.
5//!
6//! The main type provided is [`SmolRT`], which implements the core runtime functionality.
7//!
8//! See the [Orb crate](https://docs.rs/orb) for more information.
9//!
10//! ## Features
11//!
12//! - `global`: Enables the global executor feature, which allows using `smol` default global executor
13//!   instead of providing your own executor instance. (by default not enabled, omit the `smol`
14//!   dependency)
15//!
16//! - `unwind`: Use AssertUnwindSafe to capture panic inside the task, and return Err(()) to the
17//! task join handle. (by default not enabled, panic terminates the program)
18//!
19//! ## Usage
20//!
21//! With multi thread runtime
22//!
23//! ```rust
24//! use orb_smol::SmolRT;
25//! use orb::prelude::*;
26//! use std::sync::{Arc, atomic::{AtomicUsize, Ordering}};
27//! use std::time::Duration;
28//! let rt = SmolRT::multi(0); // spawn background thread with cpu number
29//! let counter = Arc::new(AtomicUsize::new(0));
30//! let _counter = counter.clone();
31//! rt.spawn(async move {
32//!     loop {
33//!         SmolRT::sleep(Duration::from_secs(1)).await;
34//!         _counter.fetch_add(1, Ordering::SeqCst);
35//!     }
36//! });
37//! // background task will continue to run until rt is drop
38//! std::thread::sleep(Duration::from_secs(3));
39//! drop(rt);
40//! let count = counter.load(Ordering::SeqCst);
41//! assert!(count >= 2 && count <= 4, "{count}");
42//! ```
43//!
44//! ## Static Spawn
45//!
46//! This runtime supports static spawn methods through the [`AsyncRuntime`] trait
47//! that use the current runtime context:
48//!
49//! ```rust
50//! use orb::AsyncRuntime;
51//! use orb::runtime::AsyncExec;
52//!
53//! fn example<RT: AsyncRuntime>() {
54//!     let rt = RT::multi(2);
55//!     rt.block_on(async {
56//!         // Spawn a task using the static method - uses current runtime context
57//!         let handle = RT::spawn(async {
58//!             42
59//!         });
60//!         let result = handle.await.unwrap();
61//!         assert_eq!(result, 42);
62//!
63//!         // Spawn and detach a task
64//!         RT::spawn_detach(async {
65//!             println!("detached task running");
66//!         });
67//!     });
68//! }
69//! ```
70//!
71//! The static spawn methods ([`AsyncRuntime::spawn`], [`AsyncRuntime::spawn_detach`]) automatically use
72//! the runtime context of the current thread. This is implemented using thread-local
73//! storage that is registered when entering `block_on` or when worker threads are spawned.
74//!
75//! This feature provides a unified interface across different runtime implementations
76//! (smol, tokio, etc.) and fills the gap in `async-executor`'s native functionality.
77//!
78//! With the smol global thread (requires the `global` feature):
79//!
80//! ```rust
81//! use orb_smol::SmolRT;
82//!
83//! #[cfg(feature = "global")]
84//! let rt = SmolRT::new_global();
85//! ```
86
87use async_executor::Executor;
88use async_io::{Async, Timer};
89use crossfire::{MAsyncRx, mpmc, null::CloseHandle};
90use futures_lite::{future::block_on, stream::StreamExt};
91use orb::AsyncRuntime;
92use orb::io::{AsyncFd, AsyncIO};
93use orb::runtime::{AsyncExec, AsyncJoiner, ThreadJoiner};
94use orb::time::{AsyncTime, TimeInterval};
95use std::cell::Cell;
96use std::fmt;
97use std::future::Future;
98use std::io;
99use std::net::{SocketAddr, TcpStream};
100use std::num::NonZero;
101use std::ops::Deref;
102use std::os::{
103    fd::{AsFd, AsRawFd},
104    unix::net::UnixStream,
105};
106use std::path::Path;
107use std::pin::Pin;
108use std::ptr;
109use std::sync::Arc;
110use std::task::*;
111use std::thread;
112use std::time::{Duration, Instant};
113
114pub struct SmolRT {}
115
116/// The SmolRT implements AsyncRuntime trait
117#[derive(Clone)]
118pub struct SmolExec(Option<SmolExecInner>);
119
120#[derive(Clone)]
121struct SmolExecInner {
122    ex: Arc<Executor<'static>>,
123    _close_h: Option<CloseHandle<mpmc::Null>>,
124}
125
126// Thread local storage for the current executor context (pointer to Arc<Executor>)
127thread_local! {
128    static CURRENT_EXECUTOR: Cell<*const Executor<'static>> = const { Cell::new(ptr::null()) };
129}
130
131/// Set the current executor context for this thread
132fn set_current_executor(exec: *const Executor<'static>) {
133    CURRENT_EXECUTOR.set(exec);
134}
135
136/// Get the current executor context for this thread
137fn get_current_executor() -> *const Executor<'static> {
138    CURRENT_EXECUTOR.get()
139}
140
141impl fmt::Debug for SmolExec {
142    #[inline]
143    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
144        if self.0.is_some() { write!(f, "smol") } else { write!(f, "smol(global)") }
145    }
146}
147
148impl SmolExec {
149    #[cfg(feature = "global")]
150    #[inline]
151    pub fn new_global() -> Self {
152        Self(None)
153    }
154}
155
156impl AsyncIO for SmolRT {
157    type AsyncFd<T: AsRawFd + AsFd + Send + Sync + 'static> = SmolFD<T>;
158
159    #[inline(always)]
160    async fn connect_tcp(addr: &SocketAddr) -> io::Result<Self::AsyncFd<TcpStream>> {
161        let _addr = addr.clone();
162        let stream = Async::<TcpStream>::connect(_addr).await?;
163        // into_inner will not change back to blocking
164        Self::to_async_fd_rw(stream.into_inner()?)
165    }
166
167    #[inline(always)]
168    async fn connect_unix(addr: &Path) -> io::Result<Self::AsyncFd<UnixStream>> {
169        let stream = Async::<UnixStream>::connect(addr).await?;
170        // into_inner will not change back to blocking
171        Self::to_async_fd_rw(stream.into_inner()?)
172    }
173
174    #[inline(always)]
175    fn to_async_fd_rd<T: AsRawFd + AsFd + Send + Sync + 'static>(
176        fd: T,
177    ) -> io::Result<Self::AsyncFd<T>> {
178        Ok(SmolFD(Async::new(fd)?))
179    }
180
181    #[inline(always)]
182    fn to_async_fd_rw<T: AsRawFd + AsFd + Send + Sync + 'static>(
183        fd: T,
184    ) -> io::Result<Self::AsyncFd<T>> {
185        Ok(SmolFD(Async::new(fd)?))
186    }
187}
188
189impl AsyncTime for SmolRT {
190    type Interval = SmolInterval;
191
192    #[inline(always)]
193    fn sleep(d: Duration) -> impl Future + Send {
194        Timer::after(d)
195    }
196
197    #[inline(always)]
198    fn interval(d: Duration) -> Self::Interval {
199        let later = std::time::Instant::now() + d;
200        SmolInterval(Timer::interval_at(later, d))
201    }
202}
203
204macro_rules! unwind_wrap {
205    ($f: expr) => {{
206        #[cfg(feature = "unwind")]
207        {
208            use futures_lite::future::FutureExt;
209            std::panic::AssertUnwindSafe($f).catch_unwind()
210        }
211        #[cfg(not(feature = "unwind"))]
212        $f
213    }};
214}
215
216/// AsyncJoiner implementation for smol
217#[cfg(feature = "unwind")]
218pub struct SmolJoinHandle<T>(
219    Option<async_executor::Task<Result<T, Box<dyn std::any::Any + Send>>>>,
220);
221#[cfg(not(feature = "unwind"))]
222pub struct SmolJoinHandle<T>(Option<async_executor::Task<T>>);
223
224impl<T: Send> AsyncJoiner<T> for SmolJoinHandle<T> {
225    #[inline]
226    fn is_finished(&self) -> bool {
227        self.0.as_ref().unwrap().is_finished()
228    }
229
230    #[inline(always)]
231    fn abort(self) {
232        // do nothing, the inner task will be dropped
233    }
234
235    #[inline(always)]
236    fn detach(mut self) {
237        self.0.take().unwrap().detach();
238    }
239
240    #[inline(always)]
241    fn abort_boxed(self: Box<Self>) {
242        // do nothing, the inner task will be dropped
243    }
244
245    #[inline(always)]
246    fn detach_boxed(mut self: Box<Self>) {
247        self.0.take().unwrap().detach();
248    }
249}
250
251impl<T> Future for SmolJoinHandle<T> {
252    type Output = Result<T, ()>;
253
254    #[inline]
255    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
256        let _self = unsafe { self.get_unchecked_mut() };
257        if let Some(inner) = _self.0.as_mut() {
258            if let Poll::Ready(r) = Pin::new(inner).poll(cx) {
259                #[cfg(feature = "unwind")]
260                {
261                    return Poll::Ready(r.map_err(|_e| ()));
262                }
263                #[cfg(not(feature = "unwind"))]
264                {
265                    return Poll::Ready(Ok(r));
266                }
267            }
268            Poll::Pending
269        } else {
270            Poll::Ready(Err(()))
271        }
272    }
273}
274
275impl<T> Drop for SmolJoinHandle<T> {
276    fn drop(&mut self) {
277        if let Some(handle) = self.0.take() {
278            handle.detach();
279        }
280    }
281}
282
283pub struct BlockingJoinHandle<T>(async_executor::Task<T>);
284
285impl<T> ThreadJoiner<T> for BlockingJoinHandle<T> {
286    #[inline]
287    fn is_finished(&self) -> bool {
288        self.0.is_finished()
289    }
290}
291
292impl<T> Future for BlockingJoinHandle<T> {
293    type Output = Result<T, ()>;
294
295    #[inline]
296    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
297        let _self = unsafe { self.get_unchecked_mut() };
298        if let Poll::Ready(r) = Pin::new(&mut _self.0).poll(cx) {
299            return Poll::Ready(Ok(r));
300        }
301        Poll::Pending
302    }
303}
304
305impl AsyncRuntime for SmolRT {
306    type Exec = SmolExec;
307
308    /// Initiate executor using current thread.
309    ///
310    /// # Safety
311    ///
312    /// You should run [Self::block_on()] with this executor.
313    ///
314    /// If spawn without a `block_on()` running, it's possible
315    /// the runtime just init future without scheduling.
316    #[inline(always)]
317    fn current() -> SmolExec {
318        SmolExec(Some(SmolExecInner { ex: Arc::new(Executor::new()), _close_h: None }))
319    }
320
321    /// Initiate executor with one background thread.
322    ///
323    /// # NOTE
324    ///
325    /// [Self::block_on()] is optional.
326    #[inline(always)]
327    fn one() -> SmolExec {
328        Self::multi(1)
329    }
330
331    /// Initiate executor with multiple background threads.
332    ///
333    /// # NOTE
334    ///
335    /// When `num` == 0, start threads that match cpu number
336    /// [Self::block_on()] is optional.
337    #[inline(always)]
338    fn multi(mut size: usize) -> SmolExec {
339        if size == 0 {
340            size = usize::from(
341                thread::available_parallelism().unwrap_or(NonZero::new(1usize).unwrap()),
342            )
343        }
344        #[cfg(feature = "global")]
345        {
346            unsafe { std::env::set_var("SMOL_THREADS", size.to_string()) };
347            SmolExec::new_global()
348        }
349        #[cfg(not(feature = "global"))]
350        {
351            let (close_h, rx): (CloseHandle<mpmc::Null>, MAsyncRx<mpmc::Null>) = mpmc::new();
352            // Prevent spawning another thread by running the process driver on this thread.
353            let inner = SmolExecInner { ex: Arc::new(Executor::new()), _close_h: Some(close_h) };
354            #[cfg(not(target_os = "espidf"))]
355            inner.ex.spawn(async_process::driver()).detach();
356            let ex = inner.ex.clone();
357            // Get pointer to the Executor in Arc - this pointer is stable as long as Arc is alive
358            let ex_ptr: usize = Arc::as_ptr(&inner.ex) as usize;
359            for n in 1..=size {
360                let _ex = ex.clone();
361                let _rx = rx.clone();
362                let _ex_ptr = ex_ptr;
363                thread::Builder::new()
364                    .name(format!("smol-{}", n))
365                    .spawn(move || {
366                        set_current_executor(_ex_ptr as *const Executor<'static>);
367                        let _ = block_on(_ex.run(_rx.recv()));
368                        set_current_executor(ptr::null());
369                    })
370                    .expect("cannot spawn executor thread");
371            }
372            SmolExec(Some(inner))
373        }
374    }
375
376    /// Spawn a task in the background
377    fn spawn<F, R>(f: F) -> SmolJoinHandle<R>
378    where
379        F: Future<Output = R> + Send + 'static,
380        R: Send + 'static,
381    {
382        #[cfg(feature = "global")]
383        {
384            SmolJoinHandle(smol::spawn(f))
385        }
386        #[cfg(not(feature = "global"))]
387        {
388            let ex_ptr = get_current_executor();
389            assert!(!ex_ptr.is_null(), "spawn must be called in runtime worker context");
390            let ex = unsafe { &*ex_ptr };
391            SmolJoinHandle(Some(ex.spawn(unwind_wrap!(f))))
392        }
393    }
394
395    /// Depends on how you initialize SmolRT, spawn with executor or globally
396    #[inline]
397    fn spawn_detach<F, R>(f: F)
398    where
399        F: Future<Output = R> + Send + 'static,
400        R: Send + 'static,
401    {
402        #[cfg(feature = "global")]
403        {
404            smol::spawn(f).detach()
405        }
406        #[cfg(not(feature = "global"))]
407        {
408            let ex_ptr = get_current_executor();
409            assert!(!ex_ptr.is_null(), "spawn_detach must be called in runtime worker context");
410            let ex = unsafe { &*ex_ptr };
411            ex.spawn(unwind_wrap!(f)).detach();
412        }
413    }
414
415    #[inline]
416    fn spawn_blocking<F, R>(f: F) -> BlockingJoinHandle<R>
417    where
418        F: FnOnce() -> R + Send + 'static,
419        R: Send + 'static,
420    {
421        BlockingJoinHandle(blocking::unblock(f))
422    }
423}
424
425impl AsyncExec for SmolExec {
426    type AsyncJoiner<R: Send> = SmolJoinHandle<R>;
427
428    type ThreadJoiner<R: Send> = BlockingJoinHandle<R>;
429
430    /// Spawn a task in the background
431    fn spawn<F, R>(&self, f: F) -> Self::AsyncJoiner<R>
432    where
433        F: Future<Output = R> + Send + 'static,
434        R: Send + 'static,
435    {
436        // Although SmolJoinHandle don't need Send marker, but here in the spawn()
437        // need to restrict the requirements
438        let handle = match &self.0 {
439            Some(inner) => inner.ex.spawn(unwind_wrap!(f)),
440            None => {
441                #[cfg(feature = "global")]
442                {
443                    smol::spawn(unwind_wrap!(f))
444                }
445                #[cfg(not(feature = "global"))]
446                unreachable!();
447            }
448        };
449        SmolJoinHandle(Some(handle))
450    }
451
452    /// Depends on how you initialize SmolRT, spawn with executor or globally
453    #[inline]
454    fn spawn_detach<F, R>(&self, f: F)
455    where
456        F: Future<Output = R> + Send + 'static,
457        R: Send + 'static,
458    {
459        self.spawn(unwind_wrap!(f)).detach();
460    }
461
462    #[inline]
463    fn spawn_blocking<F, R>(&self, f: F) -> Self::ThreadJoiner<R>
464    where
465        F: FnOnce() -> R + Send + 'static,
466        R: Send + 'static,
467    {
468        BlockingJoinHandle(blocking::unblock(f))
469    }
470
471    /// Run a future to completion on the runtime
472    ///
473    /// NOTE: when initialized  with an executor,  will block current thread until the future
474    /// returns
475    #[inline]
476    fn block_on<F, R>(&self, f: F) -> R
477    where
478        F: Future<Output = R> + Send,
479        R: 'static,
480    {
481        if let Some(inner) = &self.0 {
482            let ex_ptr: *const Executor<'static> = Arc::as_ptr(&inner.ex);
483            set_current_executor(ex_ptr);
484            let result = block_on(inner.ex.run(f));
485            set_current_executor(ptr::null());
486            result
487        } else {
488            #[cfg(feature = "global")]
489            {
490                smol::block_on(f)
491            }
492            #[cfg(not(feature = "global"))]
493            unreachable!();
494        }
495    }
496}
497
498/// Associate type for SmolRT
499pub struct SmolInterval(Timer);
500
501impl TimeInterval for SmolInterval {
502    #[inline]
503    fn poll_tick(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Instant> {
504        let _self = self.get_mut();
505        match _self.0.poll_next(ctx) {
506            Poll::Ready(Some(i)) => Poll::Ready(i),
507            Poll::Ready(None) => unreachable!(),
508            Poll::Pending => Poll::Pending,
509        }
510    }
511}
512
513/// Associate type for SmolRT
514pub struct SmolFD<T: AsRawFd + AsFd + Send + Sync + 'static>(Async<T>);
515
516impl<T: AsRawFd + AsFd + Send + Sync + 'static> AsyncFd<T> for SmolFD<T> {
517    #[inline(always)]
518    async fn async_read<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
519        self.0.read_with(f).await
520    }
521
522    #[inline(always)]
523    async fn async_write<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
524        self.0.write_with(f).await
525    }
526}
527
528impl<T: AsRawFd + AsFd + Send + Sync + 'static> Deref for SmolFD<T> {
529    type Target = T;
530
531    #[inline(always)]
532    fn deref(&self) -> &Self::Target {
533        self.0.get_ref()
534    }
535}