Skip to main content

orb_smol/
lib.rs

1//! # Smol Runtime adapter for Orb framework
2//!
3//! This crate provides a Smol-based implementation of the Orb async runtime traits.
4//! It allows users to leverage Smol's lightweight async runtime with the unified Orb interface.
5//!
6//! The main type provided is [`SmolRT`], which implements the core runtime functionality.
7//!
8//! See the [Orb crate](https://docs.rs/orb) for more information.
9//!
10//! ## Features
11//!
12//! - `global`: Enables the global executor feature, which allows using `smol` default global executor
13//!   instead of providing your own executor instance. (by default not enabled, omit the `smol`
14//!   dependency)
15//!
16//! - `unwind`: Use AssertUnwindSafe to capture panic inside the task, and return Err(()) to the
17//! task join handle. (by default not enabled, panic terminates the program)
18//!
19//! ## Usage
20//!
21//! With multi thread runtime
22//!
23//! ```rust
24//! use orb_smol::SmolRT;
25//! use orb::prelude::*;
26//! use std::sync::{Arc, atomic::{AtomicUsize, Ordering}};
27//! use std::time::Duration;
28//! let rt = SmolRT::multi(0); // spawn background thread with cpu number
29//! let counter = Arc::new(AtomicUsize::new(0));
30//! let _counter = counter.clone();
31//! rt.spawn(async move {
32//!     loop {
33//!         SmolRT::sleep(Duration::from_secs(1)).await;
34//!         _counter.fetch_add(1, Ordering::SeqCst);
35//!     }
36//! });
37//! // background task will continue to run until rt is drop
38//! std::thread::sleep(Duration::from_secs(3));
39//! drop(rt);
40//! let count = counter.load(Ordering::SeqCst);
41//! assert!(count >= 2 && count <= 4, "{count}");
42//! ```
43//!
44//!
45//! With a custom executor:
46//!
47//! ```rust
48//! use orb_smol::SmolRT;
49//! use orb::prelude::*;
50//! use std::sync::Arc;
51//! use async_executor::Executor;
52//!
53//! let executor = Arc::new(Executor::new());
54//! let rt = SmolRT::new_with_executor(executor);
55//! rt.block_on(async move {
56//!     for _ in 0..3 {
57//!         SmolRT::sleep(std::time::Duration::from_secs(1)).await;
58//!     }
59//!     println!("background task will stop once the block_on is finish");
60//! });
61//! ```
62//!
63//! With the smol global thread (requires the `global` feature):
64//!
65//! ```rust
66//! use orb_smol::SmolRT;
67//!
68//! #[cfg(feature = "global")]
69//! let rt = SmolRT::new_global();
70//! ```
71
72use async_executor::Executor;
73use async_io::{Async, Timer};
74use crossfire::{MAsyncRx, mpmc, null::CloseHandle};
75use futures_lite::{future::block_on, stream::StreamExt};
76use orb::io::{AsyncFd, AsyncIO};
77use orb::runtime::{AsyncExec, AsyncExecDyn, AsyncHandle, ThreadHandle};
78use orb::time::{AsyncTime, TimeInterval};
79use std::fmt;
80use std::future::Future;
81use std::io;
82use std::net::{SocketAddr, TcpStream};
83use std::num::NonZero;
84use std::ops::Deref;
85use std::os::{
86    fd::{AsFd, AsRawFd},
87    unix::net::UnixStream,
88};
89use std::path::Path;
90use std::pin::Pin;
91use std::sync::Arc;
92use std::task::*;
93use std::thread;
94use std::time::{Duration, Instant};
95
96/// The SmolRT implements AsyncRuntime trait
97#[derive(Clone)]
98pub struct SmolRT(Option<SmolRTInner>);
99
100#[derive(Clone)]
101struct SmolRTInner {
102    ex: Arc<Executor<'static>>,
103    _close_h: Option<CloseHandle<mpmc::Null>>,
104}
105
106impl fmt::Debug for SmolRT {
107    #[inline]
108    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
109        if self.0.is_some() { write!(f, "smol") } else { write!(f, "smol(global)") }
110    }
111}
112
113impl SmolRT {
114    #[cfg(feature = "global")]
115    #[inline]
116    pub fn new_global() -> Self {
117        Self(None)
118    }
119
120    /// spawn coroutine with specified Executor.
121    ///
122    /// # Safety
123    ///
124    /// You should run block_on on this executor somewhere (self.block_on also counts),
125    /// otherwise the future spawn into this executor will not run.
126    #[inline]
127    pub fn new_with_executor(executor: Arc<Executor<'static>>) -> Self {
128        Self(Some(SmolRTInner { ex: executor, _close_h: None }))
129    }
130}
131
132impl orb::AsyncRuntime for SmolRT {}
133
134impl AsyncIO for SmolRT {
135    type AsyncFd<T: AsRawFd + AsFd + Send + Sync + 'static> = SmolFD<T>;
136
137    #[inline(always)]
138    async fn connect_tcp(addr: &SocketAddr) -> io::Result<Self::AsyncFd<TcpStream>> {
139        let _addr = addr.clone();
140        let stream = Async::<TcpStream>::connect(_addr).await?;
141        // into_inner will not change back to blocking
142        Self::to_async_fd_rw(stream.into_inner()?)
143    }
144
145    #[inline(always)]
146    async fn connect_unix(addr: &Path) -> io::Result<Self::AsyncFd<UnixStream>> {
147        let stream = Async::<UnixStream>::connect(addr).await?;
148        // into_inner will not change back to blocking
149        Self::to_async_fd_rw(stream.into_inner()?)
150    }
151
152    #[inline(always)]
153    fn to_async_fd_rd<T: AsRawFd + AsFd + Send + Sync + 'static>(
154        fd: T,
155    ) -> io::Result<Self::AsyncFd<T>> {
156        Ok(SmolFD(Async::new(fd)?))
157    }
158
159    #[inline(always)]
160    fn to_async_fd_rw<T: AsRawFd + AsFd + Send + Sync + 'static>(
161        fd: T,
162    ) -> io::Result<Self::AsyncFd<T>> {
163        Ok(SmolFD(Async::new(fd)?))
164    }
165}
166
167impl AsyncTime for SmolRT {
168    type Interval = SmolInterval;
169
170    #[inline(always)]
171    fn sleep(d: Duration) -> impl Future + Send {
172        Timer::after(d)
173    }
174
175    #[inline(always)]
176    fn interval(d: Duration) -> Self::Interval {
177        let later = std::time::Instant::now() + d;
178        SmolInterval(Timer::interval_at(later, d))
179    }
180}
181
182macro_rules! unwind_wrap {
183    ($f: expr) => {{
184        #[cfg(feature = "unwind")]
185        {
186            use futures_lite::future::FutureExt;
187            std::panic::AssertUnwindSafe($f).catch_unwind()
188        }
189        #[cfg(not(feature = "unwind"))]
190        $f
191    }};
192}
193
194/// AsyncHandle implementation for smol
195#[cfg(feature = "unwind")]
196pub struct SmolJoinHandle<T>(
197    Option<async_executor::Task<Result<T, Box<dyn std::any::Any + Send>>>>,
198);
199#[cfg(not(feature = "unwind"))]
200pub struct SmolJoinHandle<T>(Option<async_executor::Task<T>>);
201
202impl<T: Send> AsyncHandle<T> for SmolJoinHandle<T> {
203    #[inline]
204    fn is_finished(&self) -> bool {
205        self.0.as_ref().unwrap().is_finished()
206    }
207
208    #[inline(always)]
209    fn abort(self) {
210        // do nothing, the inner task will be dropped
211    }
212
213    #[inline(always)]
214    fn detach(mut self) {
215        self.0.take().unwrap().detach();
216    }
217
218    #[inline(always)]
219    fn abort_boxed(self: Box<Self>) {
220        // do nothing, the inner task will be dropped
221    }
222
223    #[inline(always)]
224    fn detach_boxed(mut self: Box<Self>) {
225        self.0.take().unwrap().detach();
226    }
227}
228
229impl<T> Future for SmolJoinHandle<T> {
230    type Output = Result<T, ()>;
231
232    #[inline]
233    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
234        let _self = unsafe { self.get_unchecked_mut() };
235        if let Some(inner) = _self.0.as_mut() {
236            if let Poll::Ready(r) = Pin::new(inner).poll(cx) {
237                #[cfg(feature = "unwind")]
238                {
239                    return Poll::Ready(r.map_err(|_e| ()));
240                }
241                #[cfg(not(feature = "unwind"))]
242                {
243                    return Poll::Ready(Ok(r));
244                }
245            }
246            Poll::Pending
247        } else {
248            Poll::Ready(Err(()))
249        }
250    }
251}
252
253impl<T> Drop for SmolJoinHandle<T> {
254    fn drop(&mut self) {
255        if let Some(handle) = self.0.take() {
256            handle.detach();
257        }
258    }
259}
260
261pub struct BlockingJoinHandle<T>(async_executor::Task<T>);
262
263impl<T> ThreadHandle<T> for BlockingJoinHandle<T> {
264    #[inline]
265    fn is_finished(&self) -> bool {
266        self.0.is_finished()
267    }
268}
269
270impl<T> Future for BlockingJoinHandle<T> {
271    type Output = Result<T, ()>;
272
273    #[inline]
274    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
275        let _self = unsafe { self.get_unchecked_mut() };
276        if let Poll::Ready(r) = Pin::new(&mut _self.0).poll(cx) {
277            return Poll::Ready(Ok(r));
278        }
279        Poll::Pending
280    }
281}
282
283impl AsyncExecDyn for SmolRT {
284    #[inline(always)]
285    fn spawn_detach_dyn(&self, f: Box<dyn Future<Output = ()> + Send + Unpin>) {
286        self.spawn(unwind_wrap!(f)).detach();
287    }
288}
289
290impl AsyncExec for SmolRT {
291    type AsyncHandle<R: Send> = SmolJoinHandle<R>;
292
293    type ThreadHandle<R: Send> = BlockingJoinHandle<R>;
294
295    /// Initiate executor using current thread.
296    ///
297    /// # Safety
298    ///
299    /// You should run [Self::block_on()] with this executor.
300    ///
301    /// If spawn without a `block_on()` running, it's possible
302    /// the runtime just init future without scheduling.
303    #[inline(always)]
304    fn current() -> Self {
305        Self::new_with_executor(Arc::new(Executor::new()))
306    }
307
308    /// Initiate executor with one background thread.
309    ///
310    /// # NOTE
311    ///
312    /// [Self::block_on()] is optional.
313    #[inline(always)]
314    fn one() -> Self {
315        Self::multi(1)
316    }
317
318    /// Initiate executor with multiple background threads.
319    ///
320    /// # NOTE
321    ///
322    /// When `num` == 0, start threads that match cpu number
323    /// [Self::block_on()] is optional.
324    #[inline(always)]
325    fn multi(mut size: usize) -> Self {
326        if size == 0 {
327            size = usize::from(
328                thread::available_parallelism().unwrap_or(NonZero::new(1usize).unwrap()),
329            )
330        }
331        #[cfg(feature = "global")]
332        {
333            unsafe { std::env::set_var("SMOL_THREADS", size.to_string()) };
334            Self(None)
335        }
336        #[cfg(not(feature = "global"))]
337        {
338            let (close_h, rx): (CloseHandle<mpmc::Null>, MAsyncRx<mpmc::Null>) = mpmc::new();
339            // Prevent spawning another thread by running the process driver on this thread.
340            let inner = SmolRTInner { ex: Arc::new(Executor::new()), _close_h: Some(close_h) };
341            #[cfg(not(target_os = "espidf"))]
342            inner.ex.spawn(async_process::driver()).detach();
343            let ex = inner.ex.clone();
344            for n in 1..=size {
345                let _ex = ex.clone();
346                let _rx = rx.clone();
347                thread::Builder::new()
348                    .name(format!("smol-{}", n))
349                    .spawn(move || block_on(_ex.run(_rx.recv())))
350                    .expect("cannot spawn executor thread");
351            }
352            Self(Some(inner))
353        }
354    }
355
356    /// Spawn a task in the background
357    fn spawn<F, R>(&self, f: F) -> Self::AsyncHandle<R>
358    where
359        F: Future<Output = R> + Send + 'static,
360        R: Send + 'static,
361    {
362        // Although SmolJoinHandle don't need Send marker, but here in the spawn()
363        // need to restrict the requirements
364        let handle = match &self.0 {
365            Some(inner) => inner.ex.spawn(unwind_wrap!(f)),
366            None => {
367                #[cfg(feature = "global")]
368                {
369                    smol::spawn(unwind_wrap!(f))
370                }
371                #[cfg(not(feature = "global"))]
372                unreachable!();
373            }
374        };
375        SmolJoinHandle(Some(handle))
376    }
377
378    /// Depends on how you initialize SmolRT, spawn with executor or globally
379    #[inline]
380    fn spawn_detach<F, R>(&self, f: F)
381    where
382        F: Future<Output = R> + Send + 'static,
383        R: Send + 'static,
384    {
385        self.spawn(unwind_wrap!(f)).detach();
386    }
387
388    #[inline]
389    fn spawn_blocking<F, R>(f: F) -> Self::ThreadHandle<R>
390    where
391        F: FnOnce() -> R + Send + 'static,
392        R: Send + 'static,
393    {
394        BlockingJoinHandle(blocking::unblock(f))
395    }
396
397    /// Run a future to completion on the runtime
398    ///
399    /// NOTE: when initialized  with an executor,  will block current thread until the future
400    /// returns
401    #[inline]
402    fn block_on<F, R>(&self, f: F) -> R
403    where
404        F: Future<Output = R> + Send,
405        R: 'static,
406    {
407        if let Some(inner) = &self.0 {
408            block_on(inner.ex.run(f))
409        } else {
410            #[cfg(feature = "global")]
411            {
412                smol::block_on(f)
413            }
414            #[cfg(not(feature = "global"))]
415            unreachable!();
416        }
417    }
418}
419
420/// Associate type for SmolRT
421pub struct SmolInterval(Timer);
422
423impl TimeInterval for SmolInterval {
424    #[inline]
425    fn poll_tick(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Instant> {
426        let _self = self.get_mut();
427        match _self.0.poll_next(ctx) {
428            Poll::Ready(Some(i)) => Poll::Ready(i),
429            Poll::Ready(None) => unreachable!(),
430            Poll::Pending => Poll::Pending,
431        }
432    }
433}
434
435/// Associate type for SmolRT
436pub struct SmolFD<T: AsRawFd + AsFd + Send + Sync + 'static>(Async<T>);
437
438impl<T: AsRawFd + AsFd + Send + Sync + 'static> AsyncFd<T> for SmolFD<T> {
439    #[inline(always)]
440    async fn async_read<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
441        self.0.read_with(f).await
442    }
443
444    #[inline(always)]
445    async fn async_write<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
446        self.0.write_with(f).await
447    }
448}
449
450impl<T: AsRawFd + AsFd + Send + Sync + 'static> Deref for SmolFD<T> {
451    type Target = T;
452
453    #[inline(always)]
454    fn deref(&self) -> &Self::Target {
455        self.0.get_ref()
456    }
457}