1use async_executor::Executor;
42use async_io::{Async, Timer};
43use futures_lite::{future::block_on, stream::StreamExt};
44use orb::io::{AsyncFd, AsyncIO};
45use orb::runtime::{AsyncExec, AsyncHandle, ThreadHandle};
46use orb::time::{AsyncTime, TimeInterval};
47use std::fmt;
48use std::future::Future;
49use std::io;
50use std::net::SocketAddr;
51use std::net::TcpStream;
52use std::ops::Deref;
53use std::os::fd::{AsFd, AsRawFd};
54use std::os::unix::net::UnixStream;
55use std::path::PathBuf;
56use std::pin::Pin;
57use std::sync::Arc;
58use std::task::*;
59use std::time::{Duration, Instant};
60
61#[derive(Clone)]
63pub struct SmolRT(Option<Arc<Executor<'static>>>);
64
65impl fmt::Debug for SmolRT {
66 #[inline]
67 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
68 if self.0.is_some() { write!(f, "smol") } else { write!(f, "smol(global)") }
69 }
70}
71
72impl SmolRT {
73 #[cfg(feature = "global")]
74 #[inline]
75 pub fn new_global() -> Self {
76 Self(None)
77 }
78
79 #[inline]
81 pub fn new(executor: Arc<Executor<'static>>) -> Self {
82 Self(Some(executor))
83 }
84}
85
86impl orb::AsyncRuntime for SmolRT {}
87
88impl AsyncIO for SmolRT {
89 type AsyncFd<T: AsRawFd + AsFd + Send + Sync + 'static> = SmolFD<T>;
90
91 #[inline(always)]
92 async fn connect_tcp(addr: &SocketAddr) -> io::Result<Self::AsyncFd<TcpStream>> {
93 let _addr = addr.clone();
94 let stream = Async::<TcpStream>::connect(_addr).await?;
95 Self::to_async_fd_rw(stream.into_inner()?)
97 }
98
99 #[inline(always)]
100 async fn connect_unix(addr: &PathBuf) -> io::Result<Self::AsyncFd<UnixStream>> {
101 let _addr = addr.clone();
102 let stream = Async::<UnixStream>::connect(_addr).await?;
103 Self::to_async_fd_rw(stream.into_inner()?)
105 }
106
107 #[inline(always)]
108 fn to_async_fd_rd<T: AsRawFd + AsFd + Send + Sync + 'static>(
109 fd: T,
110 ) -> io::Result<Self::AsyncFd<T>> {
111 Ok(SmolFD(Async::new(fd)?))
112 }
113
114 #[inline(always)]
115 fn to_async_fd_rw<T: AsRawFd + AsFd + Send + Sync + 'static>(
116 fd: T,
117 ) -> io::Result<Self::AsyncFd<T>> {
118 Ok(SmolFD(Async::new(fd)?))
119 }
120}
121
122impl AsyncTime for SmolRT {
123 type Interval = SmolInterval;
124
125 #[inline(always)]
126 fn sleep(d: Duration) -> impl Future + Send {
127 Timer::after(d)
128 }
129
130 #[inline(always)]
131 fn tick(d: Duration) -> Self::Interval {
132 let later = std::time::Instant::now() + d;
133 SmolInterval(Timer::interval_at(later, d))
134 }
135}
136
137macro_rules! unwind_wrap {
138 ($f: expr) => {{
139 #[cfg(feature = "unwind")]
140 {
141 use futures_lite::future::FutureExt;
142 std::panic::AssertUnwindSafe($f).catch_unwind()
143 }
144 #[cfg(not(feature = "unwind"))]
145 $f
146 }};
147}
148
149#[cfg(feature = "unwind")]
151pub struct SmolJoinHandle<T>(
152 Option<async_executor::Task<Result<T, Box<dyn std::any::Any + Send>>>>,
153);
154#[cfg(not(feature = "unwind"))]
155pub struct SmolJoinHandle<T>(Option<async_executor::Task<T>>);
156
157impl<T: Send> AsyncHandle<T> for SmolJoinHandle<T> {
158 #[inline(always)]
159 fn abort(self) {
160 }
162
163 #[inline]
164 fn detach(mut self) {
165 self.0.take().unwrap().detach();
166 }
167
168 #[inline]
169 fn is_finished(&self) -> bool {
170 self.0.as_ref().unwrap().is_finished()
171 }
172}
173
174impl<T> Future for SmolJoinHandle<T> {
175 type Output = Result<T, ()>;
176
177 #[inline]
178 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
179 let _self = unsafe { self.get_unchecked_mut() };
180 if let Some(inner) = _self.0.as_mut() {
181 if let Poll::Ready(r) = Pin::new(inner).poll(cx) {
182 #[cfg(feature = "unwind")]
183 {
184 return Poll::Ready(r.map_err(|_e| ()));
185 }
186 #[cfg(not(feature = "unwind"))]
187 {
188 return Poll::Ready(Ok(r));
189 }
190 }
191 Poll::Pending
192 } else {
193 Poll::Ready(Err(()))
194 }
195 }
196}
197
198impl<T> Drop for SmolJoinHandle<T> {
199 fn drop(&mut self) {
200 if let Some(handle) = self.0.take() {
201 handle.detach();
202 }
203 }
204}
205
206pub struct BlockingJoinHandle<T>(async_executor::Task<T>);
207
208impl<T> ThreadHandle<T> for BlockingJoinHandle<T> {
209 #[inline]
210 fn is_finished(&self) -> bool {
211 self.0.is_finished()
212 }
213}
214
215impl<T> Future for BlockingJoinHandle<T> {
216 type Output = Result<T, ()>;
217
218 #[inline]
219 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
220 let _self = unsafe { self.get_unchecked_mut() };
221 if let Poll::Ready(r) = Pin::new(&mut _self.0).poll(cx) {
222 return Poll::Ready(Ok(r));
223 }
224 Poll::Pending
225 }
226}
227
228impl AsyncExec for SmolRT {
229 type AsyncHandle<R: Send> = SmolJoinHandle<R>;
230
231 type ThreadHandle<R: Send> = BlockingJoinHandle<R>;
232
233 fn spawn<F, R>(&self, f: F) -> Self::AsyncHandle<R>
235 where
236 F: Future<Output = R> + Send + 'static,
237 R: Send + 'static,
238 {
239 let handle = match &self.0 {
242 Some(exec) => exec.spawn(unwind_wrap!(f)),
243 None => {
244 #[cfg(feature = "global")]
245 {
246 smol::spawn(unwind_wrap!(f))
247 }
248 #[cfg(not(feature = "global"))]
249 unreachable!();
250 }
251 };
252 SmolJoinHandle(Some(handle))
253 }
254
255 #[inline]
257 fn spawn_detach<F, R>(&self, f: F)
258 where
259 F: Future<Output = R> + Send + 'static,
260 R: Send + 'static,
261 {
262 self.spawn(unwind_wrap!(f)).detach();
263 }
264
265 #[inline]
266 fn spawn_blocking<F, R>(f: F) -> Self::ThreadHandle<R>
267 where
268 F: FnOnce() -> R + Send + 'static,
269 R: Send + 'static,
270 {
271 BlockingJoinHandle(blocking::unblock(f))
272 }
273
274 #[inline]
279 fn block_on<F, R>(&self, f: F) -> R
280 where
281 F: Future<Output = R> + Send,
282 R: Send + 'static,
283 {
284 if let Some(exec) = &self.0 {
285 block_on(exec.run(f))
286 } else {
287 #[cfg(feature = "global")]
288 {
289 smol::block_on(f)
290 }
291 #[cfg(not(feature = "global"))]
292 unreachable!();
293 }
294 }
295}
296
297pub struct SmolInterval(Timer);
299
300impl TimeInterval for SmolInterval {
301 #[inline]
302 fn poll_tick(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Instant> {
303 let _self = self.get_mut();
304 match _self.0.poll_next(ctx) {
305 Poll::Ready(Some(i)) => Poll::Ready(i),
306 Poll::Ready(None) => unreachable!(),
307 Poll::Pending => Poll::Pending,
308 }
309 }
310}
311
312pub struct SmolFD<T: AsRawFd + AsFd + Send + Sync + 'static>(Async<T>);
314
315impl<T: AsRawFd + AsFd + Send + Sync + 'static> AsyncFd<T> for SmolFD<T> {
316 #[inline(always)]
317 async fn async_read<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
318 self.0.read_with(f).await
319 }
320
321 #[inline(always)]
322 async fn async_write<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
323 self.0.write_with(f).await
324 }
325}
326
327impl<T: AsRawFd + AsFd + Send + Sync + 'static> Deref for SmolFD<T> {
328 type Target = T;
329
330 #[inline(always)]
331 fn deref(&self) -> &Self::Target {
332 self.0.get_ref()
333 }
334}