1use async_executor::Executor;
42use async_io::{Async, Timer};
43use futures_lite::{future::block_on, stream::StreamExt};
44use orb::io::{AsyncFd, AsyncIO};
45use orb::runtime::{AsyncExec, AsyncHandle, ThreadHandle};
46use orb::time::{AsyncTime, TimeInterval};
47use std::fmt;
48use std::future::Future;
49use std::io;
50use std::net::SocketAddr;
51use std::net::TcpStream;
52use std::ops::Deref;
53use std::os::fd::{AsFd, AsRawFd};
54use std::os::unix::net::UnixStream;
55use std::path::Path;
56use std::pin::Pin;
57use std::sync::Arc;
58use std::task::*;
59use std::time::{Duration, Instant};
60
61#[derive(Clone)]
63pub struct SmolRT(Option<Arc<Executor<'static>>>);
64
65impl fmt::Debug for SmolRT {
66 #[inline]
67 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
68 if self.0.is_some() { write!(f, "smol") } else { write!(f, "smol(global)") }
69 }
70}
71
72impl SmolRT {
73 #[cfg(feature = "global")]
74 #[inline]
75 pub fn new_global() -> Self {
76 Self(None)
77 }
78
79 #[inline]
81 pub fn new(executor: Arc<Executor<'static>>) -> Self {
82 Self(Some(executor))
83 }
84}
85
86impl orb::AsyncRuntime for SmolRT {}
87
88impl AsyncIO for SmolRT {
89 type AsyncFd<T: AsRawFd + AsFd + Send + Sync + 'static> = SmolFD<T>;
90
91 #[inline(always)]
92 async fn connect_tcp(addr: &SocketAddr) -> io::Result<Self::AsyncFd<TcpStream>> {
93 let _addr = addr.clone();
94 let stream = Async::<TcpStream>::connect(_addr).await?;
95 Self::to_async_fd_rw(stream.into_inner()?)
97 }
98
99 #[inline(always)]
100 async fn connect_unix(addr: &Path) -> io::Result<Self::AsyncFd<UnixStream>> {
101 let stream = Async::<UnixStream>::connect(addr).await?;
102 Self::to_async_fd_rw(stream.into_inner()?)
104 }
105
106 #[inline(always)]
107 fn to_async_fd_rd<T: AsRawFd + AsFd + Send + Sync + 'static>(
108 fd: T,
109 ) -> io::Result<Self::AsyncFd<T>> {
110 Ok(SmolFD(Async::new(fd)?))
111 }
112
113 #[inline(always)]
114 fn to_async_fd_rw<T: AsRawFd + AsFd + Send + Sync + 'static>(
115 fd: T,
116 ) -> io::Result<Self::AsyncFd<T>> {
117 Ok(SmolFD(Async::new(fd)?))
118 }
119}
120
121impl AsyncTime for SmolRT {
122 type Interval = SmolInterval;
123
124 #[inline(always)]
125 fn sleep(d: Duration) -> impl Future + Send {
126 Timer::after(d)
127 }
128
129 #[inline(always)]
130 fn interval(d: Duration) -> Self::Interval {
131 let later = std::time::Instant::now() + d;
132 SmolInterval(Timer::interval_at(later, d))
133 }
134}
135
136macro_rules! unwind_wrap {
137 ($f: expr) => {{
138 #[cfg(feature = "unwind")]
139 {
140 use futures_lite::future::FutureExt;
141 std::panic::AssertUnwindSafe($f).catch_unwind()
142 }
143 #[cfg(not(feature = "unwind"))]
144 $f
145 }};
146}
147
148#[cfg(feature = "unwind")]
150pub struct SmolJoinHandle<T>(
151 Option<async_executor::Task<Result<T, Box<dyn std::any::Any + Send>>>>,
152);
153#[cfg(not(feature = "unwind"))]
154pub struct SmolJoinHandle<T>(Option<async_executor::Task<T>>);
155
156impl<T: Send> AsyncHandle<T> for SmolJoinHandle<T> {
157 #[inline(always)]
158 fn abort(self) {
159 }
161
162 #[inline]
163 fn detach(mut self) {
164 self.0.take().unwrap().detach();
165 }
166
167 #[inline]
168 fn is_finished(&self) -> bool {
169 self.0.as_ref().unwrap().is_finished()
170 }
171}
172
173impl<T> Future for SmolJoinHandle<T> {
174 type Output = Result<T, ()>;
175
176 #[inline]
177 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
178 let _self = unsafe { self.get_unchecked_mut() };
179 if let Some(inner) = _self.0.as_mut() {
180 if let Poll::Ready(r) = Pin::new(inner).poll(cx) {
181 #[cfg(feature = "unwind")]
182 {
183 return Poll::Ready(r.map_err(|_e| ()));
184 }
185 #[cfg(not(feature = "unwind"))]
186 {
187 return Poll::Ready(Ok(r));
188 }
189 }
190 Poll::Pending
191 } else {
192 Poll::Ready(Err(()))
193 }
194 }
195}
196
197impl<T> Drop for SmolJoinHandle<T> {
198 fn drop(&mut self) {
199 if let Some(handle) = self.0.take() {
200 handle.detach();
201 }
202 }
203}
204
205pub struct BlockingJoinHandle<T>(async_executor::Task<T>);
206
207impl<T> ThreadHandle<T> for BlockingJoinHandle<T> {
208 #[inline]
209 fn is_finished(&self) -> bool {
210 self.0.is_finished()
211 }
212}
213
214impl<T> Future for BlockingJoinHandle<T> {
215 type Output = Result<T, ()>;
216
217 #[inline]
218 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
219 let _self = unsafe { self.get_unchecked_mut() };
220 if let Poll::Ready(r) = Pin::new(&mut _self.0).poll(cx) {
221 return Poll::Ready(Ok(r));
222 }
223 Poll::Pending
224 }
225}
226
227impl AsyncExec for SmolRT {
228 type AsyncHandle<R: Send> = SmolJoinHandle<R>;
229
230 type ThreadHandle<R: Send> = BlockingJoinHandle<R>;
231
232 fn spawn<F, R>(&self, f: F) -> Self::AsyncHandle<R>
234 where
235 F: Future<Output = R> + Send + 'static,
236 R: Send + 'static,
237 {
238 let handle = match &self.0 {
241 Some(exec) => exec.spawn(unwind_wrap!(f)),
242 None => {
243 #[cfg(feature = "global")]
244 {
245 smol::spawn(unwind_wrap!(f))
246 }
247 #[cfg(not(feature = "global"))]
248 unreachable!();
249 }
250 };
251 SmolJoinHandle(Some(handle))
252 }
253
254 #[inline]
256 fn spawn_detach<F, R>(&self, f: F)
257 where
258 F: Future<Output = R> + Send + 'static,
259 R: Send + 'static,
260 {
261 self.spawn(unwind_wrap!(f)).detach();
262 }
263
264 #[inline]
265 fn spawn_blocking<F, R>(f: F) -> Self::ThreadHandle<R>
266 where
267 F: FnOnce() -> R + Send + 'static,
268 R: Send + 'static,
269 {
270 BlockingJoinHandle(blocking::unblock(f))
271 }
272
273 #[inline]
278 fn block_on<F, R>(&self, f: F) -> R
279 where
280 F: Future<Output = R> + Send,
281 R: Send + 'static,
282 {
283 if let Some(exec) = &self.0 {
284 block_on(exec.run(f))
285 } else {
286 #[cfg(feature = "global")]
287 {
288 smol::block_on(f)
289 }
290 #[cfg(not(feature = "global"))]
291 unreachable!();
292 }
293 }
294}
295
296pub struct SmolInterval(Timer);
298
299impl TimeInterval for SmolInterval {
300 #[inline]
301 fn poll_tick(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Instant> {
302 let _self = self.get_mut();
303 match _self.0.poll_next(ctx) {
304 Poll::Ready(Some(i)) => Poll::Ready(i),
305 Poll::Ready(None) => unreachable!(),
306 Poll::Pending => Poll::Pending,
307 }
308 }
309}
310
311pub struct SmolFD<T: AsRawFd + AsFd + Send + Sync + 'static>(Async<T>);
313
314impl<T: AsRawFd + AsFd + Send + Sync + 'static> AsyncFd<T> for SmolFD<T> {
315 #[inline(always)]
316 async fn async_read<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
317 self.0.read_with(f).await
318 }
319
320 #[inline(always)]
321 async fn async_write<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
322 self.0.write_with(f).await
323 }
324}
325
326impl<T: AsRawFd + AsFd + Send + Sync + 'static> Deref for SmolFD<T> {
327 type Target = T;
328
329 #[inline(always)]
330 fn deref(&self) -> &Self::Target {
331 self.0.get_ref()
332 }
333}