Skip to main content

orb_smol/
lib.rs

1//! # Smol Runtime adapter for Orb framework
2//!
3//! This crate provides a Smol-based implementation of the Orb async runtime traits.
4//! It allows users to leverage Smol's lightweight async runtime with the unified Orb interface.
5//!
6//! The main type provided is [`SmolRT`], which implements the core runtime functionality.
7//!
8//! See the [Orb crate](https://docs.rs/orb) for more information.
9//!
10//! ## Features
11//!
12//! - `global`: Enables the global executor feature, which allows using `smol` default global executor
13//!   instead of providing your own executor instance. (by default not enabled, omit the `smol`
14//!   dependency)
15//!
16//! - `unwind`: Use AssertUnwindSafe to capture panic inside the task, and return Err(()) to the
17//! task join handle. (by default not enabled, panic terminates the program)
18//!
19//! ## Usage
20//!
21//! With a custom executor:
22//!
23//! ```rust
24//! use orb_smol::SmolRT;
25//! use std::sync::Arc;
26//! use async_executor::Executor;
27//!
28//! let executor = Arc::new(Executor::new());
29//! let rt = SmolRT::new(executor);
30//! ```
31//!
32//! With the global executor (requires the `global` feature):
33//!
34//! ```rust
35//! use orb_smol::SmolRT;
36//!
37//! #[cfg(feature = "global")]
38//! let rt = SmolRT::new_global();
39//! ```
40
41use async_executor::Executor;
42use async_io::{Async, Timer};
43use futures_lite::{future::block_on, stream::StreamExt};
44use orb::io::{AsyncFd, AsyncIO};
45use orb::runtime::{AsyncExec, AsyncHandle, ThreadHandle};
46use orb::time::{AsyncTime, TimeInterval};
47use std::fmt;
48use std::future::Future;
49use std::io;
50use std::net::SocketAddr;
51use std::net::TcpStream;
52use std::ops::Deref;
53use std::os::fd::{AsFd, AsRawFd};
54use std::os::unix::net::UnixStream;
55use std::path::Path;
56use std::pin::Pin;
57use std::sync::Arc;
58use std::task::*;
59use std::time::{Duration, Instant};
60
61/// The SmolRT implements AsyncRuntime trait
62#[derive(Clone)]
63pub struct SmolRT(Option<Arc<Executor<'static>>>);
64
65impl fmt::Debug for SmolRT {
66    #[inline]
67    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
68        if self.0.is_some() { write!(f, "smol") } else { write!(f, "smol(global)") }
69    }
70}
71
72impl SmolRT {
73    #[cfg(feature = "global")]
74    #[inline]
75    pub fn new_global() -> Self {
76        Self(None)
77    }
78
79    /// spawn coroutine with specified Executor
80    #[inline]
81    pub fn new(executor: Arc<Executor<'static>>) -> Self {
82        Self(Some(executor))
83    }
84}
85
86impl orb::AsyncRuntime for SmolRT {}
87
88impl AsyncIO for SmolRT {
89    type AsyncFd<T: AsRawFd + AsFd + Send + Sync + 'static> = SmolFD<T>;
90
91    #[inline(always)]
92    async fn connect_tcp(addr: &SocketAddr) -> io::Result<Self::AsyncFd<TcpStream>> {
93        let _addr = addr.clone();
94        let stream = Async::<TcpStream>::connect(_addr).await?;
95        // into_inner will not change back to blocking
96        Self::to_async_fd_rw(stream.into_inner()?)
97    }
98
99    #[inline(always)]
100    async fn connect_unix(addr: &Path) -> io::Result<Self::AsyncFd<UnixStream>> {
101        let stream = Async::<UnixStream>::connect(addr).await?;
102        // into_inner will not change back to blocking
103        Self::to_async_fd_rw(stream.into_inner()?)
104    }
105
106    #[inline(always)]
107    fn to_async_fd_rd<T: AsRawFd + AsFd + Send + Sync + 'static>(
108        fd: T,
109    ) -> io::Result<Self::AsyncFd<T>> {
110        Ok(SmolFD(Async::new(fd)?))
111    }
112
113    #[inline(always)]
114    fn to_async_fd_rw<T: AsRawFd + AsFd + Send + Sync + 'static>(
115        fd: T,
116    ) -> io::Result<Self::AsyncFd<T>> {
117        Ok(SmolFD(Async::new(fd)?))
118    }
119}
120
121impl AsyncTime for SmolRT {
122    type Interval = SmolInterval;
123
124    #[inline(always)]
125    fn sleep(d: Duration) -> impl Future + Send {
126        Timer::after(d)
127    }
128
129    #[inline(always)]
130    fn interval(d: Duration) -> Self::Interval {
131        let later = std::time::Instant::now() + d;
132        SmolInterval(Timer::interval_at(later, d))
133    }
134}
135
136macro_rules! unwind_wrap {
137    ($f: expr) => {{
138        #[cfg(feature = "unwind")]
139        {
140            use futures_lite::future::FutureExt;
141            std::panic::AssertUnwindSafe($f).catch_unwind()
142        }
143        #[cfg(not(feature = "unwind"))]
144        $f
145    }};
146}
147
148/// AsyncHandle implementation for smol
149#[cfg(feature = "unwind")]
150pub struct SmolJoinHandle<T>(
151    Option<async_executor::Task<Result<T, Box<dyn std::any::Any + Send>>>>,
152);
153#[cfg(not(feature = "unwind"))]
154pub struct SmolJoinHandle<T>(Option<async_executor::Task<T>>);
155
156impl<T: Send> AsyncHandle<T> for SmolJoinHandle<T> {
157    #[inline]
158    fn is_finished(&self) -> bool {
159        self.0.as_ref().unwrap().is_finished()
160    }
161
162    #[inline(always)]
163    fn abort(self) {
164        // do nothing, the inner task will be dropped
165    }
166
167    #[inline(always)]
168    fn detach(mut self) {
169        self.0.take().unwrap().detach();
170    }
171
172    #[inline(always)]
173    fn abort_boxed(self: Box<Self>) {
174        // do nothing, the inner task will be dropped
175    }
176
177    #[inline(always)]
178    fn detach_boxed(mut self: Box<Self>) {
179        self.0.take().unwrap().detach();
180    }
181}
182
183impl<T> Future for SmolJoinHandle<T> {
184    type Output = Result<T, ()>;
185
186    #[inline]
187    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
188        let _self = unsafe { self.get_unchecked_mut() };
189        if let Some(inner) = _self.0.as_mut() {
190            if let Poll::Ready(r) = Pin::new(inner).poll(cx) {
191                #[cfg(feature = "unwind")]
192                {
193                    return Poll::Ready(r.map_err(|_e| ()));
194                }
195                #[cfg(not(feature = "unwind"))]
196                {
197                    return Poll::Ready(Ok(r));
198                }
199            }
200            Poll::Pending
201        } else {
202            Poll::Ready(Err(()))
203        }
204    }
205}
206
207impl<T> Drop for SmolJoinHandle<T> {
208    fn drop(&mut self) {
209        if let Some(handle) = self.0.take() {
210            handle.detach();
211        }
212    }
213}
214
215pub struct BlockingJoinHandle<T>(async_executor::Task<T>);
216
217impl<T> ThreadHandle<T> for BlockingJoinHandle<T> {
218    #[inline]
219    fn is_finished(&self) -> bool {
220        self.0.is_finished()
221    }
222}
223
224impl<T> Future for BlockingJoinHandle<T> {
225    type Output = Result<T, ()>;
226
227    #[inline]
228    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
229        let _self = unsafe { self.get_unchecked_mut() };
230        if let Poll::Ready(r) = Pin::new(&mut _self.0).poll(cx) {
231            return Poll::Ready(Ok(r));
232        }
233        Poll::Pending
234    }
235}
236
237impl AsyncExec for SmolRT {
238    type AsyncHandle<R: Send> = SmolJoinHandle<R>;
239
240    type ThreadHandle<R: Send> = BlockingJoinHandle<R>;
241
242    /// Spawn a task in the background
243    fn spawn<F, R>(&self, f: F) -> Self::AsyncHandle<R>
244    where
245        F: Future<Output = R> + Send + 'static,
246        R: Send + 'static,
247    {
248        // Although SmolJoinHandle don't need Send marker, but here in the spawn()
249        // need to restrict the requirements
250        let handle = match &self.0 {
251            Some(exec) => exec.spawn(unwind_wrap!(f)),
252            None => {
253                #[cfg(feature = "global")]
254                {
255                    smol::spawn(unwind_wrap!(f))
256                }
257                #[cfg(not(feature = "global"))]
258                unreachable!();
259            }
260        };
261        SmolJoinHandle(Some(handle))
262    }
263
264    /// Depends on how you initialize SmolRT, spawn with executor or globally
265    #[inline]
266    fn spawn_detach<F, R>(&self, f: F)
267    where
268        F: Future<Output = R> + Send + 'static,
269        R: Send + 'static,
270    {
271        self.spawn(unwind_wrap!(f)).detach();
272    }
273
274    #[inline]
275    fn spawn_blocking<F, R>(f: F) -> Self::ThreadHandle<R>
276    where
277        F: FnOnce() -> R + Send + 'static,
278        R: Send + 'static,
279    {
280        BlockingJoinHandle(blocking::unblock(f))
281    }
282
283    /// Run a future to completion on the runtime
284    ///
285    /// NOTE: when initialized  with an executor,  will block current thread until the future
286    /// returns
287    #[inline]
288    fn block_on<F, R>(&self, f: F) -> R
289    where
290        F: Future<Output = R> + Send,
291        R: Send + 'static,
292    {
293        if let Some(exec) = &self.0 {
294            block_on(exec.run(f))
295        } else {
296            #[cfg(feature = "global")]
297            {
298                smol::block_on(f)
299            }
300            #[cfg(not(feature = "global"))]
301            unreachable!();
302        }
303    }
304}
305
306/// Associate type for SmolRT
307pub struct SmolInterval(Timer);
308
309impl TimeInterval for SmolInterval {
310    #[inline]
311    fn poll_tick(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Instant> {
312        let _self = self.get_mut();
313        match _self.0.poll_next(ctx) {
314            Poll::Ready(Some(i)) => Poll::Ready(i),
315            Poll::Ready(None) => unreachable!(),
316            Poll::Pending => Poll::Pending,
317        }
318    }
319}
320
321/// Associate type for SmolRT
322pub struct SmolFD<T: AsRawFd + AsFd + Send + Sync + 'static>(Async<T>);
323
324impl<T: AsRawFd + AsFd + Send + Sync + 'static> AsyncFd<T> for SmolFD<T> {
325    #[inline(always)]
326    async fn async_read<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
327        self.0.read_with(f).await
328    }
329
330    #[inline(always)]
331    async fn async_write<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
332        self.0.write_with(f).await
333    }
334}
335
336impl<T: AsRawFd + AsFd + Send + Sync + 'static> Deref for SmolFD<T> {
337    type Target = T;
338
339    #[inline(always)]
340    fn deref(&self) -> &Self::Target {
341        self.0.get_ref()
342    }
343}