moduvex_runtime/executor/
mod.rs1pub mod scheduler;
16pub mod task;
17pub mod task_local;
18pub mod waker;
19pub mod work_stealing;
20
21use std::cell::Cell;
22use std::collections::HashMap;
23use std::future::Future;
24use std::sync::atomic::Ordering;
25use std::sync::Arc;
26use std::task::{Context, Poll};
27
28use crate::platform::sys::{create_pipe, events_with_capacity, Interest};
29use crate::reactor::{with_reactor, with_reactor_mut};
30use crate::time::{next_timer_deadline, tick_timer_wheel};
31
32use scheduler::{GlobalQueue, LocalQueue};
33use task::{JoinHandle, Task, STATE_CANCELLED, STATE_COMPLETED};
34use waker::make_waker;
35
36pub struct Executor {
40 local: LocalQueue,
42 global: Arc<GlobalQueue>,
44 tasks: HashMap<usize, Task>,
47 wake_rx: i32,
49 wake_tx: i32,
51}
52
53impl Executor {
54 fn new() -> std::io::Result<Self> {
55 let (wake_rx, wake_tx) = create_pipe()?;
56 with_reactor(|r| r.register(wake_rx, WAKE_TOKEN, Interest::READABLE))?;
57 Ok(Self {
58 local: LocalQueue::new(),
59 global: Arc::new(GlobalQueue::new()),
60 tasks: HashMap::new(),
61 wake_rx,
62 wake_tx,
63 })
64 }
65
66 pub fn spawn<F>(&mut self, future: F) -> JoinHandle<F::Output>
68 where
69 F: Future + 'static,
70 F::Output: Send + 'static,
71 {
72 let (task, jh) = Task::new(future);
73 let key = Arc::as_ptr(&task.header) as usize;
74 self.global.push_header(Arc::clone(&task.header));
75 self.tasks.insert(key, task);
76 jh
77 }
78
79 pub fn block_on<F: Future>(&mut self, future: F) -> F::Output {
81 let mut root = std::pin::pin!(future);
82 let mut root_done = false;
83 let mut root_output: Option<F::Output> = None;
84
85 let root_waker = self.make_root_waker();
86
87 loop {
88 let expired = tick_timer_wheel(std::time::Instant::now());
90 for w in expired {
91 w.wake();
92 }
93
94 if !root_done {
96 let mut cx = Context::from_waker(&root_waker);
97 if let Poll::Ready(val) = root.as_mut().poll(&mut cx) {
98 root_output = Some(val);
99 root_done = true;
100 }
101 }
102
103 if root_done && self.tasks.is_empty() {
105 break;
106 }
107
108 let mut did_work = false;
110 loop {
111 let Some(header) = self.next_task() else {
112 break;
113 };
114 did_work = true;
115 let key = Arc::as_ptr(&header) as usize;
116 let state = header.state.load(Ordering::Acquire);
117
118 if state == STATE_CANCELLED {
119 if let Some(task) = self.tasks.remove(&key) {
121 task.cancel();
122 }
123 continue;
124 }
125 if state == STATE_COMPLETED {
126 self.tasks.remove(&key);
128 continue;
129 }
130
131 let waker = make_waker(Arc::clone(&header), Arc::clone(&self.global));
133 let mut cx = Context::from_waker(&waker);
134
135 if let Some(task) = self.tasks.get(&key) {
136 let completed = task.poll_task(&mut cx);
137 if completed {
138 self.tasks.remove(&key);
139 }
140 }
141 }
143
144 if root_done && self.tasks.is_empty() {
146 break;
147 }
148
149 if !did_work && self.local.is_empty() && self.global.len() == 0 {
151 self.park();
152 }
153 }
154
155 root_output.expect("root future must complete before block_on returns")
156 }
157
158 fn next_task(&mut self) -> Option<Arc<task::TaskHeader>> {
160 self.local.pop().or_else(|| self.global.pop())
161 }
162
163 fn park(&self) {
168 const MAX_PARK_MS: u64 = 10;
169
170 let timeout_ms = match next_timer_deadline() {
172 None => MAX_PARK_MS,
173 Some(deadline) => {
174 let now = std::time::Instant::now();
175 if deadline <= now {
176 0 } else {
178 let ms = deadline.duration_since(now).as_millis() as u64;
179 ms.min(MAX_PARK_MS)
180 }
181 }
182 };
183
184 let mut events = events_with_capacity(64);
185 let _ = with_reactor_mut(|r| r.poll(&mut events, Some(timeout_ms)));
187 self.drain_wake_pipe();
188 }
189
190 #[cfg(unix)]
192 fn drain_wake_pipe(&self) {
193 let mut buf = [0u8; 64];
194 loop {
195 let n = unsafe { libc::read(self.wake_rx, buf.as_mut_ptr() as *mut _, buf.len()) };
197 if n <= 0 {
198 break;
199 } }
201 }
202
203 #[cfg(not(unix))]
204 fn drain_wake_pipe(&self) {
205 }
208
209 #[cfg(unix)]
212 fn make_root_waker(&self) -> std::task::Waker {
213 use std::task::{RawWaker, RawWakerVTable};
214
215 let tx = self.wake_tx;
216
217 unsafe fn clone_root(ptr: *const ()) -> RawWaker {
220 RawWaker::new(ptr, &ROOT_VTABLE)
221 }
222 unsafe fn wake_root(ptr: *const ()) {
223 let fd = ptr as usize as i32;
224 let b: u8 = 1;
225 libc::write(fd, &b as *const u8 as *const _, 1);
227 }
228 unsafe fn wake_root_by_ref(ptr: *const ()) {
229 wake_root(ptr);
230 }
231 unsafe fn drop_root(_: *const ()) {} static ROOT_VTABLE: RawWakerVTable =
234 RawWakerVTable::new(clone_root, wake_root, wake_root_by_ref, drop_root);
235
236 let raw = std::task::RawWaker::new(tx as usize as *const (), &ROOT_VTABLE);
237 unsafe { std::task::Waker::from_raw(raw) }
240 }
241
242 #[cfg(not(unix))]
243 fn make_root_waker(&self) -> std::task::Waker {
244 use std::task::{RawWaker, RawWakerVTable};
246 static NOOP_VTABLE: RawWakerVTable = RawWakerVTable::new(
247 |p| RawWaker::new(p, &NOOP_VTABLE),
248 |_| {},
249 |_| {},
250 |_| {},
251 );
252 unsafe { std::task::Waker::from_raw(RawWaker::new(std::ptr::null(), &NOOP_VTABLE)) }
253 }
254}
255
256impl Drop for Executor {
257 fn drop(&mut self) {
258 let _ = with_reactor(|r| r.deregister(self.wake_rx));
259 unsafe {
261 libc::close(self.wake_rx);
262 libc::close(self.wake_tx);
263 }
264 }
265}
266
267const WAKE_TOKEN: usize = usize::MAX;
270
271thread_local! {
274 static CURRENT_EXECUTOR: Cell<*mut Executor> = const { Cell::new(std::ptr::null_mut()) };
277}
278
279pub fn block_on<F: Future>(future: F) -> F::Output {
286 let mut exec = Executor::new().expect("executor init failed");
287 exec.block_on(future)
288}
289
290pub fn block_on_with_spawn<F: Future>(future: F) -> F::Output {
295 let mut exec = Executor::new().expect("executor init failed");
296 CURRENT_EXECUTOR.with(|c| c.set(&mut exec as *mut Executor));
297 let result = exec.block_on(future);
298 CURRENT_EXECUTOR.with(|c| c.set(std::ptr::null_mut()));
299 result
300}
301
302pub fn spawn<F>(future: F) -> JoinHandle<F::Output>
307where
308 F: Future + 'static,
309 F::Output: Send + 'static,
310{
311 CURRENT_EXECUTOR.with(|cell| {
312 let ptr = cell.get();
313 assert!(
314 !ptr.is_null(),
315 "spawn() called outside of block_on_with_spawn context"
316 );
317 unsafe { (*ptr).spawn(future) }
320 })
321}
322
323#[cfg(test)]
326mod tests {
327 use super::*;
328 use std::sync::atomic::{AtomicUsize, Ordering as Ord};
329
330 #[test]
331 fn block_on_simple_value() {
332 assert_eq!(block_on(async { 42u32 }), 42);
333 }
334
335 #[test]
336 fn block_on_chain_of_awaits() {
337 async fn double(x: u32) -> u32 {
338 x * 2
339 }
340 async fn compute() -> u32 {
341 double(double(3).await).await
342 }
343 assert_eq!(block_on(compute()), 12);
344 }
345
346 #[test]
347 fn block_on_string_output() {
348 assert_eq!(block_on(async { String::from("hello") }), "hello");
349 }
350
351 #[test]
352 fn spawn_and_join() {
353 let result = block_on_with_spawn(async {
354 let jh = spawn(async { 100u32 });
355 jh.await.unwrap()
356 });
357 assert_eq!(result, 100);
358 }
359
360 #[test]
361 fn spawn_multiple_and_join_all() {
362 let counter = Arc::new(AtomicUsize::new(0));
363 let c1 = counter.clone();
364 let c2 = counter.clone();
365 block_on_with_spawn(async move {
366 let jh1 = spawn(async move {
367 c1.fetch_add(1, Ord::SeqCst);
368 });
369 let jh2 = spawn(async move {
370 c2.fetch_add(1, Ord::SeqCst);
371 });
372 jh1.await.unwrap();
373 jh2.await.unwrap();
374 });
375 assert_eq!(counter.load(Ord::SeqCst), 2);
376 }
377
378 #[test]
379 fn join_handle_abort_returns_cancelled() {
380 use std::future::poll_fn;
381 use std::task::Poll as P;
382
383 let result = block_on_with_spawn(async {
384 let jh = spawn(async { poll_fn(|_| P::<()>::Pending).await });
385 jh.abort();
386 jh.await
387 });
388 assert!(matches!(result, Err(task::JoinError::Cancelled)));
389 }
390
391 #[test]
392 fn block_on_nested_spawn_ordering() {
393 let order = Arc::new(std::sync::Mutex::new(Vec::<u32>::new()));
395 let o1 = order.clone();
396 let o2 = order.clone();
397 block_on_with_spawn(async move {
398 let jh1 = spawn(async move {
399 o1.lock().unwrap().push(1);
400 });
401 let jh2 = spawn(async move {
402 o2.lock().unwrap().push(2);
403 });
404 jh1.await.unwrap();
405 jh2.await.unwrap();
406 });
407 let v = order.lock().unwrap();
408 assert_eq!(v.len(), 2);
409 }
410}