orb_smol/
lib.rs

1//! # Smol Runtime adapter for Orb framework
2//!
3//! This crate provides a Smol-based implementation of the Orb async runtime traits.
4//! It allows users to leverage Smol's lightweight async runtime with the unified Orb interface.
5//!
6//! The main type provided is [`SmolRT`], which implements the core runtime functionality.
7//!
8//! See the [Orb crate](https://docs.rs/orb) for more information.
9//!
10//! ## Features
11//!
12//! - `global`: Enables the global executor feature, which allows using `smol` default global executor
13//!   instead of providing your own executor instance. (by default not enabled, omit the `smol`
14//!   dependency)
15//!
16//! - `unwind`: Use AssertUnwindSafe to capture panic inside the task, and return Err(()) to the
17//! task join handle. (by default not enabled, panic terminates the program)
18//!
19//! ## Usage
20//!
21//! With a custom executor:
22//!
23//! ```rust
24//! use orb_smol::SmolRT;
25//! use std::sync::Arc;
26//! use async_executor::Executor;
27//!
28//! let executor = Arc::new(Executor::new());
29//! let rt = SmolRT::new(executor);
30//! ```
31//!
32//! With the global executor (requires the `global` feature):
33//!
34//! ```rust
35//! use orb_smol::SmolRT;
36//!
37//! #[cfg(feature = "global")]
38//! let rt = SmolRT::new_global();
39//! ```
40
41use async_executor::Executor;
42use async_io::{Async, Timer};
43use futures_lite::{future::block_on, stream::StreamExt};
44use orb::io::{AsyncFd, AsyncIO};
45use orb::runtime::{AsyncExec, AsyncJoinHandle, ThreadJoinHandle};
46use orb::time::{AsyncTime, TimeInterval};
47use std::fmt;
48use std::future::Future;
49use std::io;
50use std::net::SocketAddr;
51use std::net::TcpStream;
52use std::ops::Deref;
53use std::os::fd::{AsFd, AsRawFd};
54use std::os::unix::net::UnixStream;
55use std::path::PathBuf;
56use std::pin::Pin;
57use std::sync::Arc;
58use std::task::*;
59use std::time::{Duration, Instant};
60
61/// The SmolRT implements AsyncRuntime trait
62#[derive(Clone)]
63pub struct SmolRT(Option<Arc<Executor<'static>>>);
64
65impl fmt::Debug for SmolRT {
66    #[inline]
67    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
68        if self.0.is_some() { write!(f, "smol") } else { write!(f, "smol(global)") }
69    }
70}
71
72impl SmolRT {
73    #[cfg(feature = "global")]
74    #[inline]
75    pub fn new_global() -> Self {
76        Self(None)
77    }
78
79    /// spawn coroutine with specified Executor
80    #[inline]
81    pub fn new(executor: Arc<Executor<'static>>) -> Self {
82        Self(Some(executor))
83    }
84}
85
86impl orb::AsyncRuntime for SmolRT {}
87
88impl AsyncIO for SmolRT {
89    type AsyncFd<T: AsRawFd + AsFd + Send + Sync + 'static> = SmolFD<T>;
90
91    #[inline(always)]
92    async fn connect_tcp(addr: &SocketAddr) -> io::Result<Self::AsyncFd<TcpStream>> {
93        let _addr = addr.clone();
94        let stream = Async::<TcpStream>::connect(_addr).await?;
95        // into_inner will not change back to blocking
96        Self::to_async_fd_rw(stream.into_inner()?)
97    }
98
99    #[inline(always)]
100    async fn connect_unix(addr: &PathBuf) -> io::Result<Self::AsyncFd<UnixStream>> {
101        let _addr = addr.clone();
102        let stream = Async::<UnixStream>::connect(_addr).await?;
103        // into_inner will not change back to blocking
104        Self::to_async_fd_rw(stream.into_inner()?)
105    }
106
107    #[inline(always)]
108    fn to_async_fd_rd<T: AsRawFd + AsFd + Send + Sync + 'static>(
109        fd: T,
110    ) -> io::Result<Self::AsyncFd<T>> {
111        Ok(SmolFD(Async::new(fd)?))
112    }
113
114    #[inline(always)]
115    fn to_async_fd_rw<T: AsRawFd + AsFd + Send + Sync + 'static>(
116        fd: T,
117    ) -> io::Result<Self::AsyncFd<T>> {
118        Ok(SmolFD(Async::new(fd)?))
119    }
120}
121
122impl AsyncTime for SmolRT {
123    type Interval = SmolInterval;
124
125    #[inline(always)]
126    fn sleep(d: Duration) -> impl Future + Send {
127        Timer::after(d)
128    }
129
130    #[inline(always)]
131    fn tick(d: Duration) -> Self::Interval {
132        let later = std::time::Instant::now() + d;
133        SmolInterval(Timer::interval_at(later, d))
134    }
135}
136
137macro_rules! unwind_wrap {
138    ($f: expr) => {{
139        #[cfg(feature = "unwind")]
140        {
141            use futures_lite::future::FutureExt;
142            std::panic::AssertUnwindSafe($f).catch_unwind()
143        }
144        #[cfg(not(feature = "unwind"))]
145        $f
146    }};
147}
148
149/// AsyncJoinHandle implementation for smol
150#[cfg(feature = "unwind")]
151pub struct SmolJoinHandle<T>(
152    Option<async_executor::Task<Result<T, Box<dyn std::any::Any + Send>>>>,
153);
154#[cfg(not(feature = "unwind"))]
155pub struct SmolJoinHandle<T>(Option<async_executor::Task<T>>);
156
157impl<T: Send + 'static> AsyncJoinHandle<T> for SmolJoinHandle<T> {
158    #[inline(always)]
159    fn abort(self) {
160        // do nothing, the inner task will be dropped
161    }
162
163    #[inline]
164    fn detach(mut self) {
165        self.0.take().unwrap().detach();
166    }
167
168    #[inline]
169    fn is_finished(&self) -> bool {
170        self.0.as_ref().unwrap().is_finished()
171    }
172}
173
174impl<T: Send + 'static> Future for SmolJoinHandle<T> {
175    type Output = Result<T, ()>;
176
177    #[inline]
178    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
179        let _self = unsafe { self.get_unchecked_mut() };
180        if let Some(inner) = _self.0.as_mut() {
181            if let Poll::Ready(r) = Pin::new(inner).poll(cx) {
182                #[cfg(feature = "unwind")]
183                {
184                    return Poll::Ready(r.map_err(|_e| ()));
185                }
186                #[cfg(not(feature = "unwind"))]
187                {
188                    return Poll::Ready(Ok(r));
189                }
190            }
191            Poll::Pending
192        } else {
193            Poll::Ready(Err(()))
194        }
195    }
196}
197
198impl<T> Drop for SmolJoinHandle<T> {
199    fn drop(&mut self) {
200        if let Some(handle) = self.0.take() {
201            handle.detach();
202        }
203    }
204}
205
206pub struct BlockingJoinHandle<T>(async_executor::Task<T>);
207
208impl<T: Send + 'static> ThreadJoinHandle<T> for BlockingJoinHandle<T> {
209    #[inline]
210    fn is_finished(&self) -> bool {
211        self.0.is_finished()
212    }
213}
214
215impl<T: Send + 'static> Future for BlockingJoinHandle<T> {
216    type Output = Result<T, ()>;
217
218    #[inline]
219    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
220        let _self = unsafe { self.get_unchecked_mut() };
221        if let Poll::Ready(r) = Pin::new(&mut _self.0).poll(cx) {
222            return Poll::Ready(Ok(r));
223        }
224        Poll::Pending
225    }
226}
227
228impl AsyncExec for SmolRT {
229    /// Spawn a task in the background
230    fn spawn<F, R>(&self, f: F) -> impl AsyncJoinHandle<R>
231    where
232        F: Future<Output = R> + Send + 'static,
233        R: Send + 'static,
234    {
235        let handle = match &self.0 {
236            Some(exec) => exec.spawn(unwind_wrap!(f)),
237            None => {
238                #[cfg(feature = "global")]
239                {
240                    smol::spawn(unwind_wrap!(f))
241                }
242                #[cfg(not(feature = "global"))]
243                unreachable!();
244            }
245        };
246        SmolJoinHandle(Some(handle))
247    }
248
249    /// Depends on how you initialize SmolRT, spawn with executor or globally
250    #[inline]
251    fn spawn_detach<F, R>(&self, f: F)
252    where
253        F: Future<Output = R> + Send + 'static,
254        R: Send + 'static,
255    {
256        self.spawn(unwind_wrap!(f)).detach();
257    }
258
259    #[inline]
260    fn spawn_blocking<F, R>(f: F) -> impl ThreadJoinHandle<R>
261    where
262        F: FnOnce() -> R + Send + 'static,
263        R: Send + 'static,
264    {
265        BlockingJoinHandle(blocking::unblock(f))
266    }
267
268    /// Run a future to completion on the runtime
269    ///
270    /// NOTE: when initialized  with an executor,  will block current thread until the future
271    /// returns
272    #[inline]
273    fn block_on<F, R>(&self, f: F) -> R
274    where
275        F: Future<Output = R> + Send,
276        R: Send + 'static,
277    {
278        if let Some(exec) = &self.0 {
279            block_on(exec.run(f))
280        } else {
281            #[cfg(feature = "global")]
282            {
283                smol::block_on(f)
284            }
285            #[cfg(not(feature = "global"))]
286            unreachable!();
287        }
288    }
289}
290
291/// Associate type for SmolRT
292pub struct SmolInterval(Timer);
293
294impl TimeInterval for SmolInterval {
295    #[inline]
296    fn poll_tick(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Instant> {
297        let _self = self.get_mut();
298        match _self.0.poll_next(ctx) {
299            Poll::Ready(Some(i)) => Poll::Ready(i),
300            Poll::Ready(None) => unreachable!(),
301            Poll::Pending => Poll::Pending,
302        }
303    }
304}
305
306/// Associate type for SmolRT
307pub struct SmolFD<T: AsRawFd + AsFd + Send + Sync + 'static>(Async<T>);
308
309impl<T: AsRawFd + AsFd + Send + Sync + 'static> AsyncFd<T> for SmolFD<T> {
310    #[inline(always)]
311    async fn async_read<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
312        self.0.read_with(f).await
313    }
314
315    #[inline(always)]
316    async fn async_write<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
317        self.0.write_with(f).await
318    }
319}
320
321impl<T: AsRawFd + AsFd + Send + Sync + 'static> Deref for SmolFD<T> {
322    type Target = T;
323
324    #[inline(always)]
325    fn deref(&self) -> &Self::Target {
326        self.0.get_ref()
327    }
328}