1use async_executor::Executor;
88use async_io::{Async, Timer};
89use crossfire::{MAsyncRx, mpmc, null::CloseHandle};
90use futures_lite::{future::block_on, stream::StreamExt};
91use orb::AsyncRuntime;
92use orb::io::{AsyncFd, AsyncIO};
93use orb::runtime::{AsyncExec, AsyncJoiner, ThreadJoiner};
94use orb::time::{AsyncTime, TimeInterval};
95use std::cell::Cell;
96use std::fmt;
97use std::future::Future;
98use std::io;
99use std::net::{SocketAddr, TcpStream};
100use std::num::NonZero;
101use std::ops::Deref;
102use std::os::{
103 fd::{AsFd, AsRawFd},
104 unix::net::UnixStream,
105};
106use std::path::Path;
107use std::pin::Pin;
108use std::ptr;
109use std::sync::Arc;
110use std::task::*;
111use std::thread;
112use std::time::{Duration, Instant};
113
114pub struct SmolRT {}
115
116#[derive(Clone)]
118pub struct SmolExec(Option<SmolExecInner>);
119
120#[derive(Clone)]
121struct SmolExecInner {
122 ex: Arc<Executor<'static>>,
123 _close_h: Option<CloseHandle<mpmc::Null>>,
124}
125
126thread_local! {
128 static CURRENT_EXECUTOR: Cell<*const Executor<'static>> = const { Cell::new(ptr::null()) };
129}
130
131fn set_current_executor(exec: *const Executor<'static>) {
133 CURRENT_EXECUTOR.set(exec);
134}
135
136fn get_current_executor() -> *const Executor<'static> {
138 CURRENT_EXECUTOR.get()
139}
140
141impl fmt::Debug for SmolExec {
142 #[inline]
143 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
144 if self.0.is_some() { write!(f, "smol") } else { write!(f, "smol(global)") }
145 }
146}
147
148impl SmolExec {
149 #[cfg(feature = "global")]
150 #[inline]
151 pub fn new_global() -> Self {
152 Self(None)
153 }
154}
155
156impl AsyncIO for SmolRT {
157 type AsyncFd<T: AsRawFd + AsFd + Send + Sync + 'static> = SmolFD<T>;
158
159 #[inline(always)]
160 async fn connect_tcp(addr: &SocketAddr) -> io::Result<Self::AsyncFd<TcpStream>> {
161 let _addr = addr.clone();
162 let stream = Async::<TcpStream>::connect(_addr).await?;
163 Self::to_async_fd_rw(stream.into_inner()?)
165 }
166
167 #[inline(always)]
168 async fn connect_unix(addr: &Path) -> io::Result<Self::AsyncFd<UnixStream>> {
169 let stream = Async::<UnixStream>::connect(addr).await?;
170 Self::to_async_fd_rw(stream.into_inner()?)
172 }
173
174 #[inline(always)]
175 fn to_async_fd_rd<T: AsRawFd + AsFd + Send + Sync + 'static>(
176 fd: T,
177 ) -> io::Result<Self::AsyncFd<T>> {
178 Ok(SmolFD(Async::new(fd)?))
179 }
180
181 #[inline(always)]
182 fn to_async_fd_rw<T: AsRawFd + AsFd + Send + Sync + 'static>(
183 fd: T,
184 ) -> io::Result<Self::AsyncFd<T>> {
185 Ok(SmolFD(Async::new(fd)?))
186 }
187}
188
189impl AsyncTime for SmolRT {
190 type Interval = SmolInterval;
191
192 #[inline(always)]
193 fn sleep(d: Duration) -> impl Future + Send {
194 Timer::after(d)
195 }
196
197 #[inline(always)]
198 fn interval(d: Duration) -> Self::Interval {
199 let later = std::time::Instant::now() + d;
200 SmolInterval(Timer::interval_at(later, d))
201 }
202}
203
204macro_rules! unwind_wrap {
205 ($f: expr) => {{
206 #[cfg(feature = "unwind")]
207 {
208 use futures_lite::future::FutureExt;
209 std::panic::AssertUnwindSafe($f).catch_unwind()
210 }
211 #[cfg(not(feature = "unwind"))]
212 $f
213 }};
214}
215
216#[cfg(feature = "unwind")]
218pub struct SmolJoinHandle<T>(
219 Option<async_executor::Task<Result<T, Box<dyn std::any::Any + Send>>>>,
220);
221#[cfg(not(feature = "unwind"))]
222pub struct SmolJoinHandle<T>(Option<async_executor::Task<T>>);
223
224impl<T: Send> AsyncJoiner<T> for SmolJoinHandle<T> {
225 #[inline]
226 fn is_finished(&self) -> bool {
227 self.0.as_ref().unwrap().is_finished()
228 }
229
230 #[inline(always)]
231 fn abort(self) {
232 }
234
235 #[inline(always)]
236 fn detach(mut self) {
237 self.0.take().unwrap().detach();
238 }
239
240 #[inline(always)]
241 fn abort_boxed(self: Box<Self>) {
242 }
244
245 #[inline(always)]
246 fn detach_boxed(mut self: Box<Self>) {
247 self.0.take().unwrap().detach();
248 }
249}
250
251impl<T> Future for SmolJoinHandle<T> {
252 type Output = Result<T, ()>;
253
254 #[inline]
255 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
256 let _self = unsafe { self.get_unchecked_mut() };
257 if let Some(inner) = _self.0.as_mut() {
258 if let Poll::Ready(r) = Pin::new(inner).poll(cx) {
259 #[cfg(feature = "unwind")]
260 {
261 return Poll::Ready(r.map_err(|_e| ()));
262 }
263 #[cfg(not(feature = "unwind"))]
264 {
265 return Poll::Ready(Ok(r));
266 }
267 }
268 Poll::Pending
269 } else {
270 Poll::Ready(Err(()))
271 }
272 }
273}
274
275impl<T> Drop for SmolJoinHandle<T> {
276 fn drop(&mut self) {
277 if let Some(handle) = self.0.take() {
278 handle.detach();
279 }
280 }
281}
282
283pub struct BlockingJoinHandle<T>(async_executor::Task<T>);
284
285impl<T> ThreadJoiner<T> for BlockingJoinHandle<T> {
286 #[inline]
287 fn is_finished(&self) -> bool {
288 self.0.is_finished()
289 }
290}
291
292impl<T> Future for BlockingJoinHandle<T> {
293 type Output = Result<T, ()>;
294
295 #[inline]
296 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
297 let _self = unsafe { self.get_unchecked_mut() };
298 if let Poll::Ready(r) = Pin::new(&mut _self.0).poll(cx) {
299 return Poll::Ready(Ok(r));
300 }
301 Poll::Pending
302 }
303}
304
305impl AsyncRuntime for SmolRT {
306 type Exec = SmolExec;
307
308 #[inline(always)]
317 fn current() -> SmolExec {
318 SmolExec(Some(SmolExecInner { ex: Arc::new(Executor::new()), _close_h: None }))
319 }
320
321 #[inline(always)]
327 fn one() -> SmolExec {
328 Self::multi(1)
329 }
330
331 #[inline(always)]
338 fn multi(mut size: usize) -> SmolExec {
339 if size == 0 {
340 size = usize::from(
341 thread::available_parallelism().unwrap_or(NonZero::new(1usize).unwrap()),
342 )
343 }
344 #[cfg(feature = "global")]
345 {
346 unsafe { std::env::set_var("SMOL_THREADS", size.to_string()) };
347 SmolExec::new_global()
348 }
349 #[cfg(not(feature = "global"))]
350 {
351 let (close_h, rx): (CloseHandle<mpmc::Null>, MAsyncRx<mpmc::Null>) = mpmc::new();
352 let inner = SmolExecInner { ex: Arc::new(Executor::new()), _close_h: Some(close_h) };
354 #[cfg(not(target_os = "espidf"))]
355 inner.ex.spawn(async_process::driver()).detach();
356 let ex = inner.ex.clone();
357 let ex_ptr: usize = Arc::as_ptr(&inner.ex) as usize;
359 for n in 1..=size {
360 let _ex = ex.clone();
361 let _rx = rx.clone();
362 let _ex_ptr = ex_ptr;
363 thread::Builder::new()
364 .name(format!("smol-{}", n))
365 .spawn(move || {
366 set_current_executor(_ex_ptr as *const Executor<'static>);
367 let _ = block_on(_ex.run(_rx.recv()));
368 set_current_executor(ptr::null());
369 })
370 .expect("cannot spawn executor thread");
371 }
372 SmolExec(Some(inner))
373 }
374 }
375
376 fn spawn<F, R>(f: F) -> SmolJoinHandle<R>
378 where
379 F: Future<Output = R> + Send + 'static,
380 R: Send + 'static,
381 {
382 #[cfg(feature = "global")]
383 {
384 SmolJoinHandle(smol::spawn(f))
385 }
386 #[cfg(not(feature = "global"))]
387 {
388 let ex_ptr = get_current_executor();
389 assert!(!ex_ptr.is_null(), "spawn must be called in runtime worker context");
390 let ex = unsafe { &*ex_ptr };
391 SmolJoinHandle(Some(ex.spawn(unwind_wrap!(f))))
392 }
393 }
394
395 #[inline]
397 fn spawn_detach<F, R>(f: F)
398 where
399 F: Future<Output = R> + Send + 'static,
400 R: Send + 'static,
401 {
402 #[cfg(feature = "global")]
403 {
404 smol::spawn(f).detach()
405 }
406 #[cfg(not(feature = "global"))]
407 {
408 let ex_ptr = get_current_executor();
409 assert!(!ex_ptr.is_null(), "spawn_detach must be called in runtime worker context");
410 let ex = unsafe { &*ex_ptr };
411 ex.spawn(unwind_wrap!(f)).detach();
412 }
413 }
414
415 #[inline]
416 fn spawn_blocking<F, R>(f: F) -> BlockingJoinHandle<R>
417 where
418 F: FnOnce() -> R + Send + 'static,
419 R: Send + 'static,
420 {
421 BlockingJoinHandle(blocking::unblock(f))
422 }
423}
424
425impl AsyncExec for SmolExec {
426 type AsyncJoiner<R: Send> = SmolJoinHandle<R>;
427
428 type ThreadJoiner<R: Send> = BlockingJoinHandle<R>;
429
430 fn spawn<F, R>(&self, f: F) -> Self::AsyncJoiner<R>
432 where
433 F: Future<Output = R> + Send + 'static,
434 R: Send + 'static,
435 {
436 let handle = match &self.0 {
439 Some(inner) => inner.ex.spawn(unwind_wrap!(f)),
440 None => {
441 #[cfg(feature = "global")]
442 {
443 smol::spawn(unwind_wrap!(f))
444 }
445 #[cfg(not(feature = "global"))]
446 unreachable!();
447 }
448 };
449 SmolJoinHandle(Some(handle))
450 }
451
452 #[inline]
454 fn spawn_detach<F, R>(&self, f: F)
455 where
456 F: Future<Output = R> + Send + 'static,
457 R: Send + 'static,
458 {
459 self.spawn(unwind_wrap!(f)).detach();
460 }
461
462 #[inline]
463 fn spawn_blocking<F, R>(&self, f: F) -> Self::ThreadJoiner<R>
464 where
465 F: FnOnce() -> R + Send + 'static,
466 R: Send + 'static,
467 {
468 BlockingJoinHandle(blocking::unblock(f))
469 }
470
471 #[inline]
476 fn block_on<F, R>(&self, f: F) -> R
477 where
478 F: Future<Output = R> + Send,
479 R: 'static,
480 {
481 if let Some(inner) = &self.0 {
482 let ex_ptr: *const Executor<'static> = Arc::as_ptr(&inner.ex);
483 set_current_executor(ex_ptr);
484 let result = block_on(inner.ex.run(f));
485 set_current_executor(ptr::null());
486 result
487 } else {
488 #[cfg(feature = "global")]
489 {
490 smol::block_on(f)
491 }
492 #[cfg(not(feature = "global"))]
493 unreachable!();
494 }
495 }
496}
497
498pub struct SmolInterval(Timer);
500
501impl TimeInterval for SmolInterval {
502 #[inline]
503 fn poll_tick(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Instant> {
504 let _self = self.get_mut();
505 match _self.0.poll_next(ctx) {
506 Poll::Ready(Some(i)) => Poll::Ready(i),
507 Poll::Ready(None) => unreachable!(),
508 Poll::Pending => Poll::Pending,
509 }
510 }
511}
512
513pub struct SmolFD<T: AsRawFd + AsFd + Send + Sync + 'static>(Async<T>);
515
516impl<T: AsRawFd + AsFd + Send + Sync + 'static> AsyncFd<T> for SmolFD<T> {
517 #[inline(always)]
518 async fn async_read<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
519 self.0.read_with(f).await
520 }
521
522 #[inline(always)]
523 async fn async_write<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
524 self.0.write_with(f).await
525 }
526}
527
528impl<T: AsRawFd + AsFd + Send + Sync + 'static> Deref for SmolFD<T> {
529 type Target = T;
530
531 #[inline(always)]
532 fn deref(&self) -> &Self::Target {
533 self.0.get_ref()
534 }
535}