1use async_executor::Executor;
79use async_io::{Async, Timer};
80#[allow(unused_imports)]
81use crossfire::{MAsyncRx, mpmc, null::CloseHandle};
82use futures_lite::{future::block_on, stream::StreamExt};
83use orb::AsyncRuntime;
84use orb::io::{AsyncFd, AsyncIO};
85use orb::runtime::{AsyncExec, AsyncJoiner, ThreadJoiner};
86use orb::time::{AsyncTime, TimeInterval};
87use std::cell::Cell;
88use std::fmt;
89use std::future::Future;
90use std::io;
91use std::net::{SocketAddr, TcpStream};
92use std::num::NonZero;
93use std::ops::Deref;
94use std::os::{
95 fd::{AsFd, AsRawFd},
96 unix::net::UnixStream,
97};
98use std::path::Path;
99use std::pin::Pin;
100use std::ptr;
101use std::sync::Arc;
102use std::task::*;
103use std::thread;
104use std::time::{Duration, Instant};
105
106pub struct SmolRT {}
107
108#[derive(Clone)]
110pub struct SmolExec(Option<SmolExecInner>);
111
112#[derive(Clone)]
113struct SmolExecInner {
114 ex: Arc<Executor<'static>>,
115 _close_h: Option<CloseHandle<mpmc::Null>>,
116}
117
118thread_local! {
120 static CURRENT_EXECUTOR: Cell<*const Executor<'static>> = const { Cell::new(ptr::null()) };
121}
122
123fn set_current_executor(exec: *const Executor<'static>) {
125 CURRENT_EXECUTOR.set(exec);
126}
127
128#[cfg(not(feature = "global"))]
129fn get_current_executor() -> *const Executor<'static> {
131 CURRENT_EXECUTOR.get()
132}
133
134impl fmt::Debug for SmolExec {
135 #[inline]
136 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
137 if self.0.is_some() { write!(f, "smol") } else { write!(f, "smol(global)") }
138 }
139}
140
141impl AsyncIO for SmolRT {
142 type AsyncFd<T: AsRawFd + AsFd + Send + Sync + 'static> = SmolFD<T>;
143
144 #[inline(always)]
145 async fn connect_tcp(addr: &SocketAddr) -> io::Result<Self::AsyncFd<TcpStream>> {
146 let _addr = addr.clone();
147 let stream = Async::<TcpStream>::connect(_addr).await?;
148 Self::to_async_fd_rw(stream.into_inner()?)
150 }
151
152 #[inline(always)]
153 async fn connect_unix(addr: &Path) -> io::Result<Self::AsyncFd<UnixStream>> {
154 let stream = Async::<UnixStream>::connect(addr).await?;
155 Self::to_async_fd_rw(stream.into_inner()?)
157 }
158
159 #[inline(always)]
160 fn to_async_fd_rd<T: AsRawFd + AsFd + Send + Sync + 'static>(
161 fd: T,
162 ) -> io::Result<Self::AsyncFd<T>> {
163 Ok(SmolFD(Async::new(fd)?))
164 }
165
166 #[inline(always)]
167 fn to_async_fd_rw<T: AsRawFd + AsFd + Send + Sync + 'static>(
168 fd: T,
169 ) -> io::Result<Self::AsyncFd<T>> {
170 Ok(SmolFD(Async::new(fd)?))
171 }
172}
173
174impl AsyncTime for SmolRT {
175 type Interval = SmolInterval;
176
177 #[inline(always)]
178 fn sleep(d: Duration) -> impl Future + Send {
179 Timer::after(d)
180 }
181
182 #[inline(always)]
183 fn interval(d: Duration) -> Self::Interval {
184 let later = std::time::Instant::now() + d;
185 SmolInterval(Timer::interval_at(later, d))
186 }
187}
188
189macro_rules! unwind_wrap {
190 ($f: expr) => {{
191 #[cfg(feature = "unwind")]
192 {
193 use futures_lite::future::FutureExt;
194 std::panic::AssertUnwindSafe($f).catch_unwind()
195 }
196 #[cfg(not(feature = "unwind"))]
197 $f
198 }};
199}
200
201#[cfg(feature = "unwind")]
203pub struct SmolJoinHandle<T>(
204 Option<async_executor::Task<Result<T, Box<dyn std::any::Any + Send>>>>,
205);
206#[cfg(not(feature = "unwind"))]
207pub struct SmolJoinHandle<T>(Option<async_executor::Task<T>>);
208
209impl<T: Send> AsyncJoiner<T> for SmolJoinHandle<T> {
210 #[inline]
211 fn is_finished(&self) -> bool {
212 self.0.as_ref().unwrap().is_finished()
213 }
214
215 #[inline(always)]
216 fn abort(self) {
217 }
219
220 #[inline(always)]
221 fn detach(mut self) {
222 self.0.take().unwrap().detach();
223 }
224
225 #[inline(always)]
226 fn abort_boxed(self: Box<Self>) {
227 }
229
230 #[inline(always)]
231 fn detach_boxed(mut self: Box<Self>) {
232 self.0.take().unwrap().detach();
233 }
234}
235
236impl<T> Future for SmolJoinHandle<T> {
237 type Output = Result<T, ()>;
238
239 #[inline]
240 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
241 let _self = unsafe { self.get_unchecked_mut() };
242 if let Some(inner) = _self.0.as_mut() {
243 if let Poll::Ready(r) = Pin::new(inner).poll(cx) {
244 #[cfg(feature = "unwind")]
245 {
246 return Poll::Ready(r.map_err(|_e| ()));
247 }
248 #[cfg(not(feature = "unwind"))]
249 {
250 return Poll::Ready(Ok(r));
251 }
252 }
253 Poll::Pending
254 } else {
255 Poll::Ready(Err(()))
256 }
257 }
258}
259
260impl<T> Drop for SmolJoinHandle<T> {
261 fn drop(&mut self) {
262 if let Some(handle) = self.0.take() {
263 handle.detach();
264 }
265 }
266}
267
268pub struct BlockingJoinHandle<T>(async_executor::Task<T>);
269
270impl<T> ThreadJoiner<T> for BlockingJoinHandle<T> {
271 #[inline]
272 fn is_finished(&self) -> bool {
273 self.0.is_finished()
274 }
275}
276
277impl<T> Future for BlockingJoinHandle<T> {
278 type Output = Result<T, ()>;
279
280 #[inline]
281 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
282 let _self = unsafe { self.get_unchecked_mut() };
283 if let Poll::Ready(r) = Pin::new(&mut _self.0).poll(cx) {
284 return Poll::Ready(Ok(r));
285 }
286 Poll::Pending
287 }
288}
289
290impl AsyncRuntime for SmolRT {
291 type Exec = SmolExec;
292
293 #[inline(always)]
302 fn current() -> SmolExec {
303 SmolExec(Some(SmolExecInner { ex: Arc::new(Executor::new()), _close_h: None }))
304 }
305
306 #[inline(always)]
312 fn one() -> SmolExec {
313 Self::multi(1)
314 }
315
316 #[inline(always)]
324 fn multi(mut size: usize) -> SmolExec {
325 if size == 0 {
326 size = usize::from(
327 thread::available_parallelism().unwrap_or(NonZero::new(1usize).unwrap()),
328 )
329 }
330 #[cfg(feature = "global")]
331 {
332 unsafe { std::env::set_var("SMOL_THREADS", size.to_string()) };
333 SmolExec(None)
334 }
335 #[cfg(not(feature = "global"))]
336 {
337 let (close_h, rx): (CloseHandle<mpmc::Null>, MAsyncRx<mpmc::Null>) = mpmc::new();
338 let inner = SmolExecInner { ex: Arc::new(Executor::new()), _close_h: Some(close_h) };
340 #[cfg(not(target_os = "espidf"))]
341 inner.ex.spawn(async_process::driver()).detach();
342 let ex = inner.ex.clone();
343 let ex_ptr: usize = Arc::as_ptr(&inner.ex) as usize;
345 for n in 1..=size {
346 let _ex = ex.clone();
347 let _rx = rx.clone();
348 let _ex_ptr = ex_ptr;
349 thread::Builder::new()
350 .name(format!("smol-{}", n))
351 .spawn(move || {
352 set_current_executor(_ex_ptr as *const Executor<'static>);
353 let _ = block_on(_ex.run(_rx.recv()));
354 set_current_executor(ptr::null());
355 })
356 .expect("cannot spawn executor thread");
357 }
358 SmolExec(Some(inner))
359 }
360 }
361
362 fn spawn<F, R>(f: F) -> SmolJoinHandle<R>
364 where
365 F: Future<Output = R> + Send + 'static,
366 R: Send + 'static,
367 {
368 #[cfg(feature = "global")]
369 {
370 SmolJoinHandle(Some(smol::spawn(unwind_wrap!(f))))
371 }
372 #[cfg(not(feature = "global"))]
373 {
374 let ex_ptr = get_current_executor();
375 assert!(!ex_ptr.is_null(), "spawn must be called in runtime worker context");
376 let ex = unsafe { &*ex_ptr };
377 SmolJoinHandle(Some(ex.spawn(unwind_wrap!(f))))
378 }
379 }
380
381 #[inline]
383 fn spawn_detach<F, R>(f: F)
384 where
385 F: Future<Output = R> + Send + 'static,
386 R: Send + 'static,
387 {
388 #[cfg(feature = "global")]
389 {
390 smol::spawn(f).detach()
391 }
392 #[cfg(not(feature = "global"))]
393 {
394 let ex_ptr = get_current_executor();
395 assert!(!ex_ptr.is_null(), "spawn_detach must be called in runtime worker context");
396 let ex = unsafe { &*ex_ptr };
397 ex.spawn(unwind_wrap!(f)).detach();
398 }
399 }
400
401 #[inline]
402 fn spawn_blocking<F, R>(f: F) -> BlockingJoinHandle<R>
403 where
404 F: FnOnce() -> R + Send + 'static,
405 R: Send + 'static,
406 {
407 BlockingJoinHandle(blocking::unblock(f))
408 }
409}
410
411impl AsyncExec for SmolExec {
412 type AsyncJoiner<R: Send> = SmolJoinHandle<R>;
413
414 type ThreadJoiner<R: Send> = BlockingJoinHandle<R>;
415
416 fn spawn<F, R>(&self, f: F) -> Self::AsyncJoiner<R>
418 where
419 F: Future<Output = R> + Send + 'static,
420 R: Send + 'static,
421 {
422 let handle = match &self.0 {
425 Some(inner) => inner.ex.spawn(unwind_wrap!(f)),
426 None => {
427 #[cfg(feature = "global")]
428 {
429 smol::spawn(unwind_wrap!(f))
430 }
431 #[cfg(not(feature = "global"))]
432 unreachable!();
433 }
434 };
435 SmolJoinHandle(Some(handle))
436 }
437
438 #[inline]
440 fn spawn_detach<F, R>(&self, f: F)
441 where
442 F: Future<Output = R> + Send + 'static,
443 R: Send + 'static,
444 {
445 self.spawn(unwind_wrap!(f)).detach();
446 }
447
448 #[inline]
449 fn spawn_blocking<F, R>(&self, f: F) -> Self::ThreadJoiner<R>
450 where
451 F: FnOnce() -> R + Send + 'static,
452 R: Send + 'static,
453 {
454 BlockingJoinHandle(blocking::unblock(f))
455 }
456
457 #[inline]
462 fn block_on<F, R>(&self, f: F) -> R
463 where
464 F: Future<Output = R> + Send,
465 R: 'static,
466 {
467 if let Some(inner) = &self.0 {
468 let ex_ptr: *const Executor<'static> = Arc::as_ptr(&inner.ex);
469 set_current_executor(ex_ptr);
470 let result = block_on(inner.ex.run(f));
471 set_current_executor(ptr::null());
472 result
473 } else {
474 #[cfg(feature = "global")]
475 {
476 smol::block_on(f)
477 }
478 #[cfg(not(feature = "global"))]
479 unreachable!();
480 }
481 }
482}
483
484pub struct SmolInterval(Timer);
486
487impl TimeInterval for SmolInterval {
488 #[inline]
489 fn poll_tick(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Instant> {
490 let _self = self.get_mut();
491 match _self.0.poll_next(ctx) {
492 Poll::Ready(Some(i)) => Poll::Ready(i),
493 Poll::Ready(None) => unreachable!(),
494 Poll::Pending => Poll::Pending,
495 }
496 }
497}
498
499pub struct SmolFD<T: AsRawFd + AsFd + Send + Sync + 'static>(Async<T>);
501
502impl<T: AsRawFd + AsFd + Send + Sync + 'static> AsyncFd<T> for SmolFD<T> {
503 #[inline(always)]
504 async fn async_read<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
505 self.0.read_with(f).await
506 }
507
508 #[inline(always)]
509 async fn async_write<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
510 self.0.write_with(f).await
511 }
512}
513
514impl<T: AsRawFd + AsFd + Send + Sync + 'static> Deref for SmolFD<T> {
515 type Target = T;
516
517 #[inline(always)]
518 fn deref(&self) -> &Self::Target {
519 self.0.get_ref()
520 }
521}