1use crate::{
2 join::{JoinHandle, JoinState},
3 preempt::PreemptState,
4 queue::{Queue, QueueKey, TaskId, TaskQueue},
5 stats::{ExecutorStats, QueueStats},
6 task::TaskHeader,
7 yield_once::yield_once,
8};
9use futures::FutureExt;
10use futures_util::task::AtomicWaker;
11use slab::Slab;
12use static_assertions::assert_not_impl_any;
13use std::sync::atomic::AtomicBool;
14use std::{
15 cell::Cell,
16 cell::RefCell,
17 future::Future,
18 mem,
19 pin::Pin,
20 rc::Rc,
21 sync::atomic::Ordering,
22 sync::Arc,
23 task::{Context, Poll},
24 time::{Duration, Instant},
25};
26
27thread_local! {
28 static YIELD_MAYBE_DEADLINE: Cell<Option<Instant>> = Cell::new(None);
29}
30
31fn set_yield_maybe_deadline(deadline: Instant) {
32 YIELD_MAYBE_DEADLINE.with(|cell| cell.set(Some(deadline)));
33}
34
35#[derive(Debug)]
36pub enum SpawnError<K: QueueKey> {
37 ShuttingDown,
38 QueueNotFound(K),
39 InvalidShare(u64),
40}
41
42struct CancelableFuture<T, K: QueueKey, F: Future<Output = T> + 'static> {
47 header: Arc<TaskHeader<K>>, join: Arc<JoinState<T>>,
49 fut: Pin<Box<F>>,
50 catch_panics: bool,
51}
52
53impl<T, K: QueueKey, F: Future<Output = T> + 'static> CancelableFuture<T, K, F> {
54 pub fn new(
55 header: Arc<TaskHeader<K>>,
56 join: Arc<JoinState<T>>,
57 fut: F,
58 catch_panics: bool,
59 ) -> Self {
60 Self {
61 header,
62 join,
63 fut: Box::pin(fut),
64 catch_panics,
65 }
66 }
67}
68
69impl<T, K: QueueKey, F: Future<Output = T> + 'static> Future for CancelableFuture<T, K, F> {
70 type Output = ();
71
72 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
73 if self.join.is_done() {
75 return Poll::Ready(());
76 }
77
78 if self.header.is_cancelled() {
80 self.join.try_complete_cancelled();
81 return Poll::Ready(());
82 }
83
84 let poll_result = if self.catch_panics {
86 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| self.fut.as_mut().poll(cx)))
87 } else {
88 Ok(self.fut.as_mut().poll(cx))
89 };
90
91 match poll_result {
92 Ok(Poll::Ready(out)) => {
93 self.join.try_complete_ok(out);
94 Poll::Ready(())
95 }
96 Ok(Poll::Pending) => Poll::Pending,
97 Err(panic_payload) => {
98 let panic_err = crate::join::PanicError::from_panic_payload(panic_payload);
100 self.join
101 .try_complete_err(crate::join::JoinError::Panic(panic_err));
102 Poll::Ready(())
103 }
104 }
105 }
106}
107
108struct UntilWakerWrapper {
111 woken: Arc<std::sync::atomic::AtomicBool>,
112 idle_waker: Arc<futures_util::task::AtomicWaker>,
113}
114
115impl futures_util::task::ArcWake for UntilWakerWrapper {
116 fn wake_by_ref(arc_self: &Arc<Self>) {
117 arc_self.woken.store(true, Ordering::Release);
118 arc_self.idle_waker.wake();
119 }
120}
121
122struct TaskRecord<K: QueueKey> {
124 header: Arc<TaskHeader<K>>,
125 waker: std::task::Waker,
126 fut: Pin<Box<dyn Future<Output = ()> + 'static>>, }
128
129struct QueueState<K: QueueKey> {
131 vruntime: u128, share: u64,
133 task_queue: Arc<TaskQueue>,
134 stats: QueueStats<K>,
135}
136
137impl<K: QueueKey> QueueState<K> {
138 fn new(queue: Queue<K>, task_queue: Arc<TaskQueue>) -> Self {
139 Self {
140 vruntime: 0,
141 stats: QueueStats::new(queue.id(), queue.share()),
142 share: queue.share(),
143 task_queue,
144 }
145 }
146}
147
148pub struct QueueHandle<K: QueueKey> {
149 executor: Rc<Executor<K>>,
150 qid: K,
151}
152impl<K: QueueKey> QueueHandle<K> {
153 pub fn spawn<T, F>(self: &Self, fut: F) -> JoinHandle<T, K>
154 where
155 T: 'static,
156 F: Future<Output = T> + 'static, {
158 self.executor.spawn_inner(self.qid, fut)
159 }
160}
161
162pub struct ExecutorBuilder<K: QueueKey> {
163 options: ExecutorOptions,
164 queues: Vec<Queue<K>>,
165}
166impl<K: QueueKey> ExecutorBuilder<K> {
167 pub fn new() -> Self {
168 Self {
169 options: ExecutorOptions::default(),
170 queues: Vec::new(),
171 }
172 }
173 pub fn with_sched_latency(mut self, sched_latency: Duration) -> Self {
174 self.options.sched_latency = sched_latency;
175 self
176 }
177 pub fn with_min_slice(mut self, min_slice: Duration) -> Self {
178 self.options.min_slice = min_slice;
179 self
180 }
181 pub fn with_driver_yield(mut self, driver_yield: Duration) -> Self {
182 self.options.driver_yield = driver_yield;
183 self
184 }
185
186 pub fn with_queue(mut self, qid: K, share: u64) -> Self {
188 let queue = Queue::new(qid, share);
189 self.queues.push(queue);
190 self
191 }
192 pub fn with_panic_on_task_panic(mut self, panic_on_task_panic: bool) -> Self {
193 self.options.panic_on_task_panic = panic_on_task_panic;
194 self
195 }
196 pub fn with_max_polls_per_yield(mut self, max_polls: u32) -> Self {
200 self.options.max_polls_per_yield = max_polls;
201 self
202 }
203 pub fn with_enable_lifo(mut self, enable: bool) -> Self {
207 self.options.enable_lifo = enable;
208 self
209 }
210 pub fn with_lifo_skip_interval(mut self, interval: usize) -> Self {
215 self.options.lifo_skip_interval = interval;
216 self
217 }
218 pub fn build(self) -> Result<Rc<Executor<K>>, String> {
219 Executor::new(self.options, self.queues)
220 }
221}
222
223pub struct ExecutorOptions {
224 sched_latency: Duration,
225 min_slice: Duration,
226 driver_yield: Duration,
227 panic_on_task_panic: bool,
228 max_polls_per_yield: u32,
231 enable_lifo: bool,
233 lifo_skip_interval: usize,
235}
236impl Default for ExecutorOptions {
237 fn default() -> Self {
238 Self {
239 sched_latency: Duration::from_millis(2),
240 min_slice: Duration::from_micros(100),
241 driver_yield: Duration::from_micros(500),
242 panic_on_task_panic: true,
243 max_polls_per_yield: 61, enable_lifo: false, lifo_skip_interval: 16,
246 }
247 }
248}
249
250pub struct Executor<K: QueueKey> {
252 options: ExecutorOptions,
253 task_queues: Vec<Arc<TaskQueue>>,
254 is_runnable: RefCell<Vec<bool>>, tasks: RefCell<Slab<TaskRecord<K>>>,
257 queues: RefCell<Vec<QueueState<K>>>,
258 qids: RefCell<Vec<K>>,
259
260 min_vruntime: std::cell::Cell<u128>,
261
262 preempt_state: Arc<PreemptState>,
265
266 stats: RefCell<ExecutorStats>,
268}
269
270assert_not_impl_any!(Executor<u8>: Send, Sync);
271
272impl<K: QueueKey> Executor<K> {
273 pub fn new(options: ExecutorOptions, queues: Vec<Queue<K>>) -> Result<Rc<Self>, String> {
275 if queues.is_empty() {
276 return Err("Must have at least one queue".to_string());
277 }
278 for i in 0..queues.len() {
280 for j in i + 1..queues.len() {
281 if queues[i].id() == queues[j].id() {
282 return Err("All queues must have unique ids".to_string());
283 }
284 }
285 }
286 if queues.iter().any(|q| q.share() == 0) {
288 return Err("All queues must have a share > 0".to_string());
289 }
290
291 let num_queues = queues.len();
293 if num_queues > 256 {
294 return Err("Cannot have more than 256 queues (preemption mask limit)".to_string());
295 }
296
297 let preempt_state = Arc::new(PreemptState::new());
299
300 let task_queues: Vec<Arc<TaskQueue>> = (0..num_queues)
301 .map(|_| {
302 Arc::new(TaskQueue::new(
303 options.enable_lifo,
304 options.lifo_skip_interval,
305 ))
306 })
307 .collect();
308
309 let qids = queues.iter().map(|q| q.id()).collect::<Vec<_>>();
310 let queues: Vec<QueueState<K>> = queues
311 .into_iter()
312 .enumerate()
313 .map(|(idx, q)| QueueState::new(q, task_queues[idx].clone()))
314 .collect();
315
316 Ok(Rc::new(Self {
317 task_queues,
318 is_runnable: RefCell::new(vec![false; num_queues]),
319 tasks: RefCell::new(Slab::new()),
320 queues: RefCell::new(queues),
321 qids: RefCell::new(qids),
322 options,
323 min_vruntime: std::cell::Cell::new(0),
324 preempt_state,
325 stats: RefCell::new(ExecutorStats::new(Instant::now())),
326 }))
327 }
328
329 pub fn queue(self: &Rc<Self>, qid: K) -> Result<QueueHandle<K>, SpawnError<K>> {
331 let Some(_) = self.qids.borrow().iter().position(|q| *q == qid) else {
332 return Err(SpawnError::QueueNotFound(qid));
333 };
334 Ok(QueueHandle {
335 executor: self.clone(),
336 qid,
337 })
338 }
339
340 fn spawn_inner<T, F>(self: &Rc<Self>, qid: K, fut: F) -> JoinHandle<T, K>
342 where
343 T: 'static,
344 F: Future<Output = T> + 'static, {
346 let qid = qid.into();
347 let qidx = self
348 .qids
349 .borrow()
350 .iter()
351 .position(|q| *q == qid)
352 .expect("queue should exist");
353 let mut tasks = self.tasks.borrow_mut();
354 let entry = tasks.vacant_entry();
355 let id = entry.key();
356 let preempt_state = if self.task_queues.len() > 1 {
357 Some(self.preempt_state.clone())
358 } else {
359 None
360 };
361 let header = Arc::new(TaskHeader::new(
362 id,
363 qid,
364 qidx,
365 self.task_queues[qidx].clone(),
366 preempt_state,
367 ));
368 let join = Arc::new(JoinState::<T>::new());
369 let catch_panics = !self.options.panic_on_task_panic;
372 let wrapped = CancelableFuture::new(header.clone(), join.clone(), fut, catch_panics);
373
374 let waker = futures::task::waker(header.clone());
375
376 entry.insert(TaskRecord {
377 header: header.clone(),
378 waker,
379 fut: Box::pin(wrapped),
380 });
381
382 header.enqueue();
384
385 JoinHandle::new(header, join)
386 }
387
388 fn pick_next_class(&self) -> Option<(usize, Duration, u128, usize)> {
395 let mut best: Option<(usize, u128)> = None;
396 let mut runnable = None;
397 let mut num_runnable = 0;
398 let mut is_runnable = self.is_runnable.borrow_mut();
399 for (idx, q) in self.queues.borrow_mut().iter_mut().enumerate() {
400 let was_runnable = is_runnable[idx];
401 is_runnable[idx] = !q.task_queue.is_empty();
402 if !was_runnable && is_runnable[idx] {
403 q.vruntime = q.vruntime.max(self.min_vruntime.get());
405 }
406 if is_runnable[idx] {
407 num_runnable += 1;
408 runnable = Some(idx);
409 }
410 }
411 if num_runnable == 0 {
412 return None;
413 }
414 let request = self.options.sched_latency.as_nanos() as u128 / num_runnable as u128;
415 let request = request.max(self.options.min_slice.as_nanos() as u128);
416
417 if num_runnable == 1 {
418 let selected_idx = runnable.unwrap();
419 let queues = self.queues.borrow();
420 let selected_deadline =
421 queues[selected_idx].vruntime + (request / queues[selected_idx].share as u128);
422 return Some((
423 selected_idx,
424 Duration::from_nanos(request as u64),
425 selected_deadline,
426 num_runnable,
427 ));
428 }
429
430 for (idx, q) in self.queues.borrow().iter().enumerate() {
432 if q.task_queue.is_empty() {
433 continue;
434 }
435 let deadline = q.vruntime + (request / q.share as u128);
437 match best {
438 None => best = Some((idx, deadline)),
439 Some((_, bv)) if deadline < bv => best = Some((idx, deadline)),
440 _ => {}
441 }
442 }
443
444 let (selected_idx, selected_deadline) = best.unwrap();
445 Some((
446 selected_idx,
447 Duration::from_nanos(request as u64),
448 selected_deadline,
449 num_runnable,
450 ))
451 }
452
453 fn update_preempt_mask(&self, selected_deadline: u128, num_runnable: usize) {
457 let is_runnable = self.is_runnable.borrow();
458 let queues = self.queues.borrow();
459
460 let hypothetical_request =
462 self.options.sched_latency.as_nanos() as u128 / (num_runnable + 1) as u128;
463 let hypothetical_request =
464 hypothetical_request.max(self.options.min_slice.as_nanos() as u128);
465 let min_vruntime = self.min_vruntime.get();
466
467 let preempting = (0..queues.len()).filter(|&idx| {
469 if is_runnable[idx] {
470 return false; }
472 let hypothetical_deadline =
474 min_vruntime + (hypothetical_request / queues[idx].share as u128);
475 hypothetical_deadline < selected_deadline
476 });
477 self.preempt_state.update_mask(preempting);
478 }
479
480 fn charge_class(&self, qidx: usize, elapsed: Duration) {
484 if self.task_queues.len() <= 1 {
485 return;
486 }
487 let mut queues = self.queues.borrow_mut();
488 let queue = &mut queues[qidx];
489 let incr = (elapsed.as_nanos() + queue.share as u128 - 1) / (queue.share as u128);
491 queue.vruntime += incr;
492 queue.stats.record_poll(elapsed);
493 }
494 fn update_min_vruntime(&self, including: u128) {
495 if self.task_queues.len() <= 1 {
496 return;
497 }
498 let min_vruntime = self
499 .queues
500 .borrow()
501 .iter()
502 .filter(|q| !q.task_queue.is_empty())
503 .map(|q| q.vruntime)
504 .chain(Some(including))
505 .min();
506 let min_vruntime = min_vruntime.unwrap();
507 let prev_min_vruntime = self.min_vruntime.get();
509 self.min_vruntime.set(prev_min_vruntime.max(min_vruntime));
510 }
511
512 pub fn stats(&self) -> ExecutorStats {
514 self.stats.borrow().clone()
515 }
516
517 pub fn qstats(&self) -> Vec<QueueStats<K>> {
519 self.queues
520 .borrow()
521 .iter()
522 .map(|q| q.stats.clone())
523 .collect()
524 }
525
526 pub async fn run_until<F: Future>(&self, until: F) -> F::Output {
536 let mut until_pinned = std::pin::pin!(until.fuse());
537
538 let until_woken = Arc::new(AtomicBool::new(false));
540 let idle_waker = Arc::new(AtomicWaker::new());
542 let until_waker = self.create_until_waker(until_woken.clone(), idle_waker.clone());
544
545 let mut last_driver_yield_at = Instant::now();
546 let mut iter = 0u64;
547
548 {
550 let mut cx = Context::from_waker(&until_waker);
551 if let Poll::Ready(result) = until_pinned.as_mut().poll(&mut cx) {
552 return result;
553 }
554 }
555
556 loop {
557 iter += 1;
558 let enable_stats = iter % 128 == 0;
559 self.stats.borrow_mut().record_loop_iter();
560
561 if until_woken.swap(false, Ordering::AcqRel) {
563 let mut cx = Context::from_waker(&until_waker);
564 if let Poll::Ready(result) = until_pinned.as_mut().poll(&mut cx) {
565 return result;
566 }
567 }
568
569 let Some((qidx, timeslice)) = self.select_queue(enable_stats) else {
571 self.wait_for_work_or_signal(&until_woken, &idle_waker)
573 .await;
574 continue;
575 };
576
577 let timeslice = timeslice.min(self.options.driver_yield);
579 let end = self.run_timeslice(qidx, timeslice, enable_stats);
580
581 let new_vruntime = self.queues.borrow()[qidx].vruntime;
583 self.update_min_vruntime(new_vruntime);
584
585 last_driver_yield_at = self.yield_to_driver(last_driver_yield_at, end).await;
587 }
588 }
589
590 fn create_until_waker(
592 &self,
593 until_woken: Arc<std::sync::atomic::AtomicBool>,
594 idle_waker: Arc<futures_util::task::AtomicWaker>,
595 ) -> std::task::Waker {
596 let wrapper = Arc::new(UntilWakerWrapper {
597 woken: until_woken,
598 idle_waker,
599 });
600 futures::task::waker(wrapper)
601 }
602
603 async fn wait_for_work_or_signal(
606 &self,
607 until_woken: &Arc<AtomicBool>,
608 idle_waker: &Arc<AtomicWaker>,
609 ) {
610 use futures_util::future::poll_fn;
611
612 poll_fn(|cx| {
613 idle_waker.register(cx.waker());
615
616 for task_queue in &self.task_queues {
620 task_queue.register_waker(cx.waker());
621 }
622
623 if until_woken.load(Ordering::Acquire) {
625 return Poll::Ready(());
626 }
627
628 for task_queue in &self.task_queues {
630 if !task_queue.is_empty() {
631 return Poll::Ready(());
632 }
633 }
634
635 Poll::Pending
636 })
637 .await
638 }
639
640 fn select_queue(&self, enable_stats: bool) -> Option<(usize, Duration)> {
643 let start = if enable_stats {
644 Some(Instant::now())
645 } else {
646 None
647 };
648 if self.task_queues.len() == 1 {
650 match self.task_queues[0].is_empty() {
651 true => return None,
652 false => return Some((0, self.options.sched_latency)),
653 }
654 }
655
656 let Some((selected_idx, timeslice, selected_deadline, num_runnable)) =
657 self.pick_next_class()
658 else {
659 self.preempt_state.update_mask(std::iter::empty());
661 return None;
662 };
663
664 self.preempt_state.clear_preempt();
666
667 self.update_preempt_mask(selected_deadline, num_runnable);
669
670 if let Some(start) = start {
671 let elapsed = Instant::now().duration_since(start);
672 self.stats.borrow_mut().record_schedule_decision(elapsed);
673 }
674
675 Some((selected_idx, timeslice))
676 }
677
678 fn pop_next_task_from_queue(&self, qidx: usize) -> Option<TaskId> {
681 loop {
682 let mut queues = self.queues.borrow_mut();
683 let queue = &mut queues[qidx];
684 queue.stats.record_runnable_dequeue();
686 let maybe_id = queue.task_queue.pop();
687
688 drop(queues);
689
690 let Some(id) = maybe_id else {
691 return None;
692 };
693
694 let tasks = self.tasks.borrow();
695 let Some(task) = tasks.get(id) else {
696 continue;
698 };
699
700 if task.header.is_done() {
701 continue;
703 }
704
705 return Some(id);
706 }
707 }
708
709 fn poll_task(&self, id: TaskId, qidx: usize, start: Instant) -> (bool, Duration) {
711 let (waker, mut extracted_fut) = {
714 let mut tasks = self.tasks.borrow_mut();
715 let task = match tasks.get_mut(id) {
716 Some(task) => task,
717 None => return (false, Duration::ZERO),
718 };
719
720 task.header.set_queued(false);
722
723 let waker = task.waker.clone();
725
726 let placeholder = Box::pin(futures::future::ready(()));
729 let extracted_fut = mem::replace(&mut task.fut, placeholder);
730
731 (waker, extracted_fut)
732 };
733 let mut cx = Context::from_waker(&waker);
736
737 let poll = extracted_fut.as_mut().poll(&mut cx);
740
741 let end = Instant::now();
742 let elapsed = end.saturating_duration_since(start);
743 self.charge_class(qidx, elapsed);
744
745 {
747 let mut tasks = self.tasks.borrow_mut();
748 let task = match tasks.get_mut(id) {
749 Some(task) => task,
750 None => {
751 return (false, elapsed);
753 }
754 };
755
756 match poll {
757 Poll::Ready(()) => {
758 (true, elapsed)
761 }
762 Poll::Pending => {
763 let placeholder = mem::replace(&mut task.fut, extracted_fut);
765 drop(placeholder);
767 (false, elapsed)
768 }
769 }
770 }
771 }
772
773 fn complete_task(&self, id: TaskId, _qidx: usize) {
775 let mut tasks = self.tasks.borrow_mut();
776 let task = tasks.get_mut(id).expect("task should exist");
777 task.header.set_done();
778 tasks.remove(id);
779 }
780
781 fn run_timeslice(&self, qidx: usize, timeslice: Duration, enable_stats: bool) -> Instant {
784 let now = Instant::now();
785 let until = now + timeslice;
786 if enable_stats {
787 self.queues.borrow_mut()[qidx]
788 .stats
789 .record_first_service_after_runnable(now);
790 }
791
792 {
797 let queue = &self.queues.borrow()[qidx];
798 queue.task_queue.drain_lifo_to_mpsc();
799 }
800
801 let mut start = now;
802 let mut polls_this_slice = 0u32;
803 let max_polls = self.options.max_polls_per_yield;
804
805 loop {
806 set_yield_maybe_deadline(until);
807
808 let Some(id) = self.pop_next_task_from_queue(qidx) else {
809 break; };
811
812 let (completed, elapsed) = self.poll_task(id, qidx, start);
813 let end = start + elapsed;
814
815 if completed {
816 self.complete_task(id, qidx);
817 }
818 start = end;
819 polls_this_slice += 1;
820
821 if polls_this_slice >= max_polls {
823 break;
824 }
825
826 if end > until {
827 if enable_stats {
828 self.stats.borrow_mut().record_poll(elapsed, true);
829 let mut queues = self.queues.borrow_mut();
830 queues[qidx].stats.record_slice_overrun();
831 queues[qidx].stats.record_slice_exhausted();
832 }
833 break;
834 }
835 if self.preempt_state.check() {
837 break;
838 }
839 }
840 start
841 }
842
843 async fn yield_to_driver(&self, last_yield: Instant, now: Instant) -> Instant {
845 let since_last = now - last_yield;
846 yield_once().await;
847 let after_yield = Instant::now();
848 let in_driver = after_yield.duration_since(now);
849 self.stats
850 .borrow_mut()
851 .record_driver_yield(since_last, in_driver);
852 after_yield
853 }
854}
855
856pub async fn yield_maybe() {
857 let should_yield = YIELD_MAYBE_DEADLINE.with(|d| {
858 if let Some(dl) = d.get() {
859 Instant::now() >= dl
860 } else {
861 false
862 }
863 });
864 if should_yield {
865 YIELD_MAYBE_DEADLINE.with(|d| d.set(None));
867 yield_once().await;
868 }
869}
870
871#[cfg(test)]
872mod tests {
873 use super::*;
874 use crate::join::JoinError;
875 use crate::yield_once::yield_once;
876 use std::sync::atomic::AtomicBool;
877 use std::sync::atomic::{AtomicU32, Ordering};
878 use std::sync::{Arc, Mutex};
879 use tokio::task::LocalSet;
880 use tokio::time::{sleep, timeout, Duration};
881
882 #[tokio::test]
883 async fn test_basic_task_completion() {
884 let local = LocalSet::new();
885 local
886 .run_until(async {
887 let executor = ExecutorBuilder::new()
888 .with_queue(0, 1)
889 .build()
890 .unwrap();
891 let counter = Arc::new(AtomicU32::new(0));
892
893 let counter_clone = counter.clone();
894 let result = executor.run_until(async {
896 let queue = executor.queue(0).unwrap();
897 let handle = queue.spawn(async move {
898 counter_clone.fetch_add(1, Ordering::Relaxed);
899 });
900 handle.await
901 });
902 let result = timeout(Duration::from_millis(100), result).await;
903 assert!(result.is_ok(), "Task should complete");
904 assert_eq!(counter.load(Ordering::Relaxed), 1);
905 })
906 .await;
907 }
908
909 #[tokio::test]
910 async fn test_join_handle_returns_result() {
911 let local = LocalSet::new();
912 local
913 .run_until(async {
914 let executor = ExecutorBuilder::new()
915 .with_queue(0, 1)
916 .build()
917 .unwrap();
918
919 let result = executor.run_until(async {
920 let queue = executor.queue(0).unwrap();
921 let handle = queue.spawn(async move { 42 });
922 handle.await
923 });
924 let result = timeout(Duration::from_millis(100), result).await;
925 assert!(result.is_ok(), "JoinHandle should complete");
926 let join_result = result.unwrap();
927 assert_eq!(join_result, Ok(42));
928 })
929 .await;
930 }
931
932 #[tokio::test]
933 async fn test_join_handle_abort() {
934 let local = LocalSet::new();
935 local
936 .run_until(async {
937 let executor = ExecutorBuilder::new()
938 .with_queue(0, 1)
939 .build()
940 .unwrap();
941 let started = Arc::new(AtomicBool::new(false));
942 let completed = Arc::new(AtomicBool::new(false));
943 let started_clone = started.clone();
944 let completed_clone = completed.clone();
945
946 let queue = executor.queue(0).unwrap();
947 let handle = executor
948 .run_until(async {
949 let handle = queue.spawn(async move {
950 started_clone.store(true, Ordering::Relaxed);
951 for _ in 0..100 {
953 sleep(Duration::from_millis(10)).await;
954 }
955 completed_clone.store(true, Ordering::Relaxed);
956 });
957 sleep(Duration::from_millis(50)).await;
959 assert!(started.load(Ordering::Relaxed), "Task should have started");
960
961 handle.abort();
963 handle
964 })
965 .await;
966
967 let result = timeout(Duration::from_millis(500), handle).await;
969 assert!(result.is_ok(), "JoinHandle should complete after abort");
970 let join_result = result.unwrap();
971 assert!(matches!(join_result, Err(JoinError::Cancelled)));
972
973 assert!(
975 !completed.load(Ordering::Relaxed),
976 "Task should not have completed"
977 );
978 })
979 .await;
980 }
981
982 #[tokio::test]
983 async fn test_vruntime_scheduling() {
984 let local = LocalSet::new();
985 local
986 .run_until(async {
987 let executor = ExecutorBuilder::new()
988 .with_queue(0, 8)
989 .with_queue(1, 1)
990 .build()
991 .unwrap();
992 let queue1 = executor.queue(0).unwrap();
993 let queue2 = executor.queue(1).unwrap();
994 let high = Arc::new(AtomicU32::new(0));
995 let low = Arc::new(AtomicU32::new(0));
996 let high_clone = high.clone();
997 let low_clone = low.clone();
998
999 executor
1000 .run_until(async {
1001 let handle1 = queue1.spawn(async move {
1007 loop {
1008 for _ in 0..100_000 {
1009 high_clone.fetch_add(1, Ordering::Relaxed);
1010 }
1011 yield_once().await;
1012 }
1013 });
1014 let handle2 = queue2.spawn(async move {
1015 loop {
1016 for _ in 0..100_000 {
1017 low_clone.fetch_add(1, Ordering::Relaxed);
1018 }
1019 yield_once().await;
1020 }
1021 });
1022 sleep(Duration::from_millis(100)).await;
1023 handle1.abort();
1024 handle2.abort();
1025 })
1026 .await;
1027 let high_count = high.load(Ordering::Relaxed);
1028 let low_count = low.load(Ordering::Relaxed);
1029 assert!(
1031 low_count * 2 < high_count && high_count < low_count * 16,
1032 "High weight class should get significantly more CPU time. High: {}, Low: {}",
1033 high_count,
1034 low_count
1035 );
1036 })
1037 .await;
1038 }
1039
1040 #[tokio::test]
1041 async fn test_policy_fifo_ordering() {
1042 let local = LocalSet::new();
1043 local
1044 .run_until(async {
1045 let executor = ExecutorBuilder::new()
1046 .with_queue(0, 1)
1047 .build()
1048 .unwrap();
1049 let queue = executor.queue(0).unwrap();
1050 let execution_order = Arc::new(Mutex::new(Vec::new()));
1051
1052 for i in 0..5 {
1054 let order_clone = execution_order.clone();
1055 let _handle = queue.spawn(async move {
1056 order_clone.lock().unwrap().push(i);
1057 });
1058 }
1059
1060 let executor_clone = executor.clone();
1061 local.spawn_local(async move {
1062 executor_clone
1064 .run_until(sleep(Duration::from_millis(200)))
1065 .await;
1066 });
1067
1068 sleep(Duration::from_millis(200)).await;
1070
1071 let order = execution_order.lock().unwrap();
1072 assert_eq!(order.len(), 5, "All tasks should have executed");
1074 assert_eq!(
1075 *order,
1076 vec![0, 1, 2, 3, 4],
1077 "Tasks should execute in FIFO order"
1078 );
1079 })
1080 .await;
1081 }
1082
1083 #[tokio::test]
1084 async fn test_multiple_tasks_same_class() {
1085 let local = LocalSet::new();
1086 local
1087 .run_until(async {
1088 let executor = ExecutorBuilder::new()
1089 .with_queue(0, 1)
1090 .build()
1091 .unwrap();
1092 let queue = executor.queue(0).unwrap();
1093 let counter = Arc::new(AtomicU32::new(0));
1094 let counter_clone = counter.clone();
1095
1096 executor
1097 .run_until(async {
1098 let mut handles = Vec::new();
1099 for _ in 0..5 {
1100 let counter_clone = counter.clone();
1101 let handle = queue.spawn(async move {
1102 counter_clone.fetch_add(1, Ordering::Relaxed);
1103 });
1104 handles.push(handle);
1105 }
1106 for handle in handles {
1107 let result = timeout(Duration::from_millis(100), handle).await;
1108 assert!(result.is_ok(), "All tasks should complete");
1109 }
1110 })
1111 .await;
1112 assert_eq!(counter_clone.load(Ordering::Relaxed), 5);
1113 })
1114 .await;
1115 }
1116
1117 #[tokio::test]
1118 async fn test_task_with_yield() {
1119 let local = LocalSet::new();
1120 local
1121 .run_until(async {
1122 let executor = ExecutorBuilder::new()
1123 .with_queue(0, 1)
1124 .build()
1125 .unwrap();
1126 let queue = executor.queue(0).unwrap();
1127 let counter = Arc::new(AtomicU32::new(0));
1128
1129 let counter_clone = counter.clone();
1130 executor
1131 .run_until(async {
1132 let handle = queue.spawn(async move {
1133 for _ in 0..3 {
1134 counter_clone.fetch_add(1, Ordering::Relaxed);
1135 sleep(Duration::from_millis(10)).await;
1136 }
1137 });
1138 let result = timeout(Duration::from_millis(500), handle).await;
1139 assert!(
1140 result.is_ok(),
1141 "Task with yields should complete, got {:?}",
1142 result
1143 );
1144 })
1145 .await;
1146
1147 assert_eq!(counter.load(Ordering::Relaxed), 3);
1148 })
1149 .await;
1150 }
1151
1152 #[tokio::test]
1153 async fn test_abort_before_task_starts() {
1154 let local = LocalSet::new();
1155 local
1156 .run_until(async {
1157 let executor = ExecutorBuilder::new()
1158 .with_queue(0, 1)
1159 .build()
1160 .unwrap();
1161 let queue = executor.queue(0).unwrap();
1162 let executed = Arc::new(AtomicBool::new(false));
1163
1164 let executed_clone = executed.clone();
1165 let handle = queue.spawn(async move {
1166 executed_clone.store(true, Ordering::Relaxed);
1167 });
1168
1169 handle.abort();
1171
1172 let executor_clone = executor.clone();
1173 local.spawn_local(async move {
1174 executor_clone
1175 .run_until(sleep(Duration::from_millis(100)))
1176 .await;
1177 });
1178
1179 sleep(Duration::from_millis(100)).await;
1181
1182 assert!(
1184 !executed.load(Ordering::Relaxed),
1185 "Task should not execute after abort"
1186 );
1187
1188 let result = timeout(Duration::from_millis(50), handle).await;
1190 assert!(result.is_ok());
1191 assert!(matches!(result.unwrap(), Err(JoinError::Cancelled)));
1192 })
1193 .await;
1194 }
1195
1196 #[tokio::test]
1197 async fn test_enum_queue_ids() {
1198 #[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
1199 enum QueueId {
1200 High,
1201 Low,
1202 }
1203 let local = LocalSet::new();
1204 local
1205 .run_until(async {
1206 let executor = ExecutorBuilder::new()
1207 .with_queue(QueueId::High, 1)
1208 .with_queue(QueueId::Low, 1)
1209 .build()
1210 .unwrap();
1211 let high = Arc::new(AtomicU32::new(0));
1212 let low = Arc::new(AtomicU32::new(0));
1213
1214 let high_clone = high.clone();
1215 let low_clone = low.clone();
1216
1217 let executor_clone = executor.clone();
1218 local.spawn_local(async move {
1219 executor_clone
1220 .run_until(sleep(Duration::from_millis(100)))
1221 .await;
1222 });
1223 let q1 = executor.queue(QueueId::High).unwrap();
1224 let _ = q1.spawn(async move {
1225 high_clone.fetch_add(1, Ordering::Relaxed);
1226 yield_once().await;
1227 });
1228 let q2 = executor.queue(QueueId::Low).unwrap();
1229 let _ = q2.spawn(async move {
1230 low_clone.fetch_add(1, Ordering::Relaxed);
1231 yield_once().await;
1232 });
1233 sleep(Duration::from_millis(100)).await;
1234 })
1235 .await;
1236 }
1237
1238 #[tokio::test]
1239 async fn test_vruntime_resets() {
1240 let local = LocalSet::new();
1241 local
1242 .run_until(async {
1243 let executor = ExecutorBuilder::new()
1244 .with_queue(0, 1)
1245 .with_queue(1, 1)
1246 .build()
1247 .unwrap();
1248 let counter = Arc::new(AtomicU32::new(0));
1249 let counter_clone = counter.clone();
1250 let q1 = executor.queue(0).unwrap();
1251 executor
1252 .run_until(async {
1253 let handle = q1.spawn(async move {
1254 for _ in 0..1000 {
1255 counter_clone.fetch_add(1, Ordering::Relaxed);
1256 yield_once().await;
1257 }
1258 });
1259 let result = timeout(Duration::from_millis(100), handle).await;
1260 assert!(result.is_ok(), "Task should complete");
1261 assert_eq!(counter.load(Ordering::Relaxed), 1000);
1262 let vruntime1 = executor.queues.borrow()[0].vruntime;
1263 assert!(vruntime1 > 0);
1264 let counter_clone = counter.clone();
1266 let q2 = executor.queue(1).unwrap();
1267 let handle = q2.spawn(async move {
1268 counter_clone.fetch_add(1, Ordering::Relaxed);
1269 });
1270 let result = timeout(Duration::from_millis(100), handle).await;
1271 assert!(result.is_ok(), "Task should complete");
1272 assert_eq!(counter.load(Ordering::Relaxed), 1001);
1273 let vruntime2 = executor.queues.borrow()[1].vruntime;
1274 assert!(
1278 vruntime2 > vruntime1,
1279 "vruntime2 should be greater than vruntime1, got {} and {}",
1280 vruntime2,
1281 vruntime1
1282 );
1283 })
1284 .await;
1285 })
1286 .await;
1287 }
1288
1289 #[tokio::test]
1290 async fn test_yield_maybe() {
1291 let local = LocalSet::new();
1292 local
1293 .run_until(async {
1294 let executor = ExecutorBuilder::new()
1295 .with_queue(0, 1)
1296 .build()
1297 .unwrap();
1298 let queue = executor.queue(0).unwrap();
1299 let counter1 = Arc::new(AtomicU32::new(0));
1300 let counter1_clone = counter1.clone();
1301 local.spawn_local(async move {
1302 executor
1303 .run_until(async {
1304 let handle = queue.spawn(async move {
1305 let mut i = 0;
1306 loop {
1307 counter1_clone.fetch_add(1, Ordering::Relaxed);
1308 if i % 1000 == 0 {
1309 yield_maybe().await;
1310 }
1311 i += 1;
1312 }
1313 });
1314 sleep(Duration::from_millis(100)).await;
1315 let count = counter1.load(Ordering::Relaxed);
1316 assert!(count > 0);
1317 let yields = executor.stats.borrow().driver_yields;
1318 assert!(yields > 0);
1319 assert!(yields < count as u64 / 1000 / 2);
1322 handle.abort();
1323 })
1324 .await;
1325 });
1326 })
1327 .await;
1328 }
1329
1330 #[test]
1332 fn test_smol_runtime() {
1333 let executor = ExecutorBuilder::new().with_queue(0, 1).build().unwrap();
1334 let smol_local_ex = smol::LocalExecutor::new();
1335 let h2 = smol_local_ex.spawn(async move {
1336 let queue = executor.queue(0).unwrap();
1337 executor
1338 .run_until(async {
1339 let handle = queue.spawn(async move { 42 });
1340 handle.await
1341 })
1342 .await
1343 });
1344
1345 let res = smol::future::block_on(smol_local_ex.run(async { h2.await }));
1346 assert_eq!(res, Ok(42));
1347 }
1348
1349 #[tokio::test]
1350 async fn test_abort_after_done() {
1351 let local = LocalSet::new();
1352 local
1353 .run_until(async {
1354 let executor = ExecutorBuilder::new()
1355 .with_queue(0, 1)
1356 .build()
1357 .unwrap();
1358 let counter = Arc::new(AtomicU32::new(0));
1359 let counter_clone = counter.clone();
1360 let queue = executor.queue(0).unwrap();
1361 let result = executor
1362 .run_until(async {
1363 let handle = queue.spawn(async move {
1364 counter_clone.fetch_add(1, Ordering::Relaxed);
1365 42
1366 });
1367 sleep(Duration::from_millis(100)).await;
1369 assert!(counter.load(Ordering::Relaxed) > 0);
1370 handle.abort();
1372 handle.await
1373 })
1374 .await;
1375 assert_eq!(result, Ok(42));
1376 })
1377 .await;
1378 }
1379
1380 #[test]
1382 fn test_monoio_runtime() {
1383 use monoio::LegacyDriver;
1384 let mut rt = monoio::RuntimeBuilder::<LegacyDriver>::new()
1385 .enable_timer() .build()
1387 .unwrap();
1388 let _ = rt.block_on(async move {
1389 let executor = ExecutorBuilder::new().with_queue(0, 1).build().unwrap();
1390 let counter = Arc::new(AtomicU32::new(0));
1391
1392 let counter_clone = counter.clone();
1393 let queue = executor.queue(0).unwrap();
1394 let result = executor
1395 .run_until(async {
1396 assert_eq!(counter.load(Ordering::Relaxed), 0);
1398
1399 let handle = queue.spawn(async move {
1400 counter_clone.fetch_add(1, Ordering::Relaxed);
1401 42
1402 });
1403 monoio::time::sleep(Duration::from_millis(100)).await;
1404 assert_eq!(counter.load(Ordering::Relaxed), 1);
1406 handle.await
1407 })
1408 .await;
1409 assert_eq!(result, Ok(42));
1410 });
1411 }
1412
1413 #[test]
1414 fn test_bad_executor_creation() {
1415 let result = ExecutorBuilder::new().with_queue(0, 0).build();
1417 assert!(result.is_err());
1418 let result = ExecutorBuilder::new()
1420 .with_queue(0, 1)
1421 .with_queue(0, 1)
1422 .build();
1423 assert!(result.is_err());
1424 let result = Executor::<u8>::new(ExecutorOptions::default(), vec![]);
1426 assert!(result.is_err());
1427 }
1428
1429 #[tokio::test]
1430 async fn test_panic_crashes_executor() {
1431 let local = LocalSet::new();
1432 local
1433 .run_until(async {
1434 let executor = ExecutorBuilder::new()
1435 .with_queue(0, 1)
1436 .build()
1437 .unwrap();
1438 let queue = executor.queue(0).unwrap();
1439 let handle = tokio::task::spawn_local(async move {
1440 executor.run_until(sleep(Duration::from_millis(100))).await;
1441 });
1442 let _ = queue.spawn(async {
1443 panic!("test");
1444 });
1445 let result = handle.await;
1446 assert!(result.is_err());
1447 assert!(result.unwrap_err().is_panic());
1448 })
1449 .await;
1450 }
1451
1452 #[tokio::test]
1453 async fn test_panic_caught_when_configured() {
1454 let local = LocalSet::new();
1455 local
1456 .run_until(async {
1457 let executor = ExecutorBuilder::new()
1459 .with_panic_on_task_panic(false)
1460 .with_queue(0, 1)
1461 .build()
1462 .unwrap();
1463 let queue = executor.queue(0).unwrap();
1464 let result = executor.run_until(async {
1465 let task_handle = queue.spawn(async {
1466 panic!("test panic message");
1467 });
1468 task_handle.await
1469 });
1470
1471 let result = timeout(Duration::from_millis(100), result).await;
1473 assert!(result.is_ok(), "Task should complete (with panic error)");
1474
1475 let join_result = result.unwrap();
1476 assert!(join_result.is_err(), "Task should return an error");
1477
1478 match join_result.unwrap_err() {
1479 JoinError::Panic(_) => {
1480 }
1482 other => panic!("Expected JoinError::Panic, got {:?}", other),
1483 }
1484
1485 assert_eq!(executor.task_queues.len(), 1,);
1487 })
1488 .await;
1489 }
1490
1491 #[tokio::test]
1492 async fn test_preemption_mask_computed_correctly() {
1493 let local = LocalSet::new();
1495 local
1496 .run_until(async {
1497 let executor = ExecutorBuilder::new()
1502 .with_queue(0, 8)
1503 .with_queue(1, 4)
1504 .with_queue(2, 1)
1505 .build()
1506 .unwrap();
1507
1508 let queue2 = executor.queue(2).unwrap();
1509 let preempt_state = executor.preempt_state.clone();
1510
1511 executor
1512 .run_until(async {
1513 let handle = queue2.spawn(async {
1515 loop {
1516 yield_once().await;
1517 }
1518 });
1519
1520 sleep(Duration::from_millis(10)).await;
1524
1525 assert!(
1529 preempt_state.would_preempt(0),
1530 "Queue 0 (weight 8) should preempt queue 2 (weight 1)"
1531 );
1532 assert!(
1533 preempt_state.would_preempt(1),
1534 "Queue 1 (weight 4) should preempt queue 2 (weight 1)"
1535 );
1536 assert!(
1537 !preempt_state.would_preempt(2),
1538 "Queue 2 is runnable, should not be in preempt mask"
1539 );
1540 assert!(
1541 !preempt_state.check(),
1542 "Preempt flag should not be set (no higher priority task enqueued)"
1543 );
1544
1545 handle.abort();
1546 let _ = handle.await;
1547 })
1548 .await;
1549 })
1550 .await;
1551 }
1552}