avila_async/
lib.rs

1//! Avila Async - Native async runtime
2//! Tokio alternative - 100% Rust std
3//!
4//! # Features
5//! - Work-stealing scheduler
6//! - Multi-threaded executor
7//! - Async I/O primitives
8//! - Channel support
9//! - Task cancellation
10//! - Graceful shutdown
11
12use std::future::Future;
13use std::pin::Pin;
14use std::task::{Context, Poll, Wake};
15use std::sync::{Arc, Mutex, Condvar, atomic::{AtomicBool, AtomicUsize, Ordering}};
16use std::collections::VecDeque;
17use std::thread;
18use std::time::{Duration, Instant};
19
20type Task = Pin<Box<dyn Future<Output = ()> + Send>>;
21
22/// Task handle for spawned futures
23pub struct JoinHandle<T> {
24    result: Arc<Mutex<Option<T>>>,
25    completed: Arc<AtomicBool>,
26}
27
28impl<T> JoinHandle<T> {
29    /// Wait for the task to complete and return its result
30    pub async fn await_result(self) -> Option<T> {
31        while !self.completed.load(Ordering::Acquire) {
32            yield_now().await;
33        }
34        self.result.lock().unwrap().take()
35    }
36}
37
38pub struct Runtime {
39    queue: Arc<Mutex<VecDeque<Task>>>,
40    shutdown: Arc<AtomicBool>,
41    task_count: Arc<AtomicUsize>,
42    condvar: Arc<Condvar>,
43}
44
45impl Runtime {
46    /// Create a new runtime instance
47    pub fn new() -> Self {
48        Self {
49            queue: Arc::new(Mutex::new(VecDeque::new())),
50            shutdown: Arc::new(AtomicBool::new(false)),
51            task_count: Arc::new(AtomicUsize::new(0)),
52            condvar: Arc::new(Condvar::new()),
53        }
54    }
55
56    /// Get the number of active tasks
57    pub fn task_count(&self) -> usize {
58        self.task_count.load(Ordering::Relaxed)
59    }
60
61    /// Initiate graceful shutdown
62    pub fn shutdown(&self) {
63        self.shutdown.store(true, Ordering::Release);
64        self.condvar.notify_all();
65    }
66
67    /// Spawn a future onto the runtime
68    pub fn spawn<F>(&self, future: F)
69    where
70        F: Future<Output = ()> + Send + 'static,
71    {
72        self.task_count.fetch_add(1, Ordering::Relaxed);
73        let task_count = Arc::clone(&self.task_count);
74        let condvar = Arc::clone(&self.condvar);
75
76        let wrapped = async move {
77            future.await;
78            task_count.fetch_sub(1, Ordering::Relaxed);
79            condvar.notify_all();
80        };
81
82        let mut queue = self.queue.lock().unwrap();
83        queue.push_back(Box::pin(wrapped));
84        self.condvar.notify_one();
85    }
86
87    /// Spawn a future and return a handle to await its result
88    pub fn spawn_with_handle<F, T>(&self, future: F) -> JoinHandle<T>
89    where
90        F: Future<Output = T> + Send + 'static,
91        T: Send + 'static,
92    {
93        let result = Arc::new(Mutex::new(None));
94        let completed = Arc::new(AtomicBool::new(false));
95        let result_clone = Arc::clone(&result);
96        let completed_clone = Arc::clone(&completed);
97
98        let task = async move {
99            let output = future.await;
100            *result_clone.lock().unwrap() = Some(output);
101            completed_clone.store(true, Ordering::Release);
102        };
103
104        self.spawn(task);
105        JoinHandle { result, completed }
106    }
107
108    pub fn block_on<F, T>(&self, future: F) -> T
109    where
110        F: Future<Output = T> + Send + 'static,
111        T: Send + 'static,
112    {
113        let result = Arc::new(Mutex::new(None));
114        let result_clone = Arc::clone(&result);
115
116        let task = async move {
117            let output = future.await;
118            *result_clone.lock().unwrap() = Some(output);
119        };
120
121        self.spawn(Box::pin(task));
122        self.run();
123
124        Arc::try_unwrap(result)
125            .ok()
126            .and_then(|m| m.into_inner().ok())
127            .and_then(|opt| opt)
128            .expect("Task did not complete")
129    }
130
131    fn run(&self) {
132        let num_threads = std::thread::available_parallelism()
133            .map(|n| n.get())
134            .unwrap_or(4);
135
136        let mut handles = vec![];
137
138        for _ in 0..num_threads {
139            let queue = Arc::clone(&self.queue);
140            let shutdown = Arc::clone(&self.shutdown);
141            let task_count = Arc::clone(&self.task_count);
142            let condvar = Arc::clone(&self.condvar);
143
144            let handle = thread::spawn(move || {
145                let waker = Arc::new(RuntimeWaker { condvar: Arc::clone(&condvar) }).into();
146
147                loop {
148                    if shutdown.load(Ordering::Acquire) && task_count.load(Ordering::Relaxed) == 0 {
149                        break;
150                    }
151
152                    let task = {
153                        let mut q = queue.lock().unwrap();
154                        if q.is_empty() && !shutdown.load(Ordering::Acquire) {
155                            q = condvar.wait_timeout(q, Duration::from_millis(100)).unwrap().0;
156                        }
157                        q.pop_front()
158                    };
159
160                    match task {
161                        Some(mut task) => {
162                            let mut context = Context::from_waker(&waker);
163                            match task.as_mut().poll(&mut context) {
164                                Poll::Ready(()) => {},
165                                Poll::Pending => {
166                                    let mut q = queue.lock().unwrap();
167                                    q.push_back(task);
168                                }
169                            }
170                        }
171                        None if shutdown.load(Ordering::Acquire) => break,
172                        None => {}
173                    }
174                }
175            });
176            handles.push(handle);
177        }
178
179        for handle in handles {
180            let _ = handle.join();
181        }
182    }
183}
184
185impl Default for Runtime {
186    fn default() -> Self {
187        Self::new()
188    }
189}
190
191struct RuntimeWaker {
192    condvar: Arc<Condvar>,
193}
194
195impl Wake for RuntimeWaker {
196    fn wake(self: Arc<Self>) {
197        self.condvar.notify_one();
198    }
199
200    fn wake_by_ref(self: &Arc<Self>) {
201        self.condvar.notify_one();
202    }
203}
204
205// Global helper function
206pub fn spawn<F>(future: F)
207where
208    F: Future<Output = ()> + Send + 'static,
209{
210    RUNTIME.with(|rt| {
211        rt.borrow().spawn(future);
212    });
213}
214
215thread_local! {
216    static RUNTIME: std::cell::RefCell<Runtime> = std::cell::RefCell::new(Runtime::new());
217}
218
219// Macro for async main
220#[macro_export]
221macro_rules! main {
222    ($($body:tt)*) => {
223        fn main() {
224            let rt = $crate::Runtime::new();
225            rt.block_on(async { $($body)* });
226        }
227    };
228}
229
230/// Yield execution to allow other tasks to run
231pub async fn yield_now() {
232    struct YieldNow {
233        yielded: bool,
234    }
235
236    impl Future for YieldNow {
237        type Output = ();
238
239        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
240            if self.yielded {
241                Poll::Ready(())
242            } else {
243                self.yielded = true;
244                cx.waker().wake_by_ref();
245                Poll::Pending
246            }
247        }
248    }
249
250    YieldNow { yielded: false }.await
251}
252
253/// Sleep for a specified duration
254pub async fn sleep(duration: Duration) {
255    struct Sleep {
256        when: std::time::Instant,
257    }
258
259    impl Future for Sleep {
260        type Output = ();
261
262        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
263            if std::time::Instant::now() >= self.when {
264                Poll::Ready(())
265            } else {
266                cx.waker().wake_by_ref();
267                Poll::Pending
268            }
269        }
270    }
271
272    Sleep {
273        when: std::time::Instant::now() + duration,
274    }
275    .await
276}
277
278/// Execute a future with a timeout
279pub async fn timeout<F, T>(duration: Duration, future: F) -> Result<T, TimeoutError>
280where
281    F: Future<Output = T>,
282{
283    struct Timeout<F> {
284        future: Pin<Box<F>>,
285        deadline: Instant,
286    }
287
288    impl<F: Future> Future for Timeout<F> {
289        type Output = Result<F::Output, TimeoutError>;
290
291        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
292            if Instant::now() >= self.deadline {
293                return Poll::Ready(Err(TimeoutError));
294            }
295
296            match self.future.as_mut().poll(cx) {
297                Poll::Ready(v) => Poll::Ready(Ok(v)),
298                Poll::Pending => {
299                    cx.waker().wake_by_ref();
300                    Poll::Pending
301                }
302            }
303        }
304    }
305
306    Timeout {
307        future: Box::pin(future),
308        deadline: Instant::now() + duration,
309    }
310    .await
311}
312
313/// Timeout error type
314#[derive(Debug, Clone, Copy, PartialEq, Eq)]
315pub struct TimeoutError;
316
317impl std::fmt::Display for TimeoutError {
318    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
319        write!(f, "operation timed out")
320    }
321}
322
323impl std::error::Error for TimeoutError {}
324
325/// Async channel for message passing
326pub mod channel {
327    use std::sync::{Arc, Mutex, Condvar};
328    use std::collections::VecDeque;
329
330    /// Create a bounded channel with specified capacity
331    pub fn bounded<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
332        let inner = Arc::new(ChannelInner {
333            queue: Mutex::new(VecDeque::with_capacity(capacity)),
334            condvar: Condvar::new(),
335            capacity,
336            closed: Mutex::new(false),
337        });
338        (Sender { inner: inner.clone() }, Receiver { inner })
339    }
340
341    /// Create an unbounded channel
342    pub fn unbounded<T>() -> (Sender<T>, Receiver<T>) {
343        bounded(usize::MAX)
344    }
345
346    struct ChannelInner<T> {
347        queue: Mutex<VecDeque<T>>,
348        condvar: Condvar,
349        capacity: usize,
350        closed: Mutex<bool>,
351    }
352
353    /// Sender half of a channel
354    pub struct Sender<T> {
355        inner: Arc<ChannelInner<T>>,
356    }
357
358    impl<T> Sender<T> {
359        /// Send a value through the channel
360        pub async fn send(&self, value: T) -> Result<(), SendError<T>> {
361            if *self.inner.closed.lock().unwrap() {
362                return Err(SendError(value));
363            }
364
365            loop {
366                let mut queue = self.inner.queue.lock().unwrap();
367                if queue.len() < self.inner.capacity {
368                    queue.push_back(value);
369                    self.inner.condvar.notify_one();
370                    return Ok(());
371                }
372                drop(queue);
373                let queue = self.inner.queue.lock().unwrap();
374                let _guard = self.inner.condvar.wait(queue).unwrap();
375            }
376        }
377    }
378
379    impl<T> Clone for Sender<T> {
380        fn clone(&self) -> Self {
381            Self { inner: self.inner.clone() }
382        }
383    }
384
385    impl<T> Drop for Sender<T> {
386        fn drop(&mut self) {
387            if Arc::strong_count(&self.inner) == 2 {
388                *self.inner.closed.lock().unwrap() = true;
389                self.inner.condvar.notify_all();
390            }
391        }
392    }
393
394    /// Receiver half of a channel
395    pub struct Receiver<T> {
396        inner: Arc<ChannelInner<T>>,
397    }
398
399    impl<T> Receiver<T> {
400        /// Receive a value from the channel
401        pub async fn recv(&self) -> Option<T> {
402            loop {
403                let mut queue = self.inner.queue.lock().unwrap();
404                if let Some(value) = queue.pop_front() {
405                    self.inner.condvar.notify_one();
406                    return Some(value);
407                }
408                if *self.inner.closed.lock().unwrap() && queue.is_empty() {
409                    return None;
410                }
411                drop(queue);
412                let queue = self.inner.queue.lock().unwrap();
413                let _guard = self.inner.condvar.wait(queue).unwrap();
414            }
415        }
416    }
417
418    /// Error returned when sending fails
419    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
420    pub struct SendError<T>(pub T);
421
422    impl<T> std::fmt::Display for SendError<T> {
423        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424            write!(f, "channel closed")
425        }
426    }
427
428    impl<T: std::fmt::Debug> std::error::Error for SendError<T> {}
429}
430
431// Basic network modules
432pub mod net {
433    use std::io;
434    use std::net::{TcpListener as StdListener, TcpStream as StdStream, SocketAddr};
435
436    pub struct TcpListener(StdListener);
437    pub struct TcpStream(StdStream);
438
439    impl TcpListener {
440        pub async fn bind(addr: SocketAddr) -> io::Result<Self> {
441            let listener = StdListener::bind(addr)?;
442            listener.set_nonblocking(true)?;
443            Ok(Self(listener))
444        }
445
446        pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
447            loop {
448                match self.0.accept() {
449                    Ok((stream, addr)) => {
450                        stream.set_nonblocking(true)?;
451                        return Ok((TcpStream(stream), addr));
452                    }
453                    Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
454                        crate::sleep(std::time::Duration::from_millis(10)).await;
455                    }
456                    Err(e) => return Err(e),
457                }
458            }
459        }
460    }
461
462    impl TcpStream {
463        pub async fn connect(addr: SocketAddr) -> io::Result<Self> {
464            let stream = StdStream::connect(addr)?;
465            stream.set_nonblocking(true)?;
466            Ok(Self(stream))
467        }
468
469        pub fn into_std(self) -> StdStream {
470            self.0
471        }
472
473        pub fn as_std(&self) -> &StdStream {
474            &self.0
475        }
476
477        /// Read data from the stream
478        pub async fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
479            use std::io::Read;
480            loop {
481                match self.0.read(buf) {
482                    Ok(n) => return Ok(n),
483                    Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
484                        crate::sleep(std::time::Duration::from_millis(1)).await;
485                    }
486                    Err(e) => return Err(e),
487                }
488            }
489        }
490
491        /// Write data to the stream
492        pub async fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
493            use std::io::Write;
494            loop {
495                match self.0.write(buf) {
496                    Ok(n) => return Ok(n),
497                    Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
498                        crate::sleep(std::time::Duration::from_millis(1)).await;
499                    }
500                    Err(e) => return Err(e),
501                }
502            }
503        }
504
505        /// Write all data to the stream
506        pub async fn write_all(&mut self, mut buf: &[u8]) -> io::Result<()> {
507            while !buf.is_empty() {
508                let n = self.write(buf).await?;
509                buf = &buf[n..];
510            }
511            Ok(())
512        }
513    }
514}
515
516// Basic I/O module
517pub mod io {
518    use std::io::{self, Read, Write};
519
520    pub async fn copy<R: Read, W: Write>(reader: &mut R, writer: &mut W) -> io::Result<u64> {
521        let mut buf = [0u8; 8192];
522        let mut total = 0u64;
523
524        loop {
525            match reader.read(&mut buf) {
526                Ok(0) => return Ok(total),
527                Ok(n) => {
528                    writer.write_all(&buf[..n])?;
529                    total += n as u64;
530                }
531                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
532                    crate::sleep(std::time::Duration::from_millis(1)).await;
533                }
534                Err(e) => return Err(e),
535            }
536        }
537    }
538}