orb_smol/
lib.rs

1//! # Smol Runtime adapter for Orb framework
2//!
3//! This crate provides a Smol-based implementation of the Orb async runtime traits.
4//! It allows users to leverage Smol's lightweight async runtime with the unified Orb interface.
5//!
6//! The main type provided is [`SmolRT`], which implements the core runtime functionality.
7//!
8//! See the [Orb crate](https://docs.rs/orb) for more information.
9//!
10//! ## Features
11//!
12//! - `global`: Enables the global executor feature, which allows using `smol` default global executor
13//!   instead of providing your own executor instance. (by default not enabled, omit the `smol`
14//!   dependency)
15//!
16//! - `unwind`: Use AssertUnwindSafe to capture panic inside the task, and return Err(()) to the
17//! task join handle. (by default not enabled, panic terminates the program)
18//!
19//! ## Usage
20//!
21//! With a custom executor:
22//!
23//! ```rust
24//! use orb_smol::SmolRT;
25//! use std::sync::Arc;
26//! use async_executor::Executor;
27//!
28//! let executor = Arc::new(Executor::new());
29//! let rt = SmolRT::new(executor);
30//! ```
31//!
32//! With the global executor (requires the `global` feature):
33//!
34//! ```rust
35//! use orb_smol::SmolRT;
36//!
37//! #[cfg(feature = "global")]
38//! let rt = SmolRT::new_global();
39//! ```
40
41use async_executor::Executor;
42use async_io::{Async, Timer};
43use futures_lite::{future::block_on, stream::StreamExt};
44use orb::io::{AsyncFd, AsyncIO};
45use orb::runtime::{AsyncExec, AsyncHandle, ThreadHandle};
46use orb::time::{AsyncTime, TimeInterval};
47use std::fmt;
48use std::future::Future;
49use std::io;
50use std::net::SocketAddr;
51use std::net::TcpStream;
52use std::ops::Deref;
53use std::os::fd::{AsFd, AsRawFd};
54use std::os::unix::net::UnixStream;
55use std::path::PathBuf;
56use std::pin::Pin;
57use std::sync::Arc;
58use std::task::*;
59use std::time::{Duration, Instant};
60
61/// The SmolRT implements AsyncRuntime trait
62#[derive(Clone)]
63pub struct SmolRT(Option<Arc<Executor<'static>>>);
64
65impl fmt::Debug for SmolRT {
66    #[inline]
67    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
68        if self.0.is_some() { write!(f, "smol") } else { write!(f, "smol(global)") }
69    }
70}
71
72impl SmolRT {
73    #[cfg(feature = "global")]
74    #[inline]
75    pub fn new_global() -> Self {
76        Self(None)
77    }
78
79    /// spawn coroutine with specified Executor
80    #[inline]
81    pub fn new(executor: Arc<Executor<'static>>) -> Self {
82        Self(Some(executor))
83    }
84}
85
86impl orb::AsyncRuntime for SmolRT {}
87
88impl AsyncIO for SmolRT {
89    type AsyncFd<T: AsRawFd + AsFd + Send + Sync + 'static> = SmolFD<T>;
90
91    #[inline(always)]
92    async fn connect_tcp(addr: &SocketAddr) -> io::Result<Self::AsyncFd<TcpStream>> {
93        let _addr = addr.clone();
94        let stream = Async::<TcpStream>::connect(_addr).await?;
95        // into_inner will not change back to blocking
96        Self::to_async_fd_rw(stream.into_inner()?)
97    }
98
99    #[inline(always)]
100    async fn connect_unix(addr: &PathBuf) -> io::Result<Self::AsyncFd<UnixStream>> {
101        let _addr = addr.clone();
102        let stream = Async::<UnixStream>::connect(_addr).await?;
103        // into_inner will not change back to blocking
104        Self::to_async_fd_rw(stream.into_inner()?)
105    }
106
107    #[inline(always)]
108    fn to_async_fd_rd<T: AsRawFd + AsFd + Send + Sync + 'static>(
109        fd: T,
110    ) -> io::Result<Self::AsyncFd<T>> {
111        Ok(SmolFD(Async::new(fd)?))
112    }
113
114    #[inline(always)]
115    fn to_async_fd_rw<T: AsRawFd + AsFd + Send + Sync + 'static>(
116        fd: T,
117    ) -> io::Result<Self::AsyncFd<T>> {
118        Ok(SmolFD(Async::new(fd)?))
119    }
120}
121
122impl AsyncTime for SmolRT {
123    type Interval = SmolInterval;
124
125    #[inline(always)]
126    fn sleep(d: Duration) -> impl Future + Send {
127        Timer::after(d)
128    }
129
130    #[inline(always)]
131    fn tick(d: Duration) -> Self::Interval {
132        let later = std::time::Instant::now() + d;
133        SmolInterval(Timer::interval_at(later, d))
134    }
135}
136
137macro_rules! unwind_wrap {
138    ($f: expr) => {{
139        #[cfg(feature = "unwind")]
140        {
141            use futures_lite::future::FutureExt;
142            std::panic::AssertUnwindSafe($f).catch_unwind()
143        }
144        #[cfg(not(feature = "unwind"))]
145        $f
146    }};
147}
148
149/// AsyncHandle implementation for smol
150#[cfg(feature = "unwind")]
151pub struct SmolJoinHandle<T>(
152    Option<async_executor::Task<Result<T, Box<dyn std::any::Any + Send>>>>,
153);
154#[cfg(not(feature = "unwind"))]
155pub struct SmolJoinHandle<T>(Option<async_executor::Task<T>>);
156
157impl<T: Send> AsyncHandle<T> for SmolJoinHandle<T> {
158    #[inline(always)]
159    fn abort(self) {
160        // do nothing, the inner task will be dropped
161    }
162
163    #[inline]
164    fn detach(mut self) {
165        self.0.take().unwrap().detach();
166    }
167
168    #[inline]
169    fn is_finished(&self) -> bool {
170        self.0.as_ref().unwrap().is_finished()
171    }
172}
173
174impl<T> Future for SmolJoinHandle<T> {
175    type Output = Result<T, ()>;
176
177    #[inline]
178    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
179        let _self = unsafe { self.get_unchecked_mut() };
180        if let Some(inner) = _self.0.as_mut() {
181            if let Poll::Ready(r) = Pin::new(inner).poll(cx) {
182                #[cfg(feature = "unwind")]
183                {
184                    return Poll::Ready(r.map_err(|_e| ()));
185                }
186                #[cfg(not(feature = "unwind"))]
187                {
188                    return Poll::Ready(Ok(r));
189                }
190            }
191            Poll::Pending
192        } else {
193            Poll::Ready(Err(()))
194        }
195    }
196}
197
198impl<T> Drop for SmolJoinHandle<T> {
199    fn drop(&mut self) {
200        if let Some(handle) = self.0.take() {
201            handle.detach();
202        }
203    }
204}
205
206pub struct BlockingJoinHandle<T>(async_executor::Task<T>);
207
208impl<T> ThreadHandle<T> for BlockingJoinHandle<T> {
209    #[inline]
210    fn is_finished(&self) -> bool {
211        self.0.is_finished()
212    }
213}
214
215impl<T> Future for BlockingJoinHandle<T> {
216    type Output = Result<T, ()>;
217
218    #[inline]
219    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
220        let _self = unsafe { self.get_unchecked_mut() };
221        if let Poll::Ready(r) = Pin::new(&mut _self.0).poll(cx) {
222            return Poll::Ready(Ok(r));
223        }
224        Poll::Pending
225    }
226}
227
228impl AsyncExec for SmolRT {
229    type AsyncHandle<R: Send> = SmolJoinHandle<R>;
230
231    type ThreadHandle<R: Send> = BlockingJoinHandle<R>;
232
233    /// Spawn a task in the background
234    fn spawn<F, R>(&self, f: F) -> Self::AsyncHandle<R>
235    where
236        F: Future<Output = R> + Send + 'static,
237        R: Send + 'static,
238    {
239        // Although SmolJoinHandle don't need Send marker, but here in the spawn()
240        // need to restrict the requirements
241        let handle = match &self.0 {
242            Some(exec) => exec.spawn(unwind_wrap!(f)),
243            None => {
244                #[cfg(feature = "global")]
245                {
246                    smol::spawn(unwind_wrap!(f))
247                }
248                #[cfg(not(feature = "global"))]
249                unreachable!();
250            }
251        };
252        SmolJoinHandle(Some(handle))
253    }
254
255    /// Depends on how you initialize SmolRT, spawn with executor or globally
256    #[inline]
257    fn spawn_detach<F, R>(&self, f: F)
258    where
259        F: Future<Output = R> + Send + 'static,
260        R: Send + 'static,
261    {
262        self.spawn(unwind_wrap!(f)).detach();
263    }
264
265    #[inline]
266    fn spawn_blocking<F, R>(f: F) -> Self::ThreadHandle<R>
267    where
268        F: FnOnce() -> R + Send + 'static,
269        R: Send + 'static,
270    {
271        BlockingJoinHandle(blocking::unblock(f))
272    }
273
274    /// Run a future to completion on the runtime
275    ///
276    /// NOTE: when initialized  with an executor,  will block current thread until the future
277    /// returns
278    #[inline]
279    fn block_on<F, R>(&self, f: F) -> R
280    where
281        F: Future<Output = R> + Send,
282        R: Send + 'static,
283    {
284        if let Some(exec) = &self.0 {
285            block_on(exec.run(f))
286        } else {
287            #[cfg(feature = "global")]
288            {
289                smol::block_on(f)
290            }
291            #[cfg(not(feature = "global"))]
292            unreachable!();
293        }
294    }
295}
296
297/// Associate type for SmolRT
298pub struct SmolInterval(Timer);
299
300impl TimeInterval for SmolInterval {
301    #[inline]
302    fn poll_tick(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Instant> {
303        let _self = self.get_mut();
304        match _self.0.poll_next(ctx) {
305            Poll::Ready(Some(i)) => Poll::Ready(i),
306            Poll::Ready(None) => unreachable!(),
307            Poll::Pending => Poll::Pending,
308        }
309    }
310}
311
312/// Associate type for SmolRT
313pub struct SmolFD<T: AsRawFd + AsFd + Send + Sync + 'static>(Async<T>);
314
315impl<T: AsRawFd + AsFd + Send + Sync + 'static> AsyncFd<T> for SmolFD<T> {
316    #[inline(always)]
317    async fn async_read<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
318        self.0.read_with(f).await
319    }
320
321    #[inline(always)]
322    async fn async_write<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
323        self.0.write_with(f).await
324    }
325}
326
327impl<T: AsRawFd + AsFd + Send + Sync + 'static> Deref for SmolFD<T> {
328    type Target = T;
329
330    #[inline(always)]
331    fn deref(&self) -> &Self::Target {
332        self.0.get_ref()
333    }
334}