1use async_executor::Executor;
38use async_io::{Async, Timer};
39use futures_lite::future::block_on;
40use futures_lite::stream::StreamExt;
41use orb::io::{AsyncFd, AsyncIO};
42use orb::runtime::{AsyncExec, AsyncJoinHandle};
43use orb::time::{AsyncTime, TimeInterval};
44use std::fmt;
45use std::future::Future;
46use std::io;
47use std::net::SocketAddr;
48use std::net::TcpStream;
49use std::ops::Deref;
50use std::os::fd::{AsFd, AsRawFd};
51use std::os::unix::net::UnixStream;
52use std::path::PathBuf;
53use std::pin::Pin;
54use std::sync::Arc;
55use std::task::*;
56use std::time::{Duration, Instant};
57
58#[derive(Clone)]
59pub struct SmolRT(Option<Arc<Executor<'static>>>);
60
61impl fmt::Debug for SmolRT {
62 #[inline]
63 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
64 if self.0.is_some() { write!(f, "smol") } else { write!(f, "smol(global)") }
65 }
66}
67
68impl SmolRT {
69 #[cfg(feature = "global")]
70 #[inline]
71 pub fn new_global() -> Self {
72 Self(None)
73 }
74
75 #[inline]
77 pub fn new(executor: Arc<Executor<'static>>) -> Self {
78 Self(Some(executor))
79 }
80}
81
82impl orb::AsyncRuntime for SmolRT {}
83
84impl AsyncIO for SmolRT {
85 type AsyncFd<T: AsRawFd + AsFd + Send + Sync + 'static> = SmolFD<T>;
86
87 #[inline(always)]
88 async fn connect_tcp(addr: &SocketAddr) -> io::Result<Self::AsyncFd<TcpStream>> {
89 let _addr = addr.clone();
90 let stream = Async::<TcpStream>::connect(_addr).await?;
91 Self::to_async_fd_rw(stream.into_inner()?)
93 }
94
95 #[inline(always)]
96 async fn connect_unix(addr: &PathBuf) -> io::Result<Self::AsyncFd<UnixStream>> {
97 let _addr = addr.clone();
98 let stream = Async::<UnixStream>::connect(_addr).await?;
99 Self::to_async_fd_rw(stream.into_inner()?)
101 }
102
103 #[inline(always)]
104 fn to_async_fd_rd<T: AsRawFd + AsFd + Send + Sync + 'static>(
105 fd: T,
106 ) -> io::Result<Self::AsyncFd<T>> {
107 Ok(SmolFD(Async::new(fd)?))
108 }
109
110 #[inline(always)]
111 fn to_async_fd_rw<T: AsRawFd + AsFd + Send + Sync + 'static>(
112 fd: T,
113 ) -> io::Result<Self::AsyncFd<T>> {
114 Ok(SmolFD(Async::new(fd)?))
115 }
116}
117
118impl AsyncTime for SmolRT {
119 type Interval = SmolInterval;
120
121 #[inline(always)]
122 fn sleep(d: Duration) -> impl Future + Send {
123 Timer::after(d)
124 }
125
126 #[inline(always)]
127 fn tick(d: Duration) -> Self::Interval {
128 let later = std::time::Instant::now() + d;
129 SmolInterval(Timer::interval_at(later, d))
130 }
131}
132
133pub struct SmolJoinHandle<T>(async_executor::Task<T>);
135
136impl<T: Send + 'static> AsyncJoinHandle<T> for SmolJoinHandle<T> {
137 #[inline]
138 async fn join(self) -> Result<T, ()> {
139 Ok(self.0.await)
140 }
141
142 #[inline]
143 fn detach(self) {
144 self.0.detach();
145 }
146}
147
148impl AsyncExec for SmolRT {
149 fn spawn<F, R>(&self, f: F) -> impl AsyncJoinHandle<R>
151 where
152 F: Future<Output = R> + Send + 'static,
153 R: Send + 'static,
154 {
155 let handle = match &self.0 {
156 Some(exec) => exec.spawn(f),
157 None => {
158 #[cfg(feature = "global")]
159 {
160 smol::spawn(f)
161 }
162 #[cfg(not(feature = "global"))]
163 unreachable!();
164 }
165 };
166 SmolJoinHandle(handle)
167 }
168
169 #[inline]
171 fn spawn_detach<F, R>(&self, f: F)
172 where
173 F: Future<Output = R> + Send + 'static,
174 R: Send + 'static,
175 {
176 self.spawn(f).detach();
177 }
178
179 #[inline]
184 fn block_on<F, R>(&self, f: F) -> R
185 where
186 F: Future<Output = R> + Send,
187 R: Send + 'static,
188 {
189 if let Some(exec) = &self.0 {
190 block_on(exec.run(f))
191 } else {
192 #[cfg(feature = "global")]
193 {
194 smol::block_on(f)
195 }
196 #[cfg(not(feature = "global"))]
197 unreachable!();
198 }
199 }
200}
201
202pub struct SmolInterval(Timer);
204
205impl TimeInterval for SmolInterval {
206 #[inline]
207 fn poll_tick(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Instant> {
208 let _self = self.get_mut();
209 match _self.0.poll_next(ctx) {
210 Poll::Ready(Some(i)) => Poll::Ready(i),
211 Poll::Ready(None) => unreachable!(),
212 Poll::Pending => Poll::Pending,
213 }
214 }
215}
216
217pub struct SmolFD<T: AsRawFd + AsFd + Send + Sync + 'static>(Async<T>);
219
220impl<T: AsRawFd + AsFd + Send + Sync + 'static> AsyncFd<T> for SmolFD<T> {
221 #[inline(always)]
222 async fn async_read<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
223 self.0.read_with(f).await
224 }
225
226 #[inline(always)]
227 async fn async_write<R>(&self, f: impl FnMut(&T) -> io::Result<R> + Send) -> io::Result<R> {
228 self.0.write_with(f).await
229 }
230}
231
232impl<T: AsRawFd + AsFd + Send + Sync + 'static> Deref for SmolFD<T> {
233 type Target = T;
234
235 #[inline(always)]
236 fn deref(&self) -> &Self::Target {
237 self.0.get_ref()
238 }
239}