mini_tokio/
async.rs

1use crossbeam::channel;
2
3use std::{
4    cell::RefCell,
5    future::Future,
6    pin::Pin,
7    sync::{
8        atomic::{AtomicUsize, Ordering},
9        Arc, Mutex,
10    },
11    task::{Context, Poll, Wake, Waker},
12    time::{Duration, Instant},
13};
14
15pub use mini_tokio_attr::main;
16
17pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
18
19thread_local! {
20    static CURRENT: RefCell<Option<Runtime>> = RefCell::new(None);
21}
22
23pub struct Runtime {
24    scheduled: channel::Receiver<Arc<Task>>,
25    sender: channel::Sender<Arc<Task>>,
26    task_count: Arc<AtomicUsize>,
27}
28
29struct Task {
30    future: Mutex<BoxFuture<'static, ()>>,
31    executor: channel::Sender<Arc<Task>>,
32    task_count: Arc<AtomicUsize>,
33}
34
35impl Runtime {
36    pub fn new() -> Self {
37        let (sender, scheduled) = channel::unbounded();
38        let task_count = Arc::new(AtomicUsize::new(0));
39        Runtime { scheduled, sender, task_count }
40    }
41
42    pub fn block_on<F>(&self, future: F) -> F::Output
43    where
44        F: Future + Send + 'static,
45        F::Output: Send + 'static,
46    {
47        CURRENT.with(|cell| {
48            if cell.borrow().is_some() {
49                panic!("Attempting to start a runtime from within a runtime");
50            }
51            *cell.borrow_mut() = Some(self.clone());
52        });
53
54        let (output_sender, output_receiver) = channel::bounded(1);
55        let wrapped = SpawnableFuture::new(future, output_sender);
56        let main_task = Arc::new(Task {
57            future: Mutex::new(Box::pin(wrapped)),
58            executor: self.sender.clone(),
59            task_count: self.task_count.clone(),
60        });
61
62        self.task_count.fetch_add(1, Ordering::SeqCst);
63        let _ = self.sender.send(main_task);
64
65        loop {
66            if let Ok(task) = self.scheduled.try_recv() {
67                task.poll();
68            }
69
70            if let Ok(output) = output_receiver.try_recv() {
71                // Ensure all spawned tasks are completed
72                while self.task_count.load(Ordering::SeqCst) > 1 {
73                    if let Ok(task) = self.scheduled.try_recv() {
74                        task.poll();
75                    }
76                }
77
78                CURRENT.with(|cell| {
79                    *cell.borrow_mut() = None;
80                });
81
82                return output;
83            }
84
85            if self.task_count.load(Ordering::SeqCst) == 0 {
86                break;
87            }
88        }
89
90        CURRENT.with(|cell| {
91            *cell.borrow_mut() = None;
92        });
93
94        panic!("Runtime exited without producing a result");
95    }
96}
97
98impl Clone for Runtime {
99    fn clone(&self) -> Self {
100        Runtime {
101            scheduled: self.scheduled.clone(),
102            sender: self.sender.clone(),
103            task_count: self.task_count.clone(),
104        }
105    }
106}
107
108impl Task {
109    fn poll(self: Arc<Self>) {
110        let waker = TaskWaker::new(self.clone());
111        let mut cx = Context::from_waker(&waker);
112
113        if let Ok(mut future) = self.future.try_lock() {
114            if future.as_mut().poll(&mut cx).is_ready() {
115                self.task_count.fetch_sub(1, Ordering::SeqCst);
116            }
117        } else {
118            let _ = self.executor.send(self.clone());
119        }
120    }
121}
122
123struct TaskWaker {
124    task: Arc<Task>,
125}
126
127impl TaskWaker {
128    fn new(task: Arc<Task>) -> Waker { Waker::from(Arc::new(TaskWaker { task })) }
129}
130
131impl Wake for TaskWaker {
132    fn wake(self: Arc<Self>) { let _ = self.task.executor.send(self.task.clone()); }
133
134    fn wake_by_ref(self: &Arc<Self>) { let _ = self.task.executor.send(self.task.clone()); }
135}
136
137struct SpawnableFuture<F: Future> {
138    inner: Pin<Box<F>>,
139    output_sender: Option<channel::Sender<F::Output>>,
140}
141
142impl<F: Future> SpawnableFuture<F> {
143    fn new(future: F, output_sender: channel::Sender<F::Output>) -> Self {
144        SpawnableFuture {
145            inner: Box::pin(future),
146            output_sender: Some(output_sender),
147        }
148    }
149}
150
151impl<F: Future + Send + 'static> Future for SpawnableFuture<F>
152where
153    F::Output: Send + 'static,
154{
155    type Output = ();
156
157    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
158        match self.inner.as_mut().poll(cx) {
159            Poll::Ready(output) => {
160                if let Some(sender) = self.output_sender.take() {
161                    let _ = sender.send(output);
162                }
163                Poll::Ready(())
164            }
165            Poll::Pending => Poll::Pending,
166        }
167    }
168}
169
170pub async fn delay(dur: Duration) {
171    let start = Instant::now();
172    while start.elapsed() < dur {
173        yield_now().await;
174    }
175}
176
177pub async fn yield_now() {
178    struct Yield {
179        yielded: bool,
180    }
181
182    impl Future for Yield {
183        type Output = ();
184
185        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
186            if self.yielded {
187                Poll::Ready(())
188            } else {
189                self.yielded = true;
190                cx.waker().wake_by_ref();
191                Poll::Pending
192            }
193        }
194    }
195
196    Yield { yielded: false }.await;
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn test_runtime() {
205        let rt = Runtime::new();
206
207        rt.block_on(async {
208            delay(Duration::from_millis(500)).await;
209
210            println!("hello");
211            println!("world");
212
213            delay(Duration::from_millis(500)).await;
214        });
215
216        println!("Runtime exited");
217    }
218
219    #[test]
220    #[should_panic(expected = "Attempting to start a runtime from within a runtime")]
221    fn test_nested_runtime() {
222        let rt = Runtime::new();
223
224        rt.block_on(async {
225            let inner_rt = Runtime::new();
226            inner_rt.block_on(async {
227                // This should panic
228            });
229        });
230    }
231}