1use async_executor::Executor;
38use async_io::{Async, Timer};
39use futures_lite::future::block_on;
40use futures_lite::stream::StreamExt;
41use orb::io::{AsyncFd, AsyncIO};
42use orb::runtime::{AsyncExec, AsyncJoinHandle};
43use orb::time::{AsyncTime, TimeInterval};
44use std::fmt;
45use std::future::Future;
46use std::io;
47use std::net::SocketAddr;
48use std::net::TcpStream;
49use std::ops::Deref;
50use std::os::fd::{AsFd, AsRawFd};
51use std::os::unix::net::UnixStream;
52use std::path::PathBuf;
53use std::pin::Pin;
54use std::sync::Arc;
55use std::task::*;
56use std::time::{Duration, Instant};
57
58#[derive(Clone)]
60pub struct SmolRT(Option<Arc<Executor<'static>>>);
61
62impl fmt::Debug for SmolRT {
63 #[inline]
64 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
65 if self.0.is_some() { write!(f, "smol") } else { write!(f, "smol(global)") }
66 }
67}
68
69impl SmolRT {
70 #[cfg(feature = "global")]
71 #[inline]
72 pub fn new_global() -> Self {
73 Self(None)
74 }
75
76 #[inline]
78 pub fn new(executor: Arc<Executor<'static>>) -> Self {
79 Self(Some(executor))
80 }
81}
82
83impl orb::AsyncRuntime for SmolRT {}
84
85impl AsyncIO for SmolRT {
86 type AsyncFd<T: AsRawFd + AsFd + Send + Sync + 'static> = SmolFD<T>;
87
88 #[inline(always)]
89 async fn connect_tcp(addr: &SocketAddr) -> io::Result<Self::AsyncFd<TcpStream>> {
90 let _addr = addr.clone();
91 let stream = Async::<TcpStream>::connect(_addr).await?;
92 Self::to_async_fd_rw(stream.into_inner()?)
94 }
95
96 #[inline(always)]
97 async fn connect_unix(addr: &PathBuf) -> io::Result<Self::AsyncFd<UnixStream>> {
98 let _addr = addr.clone();
99 let stream = Async::<UnixStream>::connect(_addr).await?;
100 Self::to_async_fd_rw(stream.into_inner()?)
102 }
103
104 #[inline(always)]
105 fn to_async_fd_rd<T: AsRawFd + AsFd + Send + Sync + 'static>(
106 fd: T,
107 ) -> io::Result<Self::AsyncFd<T>> {
108 Ok(SmolFD(Async::new(fd)?))
109 }
110
111 #[inline(always)]
112 fn to_async_fd_rw<T: AsRawFd + AsFd + Send + Sync + 'static>(
113 fd: T,
114 ) -> io::Result<Self::AsyncFd<T>> {
115 Ok(SmolFD(Async::new(fd)?))
116 }
117}
118
119impl AsyncTime for SmolRT {
120 type Interval = SmolInterval;
121
122 #[inline(always)]
123 fn sleep(d: Duration) -> impl Future + Send {
124 Timer::after(d)
125 }
126
127 #[inline(always)]
128 fn tick(d: Duration) -> Self::Interval {
129 let later = std::time::Instant::now() + d;
130 SmolInterval(Timer::interval_at(later, d))
131 }
132}
133
134pub struct SmolJoinHandle<T>(async_executor::Task<T>);
136
137impl<T: Send + 'static> AsyncJoinHandle<T> for SmolJoinHandle<T> {
138 #[inline]
139 async fn join(self) -> Result<T, ()> {
140 Ok(self.0.await)
141 }
142
143 #[inline]
144 fn detach(self) {
145 self.0.detach();
146 }
147}
148
149impl AsyncExec for SmolRT {
150 fn spawn<F, R>(&self, f: F) -> impl AsyncJoinHandle<R>
152 where
153 F: Future<Output = R> + Send + 'static,
154 R: Send + 'static,
155 {
156 let handle = match &self.0 {
157 Some(exec) => exec.spawn(f),
158 None => {
159 #[cfg(feature = "global")]
160 {
161 smol::spawn(f)
162 }
163 #[cfg(not(feature = "global"))]
164 unreachable!();
165 }
166 };
167 SmolJoinHandle(handle)
168 }
169
170 #[inline]
172 fn spawn_detach<F, R>(&self, f: F)
173 where
174 F: Future<Output = R> + Send + 'static,
175 R: Send + 'static,
176 {
177 self.spawn(f).detach();
178 }
179
180 #[inline]
181 fn spawn_blocking<F, R>(f: F) -> impl AsyncJoinHandle<R>
182 where
183 F: FnOnce() -> R + Send + 'static,
184 R: Send + 'static,
185 {
186 SmolJoinHandle(blocking::unblock(f))
187 }
188
189 #[inline]
194 fn block_on<F, R>(&self, f: F) -> R
195 where
196 F: Future<Output = R> + Send,
197 R: Send + 'static,
198 {
199 if let Some(exec) = &self.0 {
200 block_on(exec.run(f))
201 } else {
202 #[cfg(feature = "global")]
203 {
204 smol::block_on(f)
205 }
206 #[cfg(not(feature = "global"))]
207 unreachable!();
208 }
209 }
210}
211
212pub struct SmolInterval(Timer);
214
215impl TimeInterval for SmolInterval {
216 #[inline]
217 fn poll_tick(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Instant> {
218 let _self = self.get_mut();
219 match _self.0.poll_next(ctx) {
220 Poll::Ready(Some(i)) => Poll::Ready(i),
221 Poll::Ready(None) => unreachable!(),
222 Poll::Pending => Poll::Pending,
223 }
224 }
225}
226
227pub struct SmolFD<T: AsRawFd + AsFd + Send + Sync + 'static>(Async<T>);
229
230impl<T: AsRawFd + AsFd + Send + Sync + 'static> AsyncFd<T> for SmolFD<T> {
231 #[inline(always)]
232 async fn async_read<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
233 self.0.read_with(f).await
234 }
235
236 #[inline(always)]
237 async fn async_write<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
238 self.0.write_with(f).await
239 }
240}
241
242impl<T: AsRawFd + AsFd + Send + Sync + 'static> Deref for SmolFD<T> {
243 type Target = T;
244
245 #[inline(always)]
246 fn deref(&self) -> &Self::Target {
247 self.0.get_ref()
248 }
249}