1use async_executor::Executor;
42use async_io::{Async, Timer};
43use futures_lite::{future::block_on, stream::StreamExt};
44use orb::io::{AsyncFd, AsyncIO};
45use orb::runtime::{AsyncExec, AsyncExecDyn, AsyncHandle, ThreadHandle};
46use orb::time::{AsyncTime, TimeInterval};
47use std::fmt;
48use std::future::Future;
49use std::io;
50use std::net::SocketAddr;
51use std::net::TcpStream;
52use std::ops::Deref;
53use std::os::fd::{AsFd, AsRawFd};
54use std::os::unix::net::UnixStream;
55use std::path::Path;
56use std::pin::Pin;
57use std::sync::Arc;
58use std::task::*;
59use std::time::{Duration, Instant};
60
61#[derive(Clone)]
63pub struct SmolRT(Option<Arc<Executor<'static>>>);
64
65impl fmt::Debug for SmolRT {
66 #[inline]
67 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
68 if self.0.is_some() { write!(f, "smol") } else { write!(f, "smol(global)") }
69 }
70}
71
72impl SmolRT {
73 #[cfg(feature = "global")]
74 #[inline]
75 pub fn new_global() -> Self {
76 Self(None)
77 }
78
79 #[inline]
81 pub fn new(executor: Arc<Executor<'static>>) -> Self {
82 Self(Some(executor))
83 }
84}
85
86impl orb::AsyncRuntime for SmolRT {}
87
88impl AsyncIO for SmolRT {
89 type AsyncFd<T: AsRawFd + AsFd + Send + Sync + 'static> = SmolFD<T>;
90
91 #[inline(always)]
92 async fn connect_tcp(addr: &SocketAddr) -> io::Result<Self::AsyncFd<TcpStream>> {
93 let _addr = addr.clone();
94 let stream = Async::<TcpStream>::connect(_addr).await?;
95 Self::to_async_fd_rw(stream.into_inner()?)
97 }
98
99 #[inline(always)]
100 async fn connect_unix(addr: &Path) -> io::Result<Self::AsyncFd<UnixStream>> {
101 let stream = Async::<UnixStream>::connect(addr).await?;
102 Self::to_async_fd_rw(stream.into_inner()?)
104 }
105
106 #[inline(always)]
107 fn to_async_fd_rd<T: AsRawFd + AsFd + Send + Sync + 'static>(
108 fd: T,
109 ) -> io::Result<Self::AsyncFd<T>> {
110 Ok(SmolFD(Async::new(fd)?))
111 }
112
113 #[inline(always)]
114 fn to_async_fd_rw<T: AsRawFd + AsFd + Send + Sync + 'static>(
115 fd: T,
116 ) -> io::Result<Self::AsyncFd<T>> {
117 Ok(SmolFD(Async::new(fd)?))
118 }
119}
120
121impl AsyncTime for SmolRT {
122 type Interval = SmolInterval;
123
124 #[inline(always)]
125 fn sleep(d: Duration) -> impl Future + Send {
126 Timer::after(d)
127 }
128
129 #[inline(always)]
130 fn interval(d: Duration) -> Self::Interval {
131 let later = std::time::Instant::now() + d;
132 SmolInterval(Timer::interval_at(later, d))
133 }
134}
135
136macro_rules! unwind_wrap {
137 ($f: expr) => {{
138 #[cfg(feature = "unwind")]
139 {
140 use futures_lite::future::FutureExt;
141 std::panic::AssertUnwindSafe($f).catch_unwind()
142 }
143 #[cfg(not(feature = "unwind"))]
144 $f
145 }};
146}
147
148#[cfg(feature = "unwind")]
150pub struct SmolJoinHandle<T>(
151 Option<async_executor::Task<Result<T, Box<dyn std::any::Any + Send>>>>,
152);
153#[cfg(not(feature = "unwind"))]
154pub struct SmolJoinHandle<T>(Option<async_executor::Task<T>>);
155
156impl<T: Send> AsyncHandle<T> for SmolJoinHandle<T> {
157 #[inline]
158 fn is_finished(&self) -> bool {
159 self.0.as_ref().unwrap().is_finished()
160 }
161
162 #[inline(always)]
163 fn abort(self) {
164 }
166
167 #[inline(always)]
168 fn detach(mut self) {
169 self.0.take().unwrap().detach();
170 }
171
172 #[inline(always)]
173 fn abort_boxed(self: Box<Self>) {
174 }
176
177 #[inline(always)]
178 fn detach_boxed(mut self: Box<Self>) {
179 self.0.take().unwrap().detach();
180 }
181}
182
183impl<T> Future for SmolJoinHandle<T> {
184 type Output = Result<T, ()>;
185
186 #[inline]
187 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
188 let _self = unsafe { self.get_unchecked_mut() };
189 if let Some(inner) = _self.0.as_mut() {
190 if let Poll::Ready(r) = Pin::new(inner).poll(cx) {
191 #[cfg(feature = "unwind")]
192 {
193 return Poll::Ready(r.map_err(|_e| ()));
194 }
195 #[cfg(not(feature = "unwind"))]
196 {
197 return Poll::Ready(Ok(r));
198 }
199 }
200 Poll::Pending
201 } else {
202 Poll::Ready(Err(()))
203 }
204 }
205}
206
207impl<T> Drop for SmolJoinHandle<T> {
208 fn drop(&mut self) {
209 if let Some(handle) = self.0.take() {
210 handle.detach();
211 }
212 }
213}
214
215pub struct BlockingJoinHandle<T>(async_executor::Task<T>);
216
217impl<T> ThreadHandle<T> for BlockingJoinHandle<T> {
218 #[inline]
219 fn is_finished(&self) -> bool {
220 self.0.is_finished()
221 }
222}
223
224impl<T> Future for BlockingJoinHandle<T> {
225 type Output = Result<T, ()>;
226
227 #[inline]
228 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
229 let _self = unsafe { self.get_unchecked_mut() };
230 if let Poll::Ready(r) = Pin::new(&mut _self.0).poll(cx) {
231 return Poll::Ready(Ok(r));
232 }
233 Poll::Pending
234 }
235}
236
237impl AsyncExecDyn for SmolRT {
238 #[inline(always)]
239 fn spawn_detach_dyn(&self, f: Box<dyn Future<Output = ()> + Send + Unpin>) {
240 self.spawn(unwind_wrap!(f)).detach();
241 }
242}
243
244impl AsyncExec for SmolRT {
245 type AsyncHandle<R: Send> = SmolJoinHandle<R>;
246
247 type ThreadHandle<R: Send> = BlockingJoinHandle<R>;
248
249 fn spawn<F, R>(&self, f: F) -> Self::AsyncHandle<R>
251 where
252 F: Future<Output = R> + Send + 'static,
253 R: Send + 'static,
254 {
255 let handle = match &self.0 {
258 Some(exec) => exec.spawn(unwind_wrap!(f)),
259 None => {
260 #[cfg(feature = "global")]
261 {
262 smol::spawn(unwind_wrap!(f))
263 }
264 #[cfg(not(feature = "global"))]
265 unreachable!();
266 }
267 };
268 SmolJoinHandle(Some(handle))
269 }
270
271 #[inline]
273 fn spawn_detach<F, R>(&self, f: F)
274 where
275 F: Future<Output = R> + Send + 'static,
276 R: Send + 'static,
277 {
278 self.spawn(unwind_wrap!(f)).detach();
279 }
280
281 #[inline]
282 fn spawn_blocking<F, R>(f: F) -> Self::ThreadHandle<R>
283 where
284 F: FnOnce() -> R + Send + 'static,
285 R: Send + 'static,
286 {
287 BlockingJoinHandle(blocking::unblock(f))
288 }
289
290 #[inline]
295 fn block_on<F, R>(&self, f: F) -> R
296 where
297 F: Future<Output = R> + Send,
298 R: Send + 'static,
299 {
300 if let Some(exec) = &self.0 {
301 block_on(exec.run(f))
302 } else {
303 #[cfg(feature = "global")]
304 {
305 smol::block_on(f)
306 }
307 #[cfg(not(feature = "global"))]
308 unreachable!();
309 }
310 }
311}
312
313pub struct SmolInterval(Timer);
315
316impl TimeInterval for SmolInterval {
317 #[inline]
318 fn poll_tick(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Instant> {
319 let _self = self.get_mut();
320 match _self.0.poll_next(ctx) {
321 Poll::Ready(Some(i)) => Poll::Ready(i),
322 Poll::Ready(None) => unreachable!(),
323 Poll::Pending => Poll::Pending,
324 }
325 }
326}
327
328pub struct SmolFD<T: AsRawFd + AsFd + Send + Sync + 'static>(Async<T>);
330
331impl<T: AsRawFd + AsFd + Send + Sync + 'static> AsyncFd<T> for SmolFD<T> {
332 #[inline(always)]
333 async fn async_read<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
334 self.0.read_with(f).await
335 }
336
337 #[inline(always)]
338 async fn async_write<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
339 self.0.write_with(f).await
340 }
341}
342
343impl<T: AsRawFd + AsFd + Send + Sync + 'static> Deref for SmolFD<T> {
344 type Target = T;
345
346 #[inline(always)]
347 fn deref(&self) -> &Self::Target {
348 self.0.get_ref()
349 }
350}