1use std::future::Future;
13use std::pin::Pin;
14use std::task::{Context, Poll, Wake};
15use std::sync::{Arc, Mutex, Condvar, atomic::{AtomicBool, AtomicUsize, Ordering}};
16use std::collections::VecDeque;
17use std::thread;
18use std::time::{Duration, Instant};
19
20type Task = Pin<Box<dyn Future<Output = ()> + Send>>;
21
22pub struct JoinHandle<T> {
24 result: Arc<Mutex<Option<T>>>,
25 completed: Arc<AtomicBool>,
26}
27
28impl<T> JoinHandle<T> {
29 pub async fn await_result(self) -> Option<T> {
31 while !self.completed.load(Ordering::Acquire) {
32 yield_now().await;
33 }
34 self.result.lock().unwrap().take()
35 }
36}
37
38pub struct Runtime {
39 queue: Arc<Mutex<VecDeque<Task>>>,
40 shutdown: Arc<AtomicBool>,
41 task_count: Arc<AtomicUsize>,
42 condvar: Arc<Condvar>,
43}
44
45impl Runtime {
46 pub fn new() -> Self {
48 Self {
49 queue: Arc::new(Mutex::new(VecDeque::new())),
50 shutdown: Arc::new(AtomicBool::new(false)),
51 task_count: Arc::new(AtomicUsize::new(0)),
52 condvar: Arc::new(Condvar::new()),
53 }
54 }
55
56 pub fn task_count(&self) -> usize {
58 self.task_count.load(Ordering::Relaxed)
59 }
60
61 pub fn shutdown(&self) {
63 self.shutdown.store(true, Ordering::Release);
64 self.condvar.notify_all();
65 }
66
67 pub fn spawn<F>(&self, future: F)
69 where
70 F: Future<Output = ()> + Send + 'static,
71 {
72 self.task_count.fetch_add(1, Ordering::Relaxed);
73 let task_count = Arc::clone(&self.task_count);
74 let condvar = Arc::clone(&self.condvar);
75
76 let wrapped = async move {
77 future.await;
78 task_count.fetch_sub(1, Ordering::Relaxed);
79 condvar.notify_all();
80 };
81
82 let mut queue = self.queue.lock().unwrap();
83 queue.push_back(Box::pin(wrapped));
84 self.condvar.notify_one();
85 }
86
87 pub fn spawn_with_handle<F, T>(&self, future: F) -> JoinHandle<T>
89 where
90 F: Future<Output = T> + Send + 'static,
91 T: Send + 'static,
92 {
93 let result = Arc::new(Mutex::new(None));
94 let completed = Arc::new(AtomicBool::new(false));
95 let result_clone = Arc::clone(&result);
96 let completed_clone = Arc::clone(&completed);
97
98 let task = async move {
99 let output = future.await;
100 *result_clone.lock().unwrap() = Some(output);
101 completed_clone.store(true, Ordering::Release);
102 };
103
104 self.spawn(task);
105 JoinHandle { result, completed }
106 }
107
108 pub fn block_on<F, T>(&self, future: F) -> T
109 where
110 F: Future<Output = T> + Send + 'static,
111 T: Send + 'static,
112 {
113 let result = Arc::new(Mutex::new(None));
114 let result_clone = Arc::clone(&result);
115
116 let task = async move {
117 let output = future.await;
118 *result_clone.lock().unwrap() = Some(output);
119 };
120
121 self.spawn(Box::pin(task));
122 self.run();
123
124 Arc::try_unwrap(result)
125 .ok()
126 .and_then(|m| m.into_inner().ok())
127 .and_then(|opt| opt)
128 .expect("Task did not complete")
129 }
130
131 fn run(&self) {
132 let num_threads = std::thread::available_parallelism()
133 .map(|n| n.get())
134 .unwrap_or(4);
135
136 let mut handles = vec![];
137
138 for _ in 0..num_threads {
139 let queue = Arc::clone(&self.queue);
140 let shutdown = Arc::clone(&self.shutdown);
141 let task_count = Arc::clone(&self.task_count);
142 let condvar = Arc::clone(&self.condvar);
143
144 let handle = thread::spawn(move || {
145 let waker = Arc::new(RuntimeWaker { condvar: Arc::clone(&condvar) }).into();
146
147 loop {
148 if shutdown.load(Ordering::Acquire) && task_count.load(Ordering::Relaxed) == 0 {
149 break;
150 }
151
152 let task = {
153 let mut q = queue.lock().unwrap();
154 if q.is_empty() && !shutdown.load(Ordering::Acquire) {
155 q = condvar.wait_timeout(q, Duration::from_millis(100)).unwrap().0;
156 }
157 q.pop_front()
158 };
159
160 match task {
161 Some(mut task) => {
162 let mut context = Context::from_waker(&waker);
163 match task.as_mut().poll(&mut context) {
164 Poll::Ready(()) => {},
165 Poll::Pending => {
166 let mut q = queue.lock().unwrap();
167 q.push_back(task);
168 }
169 }
170 }
171 None if shutdown.load(Ordering::Acquire) => break,
172 None => {}
173 }
174 }
175 });
176 handles.push(handle);
177 }
178
179 for handle in handles {
180 let _ = handle.join();
181 }
182 }
183}
184
185impl Default for Runtime {
186 fn default() -> Self {
187 Self::new()
188 }
189}
190
191struct RuntimeWaker {
192 condvar: Arc<Condvar>,
193}
194
195impl Wake for RuntimeWaker {
196 fn wake(self: Arc<Self>) {
197 self.condvar.notify_one();
198 }
199
200 fn wake_by_ref(self: &Arc<Self>) {
201 self.condvar.notify_one();
202 }
203}
204
205pub fn spawn<F>(future: F)
207where
208 F: Future<Output = ()> + Send + 'static,
209{
210 RUNTIME.with(|rt| {
211 rt.borrow().spawn(future);
212 });
213}
214
215thread_local! {
216 static RUNTIME: std::cell::RefCell<Runtime> = std::cell::RefCell::new(Runtime::new());
217}
218
219#[macro_export]
221macro_rules! main {
222 ($($body:tt)*) => {
223 fn main() {
224 let rt = $crate::Runtime::new();
225 rt.block_on(async { $($body)* });
226 }
227 };
228}
229
230pub async fn yield_now() {
232 struct YieldNow {
233 yielded: bool,
234 }
235
236 impl Future for YieldNow {
237 type Output = ();
238
239 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
240 if self.yielded {
241 Poll::Ready(())
242 } else {
243 self.yielded = true;
244 cx.waker().wake_by_ref();
245 Poll::Pending
246 }
247 }
248 }
249
250 YieldNow { yielded: false }.await
251}
252
253pub async fn sleep(duration: Duration) {
255 struct Sleep {
256 when: std::time::Instant,
257 }
258
259 impl Future for Sleep {
260 type Output = ();
261
262 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
263 if std::time::Instant::now() >= self.when {
264 Poll::Ready(())
265 } else {
266 cx.waker().wake_by_ref();
267 Poll::Pending
268 }
269 }
270 }
271
272 Sleep {
273 when: std::time::Instant::now() + duration,
274 }
275 .await
276}
277
278pub async fn timeout<F, T>(duration: Duration, future: F) -> Result<T, TimeoutError>
280where
281 F: Future<Output = T>,
282{
283 struct Timeout<F> {
284 future: Pin<Box<F>>,
285 deadline: Instant,
286 }
287
288 impl<F: Future> Future for Timeout<F> {
289 type Output = Result<F::Output, TimeoutError>;
290
291 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
292 if Instant::now() >= self.deadline {
293 return Poll::Ready(Err(TimeoutError));
294 }
295
296 match self.future.as_mut().poll(cx) {
297 Poll::Ready(v) => Poll::Ready(Ok(v)),
298 Poll::Pending => {
299 cx.waker().wake_by_ref();
300 Poll::Pending
301 }
302 }
303 }
304 }
305
306 Timeout {
307 future: Box::pin(future),
308 deadline: Instant::now() + duration,
309 }
310 .await
311}
312
313#[derive(Debug, Clone, Copy, PartialEq, Eq)]
315pub struct TimeoutError;
316
317impl std::fmt::Display for TimeoutError {
318 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
319 write!(f, "operation timed out")
320 }
321}
322
323impl std::error::Error for TimeoutError {}
324
325pub mod channel {
327 use std::sync::{Arc, Mutex, Condvar};
328 use std::collections::VecDeque;
329
330 pub fn bounded<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
332 let inner = Arc::new(ChannelInner {
333 queue: Mutex::new(VecDeque::with_capacity(capacity)),
334 condvar: Condvar::new(),
335 capacity,
336 closed: Mutex::new(false),
337 });
338 (Sender { inner: inner.clone() }, Receiver { inner })
339 }
340
341 pub fn unbounded<T>() -> (Sender<T>, Receiver<T>) {
343 bounded(usize::MAX)
344 }
345
346 struct ChannelInner<T> {
347 queue: Mutex<VecDeque<T>>,
348 condvar: Condvar,
349 capacity: usize,
350 closed: Mutex<bool>,
351 }
352
353 pub struct Sender<T> {
355 inner: Arc<ChannelInner<T>>,
356 }
357
358 impl<T> Sender<T> {
359 pub async fn send(&self, value: T) -> Result<(), SendError<T>> {
361 if *self.inner.closed.lock().unwrap() {
362 return Err(SendError(value));
363 }
364
365 loop {
366 let mut queue = self.inner.queue.lock().unwrap();
367 if queue.len() < self.inner.capacity {
368 queue.push_back(value);
369 self.inner.condvar.notify_one();
370 return Ok(());
371 }
372 drop(queue);
373 let queue = self.inner.queue.lock().unwrap();
374 let _guard = self.inner.condvar.wait(queue).unwrap();
375 }
376 }
377 }
378
379 impl<T> Clone for Sender<T> {
380 fn clone(&self) -> Self {
381 Self { inner: self.inner.clone() }
382 }
383 }
384
385 impl<T> Drop for Sender<T> {
386 fn drop(&mut self) {
387 if Arc::strong_count(&self.inner) == 2 {
388 *self.inner.closed.lock().unwrap() = true;
389 self.inner.condvar.notify_all();
390 }
391 }
392 }
393
394 pub struct Receiver<T> {
396 inner: Arc<ChannelInner<T>>,
397 }
398
399 impl<T> Receiver<T> {
400 pub async fn recv(&self) -> Option<T> {
402 loop {
403 let mut queue = self.inner.queue.lock().unwrap();
404 if let Some(value) = queue.pop_front() {
405 self.inner.condvar.notify_one();
406 return Some(value);
407 }
408 if *self.inner.closed.lock().unwrap() && queue.is_empty() {
409 return None;
410 }
411 drop(queue);
412 let queue = self.inner.queue.lock().unwrap();
413 let _guard = self.inner.condvar.wait(queue).unwrap();
414 }
415 }
416 }
417
418 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
420 pub struct SendError<T>(pub T);
421
422 impl<T> std::fmt::Display for SendError<T> {
423 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424 write!(f, "channel closed")
425 }
426 }
427
428 impl<T: std::fmt::Debug> std::error::Error for SendError<T> {}
429}
430
431pub mod net {
433 use std::io;
434 use std::net::{TcpListener as StdListener, TcpStream as StdStream, SocketAddr};
435
436 pub struct TcpListener(StdListener);
437 pub struct TcpStream(StdStream);
438
439 impl TcpListener {
440 pub async fn bind(addr: SocketAddr) -> io::Result<Self> {
441 let listener = StdListener::bind(addr)?;
442 listener.set_nonblocking(true)?;
443 Ok(Self(listener))
444 }
445
446 pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
447 loop {
448 match self.0.accept() {
449 Ok((stream, addr)) => {
450 stream.set_nonblocking(true)?;
451 return Ok((TcpStream(stream), addr));
452 }
453 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
454 crate::sleep(std::time::Duration::from_millis(10)).await;
455 }
456 Err(e) => return Err(e),
457 }
458 }
459 }
460 }
461
462 impl TcpStream {
463 pub async fn connect(addr: SocketAddr) -> io::Result<Self> {
464 let stream = StdStream::connect(addr)?;
465 stream.set_nonblocking(true)?;
466 Ok(Self(stream))
467 }
468
469 pub fn into_std(self) -> StdStream {
470 self.0
471 }
472
473 pub fn as_std(&self) -> &StdStream {
474 &self.0
475 }
476
477 pub async fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
479 use std::io::Read;
480 loop {
481 match self.0.read(buf) {
482 Ok(n) => return Ok(n),
483 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
484 crate::sleep(std::time::Duration::from_millis(1)).await;
485 }
486 Err(e) => return Err(e),
487 }
488 }
489 }
490
491 pub async fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
493 use std::io::Write;
494 loop {
495 match self.0.write(buf) {
496 Ok(n) => return Ok(n),
497 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
498 crate::sleep(std::time::Duration::from_millis(1)).await;
499 }
500 Err(e) => return Err(e),
501 }
502 }
503 }
504
505 pub async fn write_all(&mut self, mut buf: &[u8]) -> io::Result<()> {
507 while !buf.is_empty() {
508 let n = self.write(buf).await?;
509 buf = &buf[n..];
510 }
511 Ok(())
512 }
513 }
514}
515
516pub mod io {
518 use std::io::{self, Read, Write};
519
520 pub async fn copy<R: Read, W: Write>(reader: &mut R, writer: &mut W) -> io::Result<u64> {
521 let mut buf = [0u8; 8192];
522 let mut total = 0u64;
523
524 loop {
525 match reader.read(&mut buf) {
526 Ok(0) => return Ok(total),
527 Ok(n) => {
528 writer.write_all(&buf[..n])?;
529 total += n as u64;
530 }
531 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
532 crate::sleep(std::time::Duration::from_millis(1)).await;
533 }
534 Err(e) => return Err(e),
535 }
536 }
537 }
538}