1use async_executor::Executor;
42use async_io::{Async, Timer};
43use futures_lite::{future::block_on, stream::StreamExt};
44use orb::io::{AsyncFd, AsyncIO};
45use orb::runtime::{AsyncExec, AsyncJoinHandle, ThreadJoinHandle};
46use orb::time::{AsyncTime, TimeInterval};
47use std::fmt;
48use std::future::Future;
49use std::io;
50use std::net::SocketAddr;
51use std::net::TcpStream;
52use std::ops::Deref;
53use std::os::fd::{AsFd, AsRawFd};
54use std::os::unix::net::UnixStream;
55use std::path::PathBuf;
56use std::pin::Pin;
57use std::sync::Arc;
58use std::task::*;
59use std::time::{Duration, Instant};
60
61#[derive(Clone)]
63pub struct SmolRT(Option<Arc<Executor<'static>>>);
64
65impl fmt::Debug for SmolRT {
66 #[inline]
67 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
68 if self.0.is_some() { write!(f, "smol") } else { write!(f, "smol(global)") }
69 }
70}
71
72impl SmolRT {
73 #[cfg(feature = "global")]
74 #[inline]
75 pub fn new_global() -> Self {
76 Self(None)
77 }
78
79 #[inline]
81 pub fn new(executor: Arc<Executor<'static>>) -> Self {
82 Self(Some(executor))
83 }
84}
85
86impl orb::AsyncRuntime for SmolRT {}
87
88impl AsyncIO for SmolRT {
89 type AsyncFd<T: AsRawFd + AsFd + Send + Sync + 'static> = SmolFD<T>;
90
91 #[inline(always)]
92 async fn connect_tcp(addr: &SocketAddr) -> io::Result<Self::AsyncFd<TcpStream>> {
93 let _addr = addr.clone();
94 let stream = Async::<TcpStream>::connect(_addr).await?;
95 Self::to_async_fd_rw(stream.into_inner()?)
97 }
98
99 #[inline(always)]
100 async fn connect_unix(addr: &PathBuf) -> io::Result<Self::AsyncFd<UnixStream>> {
101 let _addr = addr.clone();
102 let stream = Async::<UnixStream>::connect(_addr).await?;
103 Self::to_async_fd_rw(stream.into_inner()?)
105 }
106
107 #[inline(always)]
108 fn to_async_fd_rd<T: AsRawFd + AsFd + Send + Sync + 'static>(
109 fd: T,
110 ) -> io::Result<Self::AsyncFd<T>> {
111 Ok(SmolFD(Async::new(fd)?))
112 }
113
114 #[inline(always)]
115 fn to_async_fd_rw<T: AsRawFd + AsFd + Send + Sync + 'static>(
116 fd: T,
117 ) -> io::Result<Self::AsyncFd<T>> {
118 Ok(SmolFD(Async::new(fd)?))
119 }
120}
121
122impl AsyncTime for SmolRT {
123 type Interval = SmolInterval;
124
125 #[inline(always)]
126 fn sleep(d: Duration) -> impl Future + Send {
127 Timer::after(d)
128 }
129
130 #[inline(always)]
131 fn tick(d: Duration) -> Self::Interval {
132 let later = std::time::Instant::now() + d;
133 SmolInterval(Timer::interval_at(later, d))
134 }
135}
136
137macro_rules! unwind_wrap {
138 ($f: expr) => {{
139 #[cfg(feature = "unwind")]
140 {
141 use futures_lite::future::FutureExt;
142 std::panic::AssertUnwindSafe($f).catch_unwind()
143 }
144 #[cfg(not(feature = "unwind"))]
145 $f
146 }};
147}
148
149#[cfg(feature = "unwind")]
151pub struct SmolJoinHandle<T>(
152 Option<async_executor::Task<Result<T, Box<dyn std::any::Any + Send>>>>,
153);
154#[cfg(not(feature = "unwind"))]
155pub struct SmolJoinHandle<T>(Option<async_executor::Task<T>>);
156
157impl<T: Send + 'static> AsyncJoinHandle<T> for SmolJoinHandle<T> {
158 #[inline(always)]
159 fn abort(self) {
160 }
162
163 #[inline]
164 fn detach(mut self) {
165 self.0.take().unwrap().detach();
166 }
167
168 #[inline]
169 fn is_finished(&self) -> bool {
170 self.0.as_ref().unwrap().is_finished()
171 }
172}
173
174impl<T: Send + 'static> Future for SmolJoinHandle<T> {
175 type Output = Result<T, ()>;
176
177 #[inline]
178 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
179 let _self = unsafe { self.get_unchecked_mut() };
180 if let Some(inner) = _self.0.as_mut() {
181 if let Poll::Ready(r) = Pin::new(inner).poll(cx) {
182 #[cfg(feature = "unwind")]
183 {
184 return Poll::Ready(r.map_err(|_e| ()));
185 }
186 #[cfg(not(feature = "unwind"))]
187 {
188 return Poll::Ready(Ok(r));
189 }
190 }
191 Poll::Pending
192 } else {
193 Poll::Ready(Err(()))
194 }
195 }
196}
197
198impl<T> Drop for SmolJoinHandle<T> {
199 fn drop(&mut self) {
200 if let Some(handle) = self.0.take() {
201 handle.detach();
202 }
203 }
204}
205
206pub struct BlockingJoinHandle<T>(async_executor::Task<T>);
207
208impl<T: Send + 'static> ThreadJoinHandle<T> for BlockingJoinHandle<T> {
209 #[inline]
210 fn is_finished(&self) -> bool {
211 self.0.is_finished()
212 }
213}
214
215impl<T: Send + 'static> Future for BlockingJoinHandle<T> {
216 type Output = Result<T, ()>;
217
218 #[inline]
219 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
220 let _self = unsafe { self.get_unchecked_mut() };
221 if let Poll::Ready(r) = Pin::new(&mut _self.0).poll(cx) {
222 return Poll::Ready(Ok(r));
223 }
224 Poll::Pending
225 }
226}
227
228impl AsyncExec for SmolRT {
229 fn spawn<F, R>(&self, f: F) -> impl AsyncJoinHandle<R>
231 where
232 F: Future<Output = R> + Send + 'static,
233 R: Send + 'static,
234 {
235 let handle = match &self.0 {
236 Some(exec) => exec.spawn(unwind_wrap!(f)),
237 None => {
238 #[cfg(feature = "global")]
239 {
240 smol::spawn(unwind_wrap!(f))
241 }
242 #[cfg(not(feature = "global"))]
243 unreachable!();
244 }
245 };
246 SmolJoinHandle(Some(handle))
247 }
248
249 #[inline]
251 fn spawn_detach<F, R>(&self, f: F)
252 where
253 F: Future<Output = R> + Send + 'static,
254 R: Send + 'static,
255 {
256 self.spawn(unwind_wrap!(f)).detach();
257 }
258
259 #[inline]
260 fn spawn_blocking<F, R>(f: F) -> impl ThreadJoinHandle<R>
261 where
262 F: FnOnce() -> R + Send + 'static,
263 R: Send + 'static,
264 {
265 BlockingJoinHandle(blocking::unblock(f))
266 }
267
268 #[inline]
273 fn block_on<F, R>(&self, f: F) -> R
274 where
275 F: Future<Output = R> + Send,
276 R: Send + 'static,
277 {
278 if let Some(exec) = &self.0 {
279 block_on(exec.run(f))
280 } else {
281 #[cfg(feature = "global")]
282 {
283 smol::block_on(f)
284 }
285 #[cfg(not(feature = "global"))]
286 unreachable!();
287 }
288 }
289}
290
291pub struct SmolInterval(Timer);
293
294impl TimeInterval for SmolInterval {
295 #[inline]
296 fn poll_tick(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Instant> {
297 let _self = self.get_mut();
298 match _self.0.poll_next(ctx) {
299 Poll::Ready(Some(i)) => Poll::Ready(i),
300 Poll::Ready(None) => unreachable!(),
301 Poll::Pending => Poll::Pending,
302 }
303 }
304}
305
306pub struct SmolFD<T: AsRawFd + AsFd + Send + Sync + 'static>(Async<T>);
308
309impl<T: AsRawFd + AsFd + Send + Sync + 'static> AsyncFd<T> for SmolFD<T> {
310 #[inline(always)]
311 async fn async_read<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
312 self.0.read_with(f).await
313 }
314
315 #[inline(always)]
316 async fn async_write<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
317 self.0.write_with(f).await
318 }
319}
320
321impl<T: AsRawFd + AsFd + Send + Sync + 'static> Deref for SmolFD<T> {
322 type Target = T;
323
324 #[inline(always)]
325 fn deref(&self) -> &Self::Target {
326 self.0.get_ref()
327 }
328}