Skip to main content

clockworker/
executor.rs

1use crate::{
2    join::{JoinHandle, JoinState},
3    preempt::PreemptState,
4    queue::{Queue, QueueKey, TaskId, TaskQueue},
5    stats::{ExecutorStats, QueueStats},
6    task::TaskHeader,
7    yield_once::yield_once,
8};
9use futures::FutureExt;
10use futures_util::task::AtomicWaker;
11use slab::Slab;
12use static_assertions::assert_not_impl_any;
13use std::sync::atomic::AtomicBool;
14use std::{
15    cell::Cell,
16    cell::RefCell,
17    future::Future,
18    mem,
19    pin::Pin,
20    rc::Rc,
21    sync::atomic::Ordering,
22    sync::Arc,
23    task::{Context, Poll},
24    time::{Duration, Instant},
25};
26
27thread_local! {
28    static YIELD_MAYBE_DEADLINE: Cell<Option<Instant>> = Cell::new(None);
29}
30
31fn set_yield_maybe_deadline(deadline: Instant) {
32    YIELD_MAYBE_DEADLINE.with(|cell| cell.set(Some(deadline)));
33}
34
35#[derive(Debug)]
36pub enum SpawnError<K: QueueKey> {
37    ShuttingDown,
38    QueueNotFound(K),
39    InvalidShare(u64),
40}
41
42/// Wraps a user given future to make it cancelable
43/// This future only returns () - when the underlying future completes,
44/// the result is published to the JoinState, which wrapped by Join Handle
45/// can be awaited by the user.
46struct CancelableFuture<T, K: QueueKey, F: Future<Output = T> + 'static> {
47    header: Arc<TaskHeader<K>>, // has `cancelled: AtomicBool`
48    join: Arc<JoinState<T>>,
49    fut: Pin<Box<F>>,
50    catch_panics: bool,
51}
52
53impl<T, K: QueueKey, F: Future<Output = T> + 'static> CancelableFuture<T, K, F> {
54    pub fn new(
55        header: Arc<TaskHeader<K>>,
56        join: Arc<JoinState<T>>,
57        fut: F,
58        catch_panics: bool,
59    ) -> Self {
60        Self {
61            header,
62            join,
63            fut: Box::pin(fut),
64            catch_panics,
65        }
66    }
67}
68
69impl<T, K: QueueKey, F: Future<Output = T> + 'static> Future for CancelableFuture<T, K, F> {
70    type Output = ();
71
72    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
73        // If already completed (maybe abort() completed join immediately), stop.
74        if self.join.is_done() {
75            return Poll::Ready(());
76        }
77
78        // Cancellation intent is owned by the task header.
79        if self.header.is_cancelled() {
80            self.join.try_complete_cancelled();
81            return Poll::Ready(());
82        }
83
84        // Poll with optional panic handling
85        let poll_result = if self.catch_panics {
86            std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| self.fut.as_mut().poll(cx)))
87        } else {
88            Ok(self.fut.as_mut().poll(cx))
89        };
90
91        match poll_result {
92            Ok(Poll::Ready(out)) => {
93                self.join.try_complete_ok(out);
94                Poll::Ready(())
95            }
96            Ok(Poll::Pending) => Poll::Pending,
97            Err(panic_payload) => {
98                // Convert panic to JoinError::Panic
99                let panic_err = crate::join::PanicError::from_panic_payload(panic_payload);
100                self.join
101                    .try_complete_err(crate::join::JoinError::Panic(panic_err));
102                Poll::Ready(())
103            }
104        }
105    }
106}
107
108/// Lightweight wrapper for fire-and-forget tasks that don't need a JoinState.
109/// The result is discarded when the future completes.
110struct DetachedFuture<K: QueueKey, F: Future + 'static> {
111    header: Arc<TaskHeader<K>>,
112    fut: Pin<Box<F>>,
113    catch_panics: bool,
114}
115
116impl<K: QueueKey, F: Future + 'static> DetachedFuture<K, F> {
117    fn new(header: Arc<TaskHeader<K>>, fut: F, catch_panics: bool) -> Self {
118        Self {
119            header,
120            fut: Box::pin(fut),
121            catch_panics,
122        }
123    }
124}
125
126impl<K: QueueKey, F: Future + 'static> Future for DetachedFuture<K, F> {
127    type Output = ();
128
129    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
130        if self.header.is_cancelled() {
131            return Poll::Ready(());
132        }
133
134        let poll_result = if self.catch_panics {
135            std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| self.fut.as_mut().poll(cx)))
136        } else {
137            Ok(self.fut.as_mut().poll(cx))
138        };
139
140        match poll_result {
141            Ok(Poll::Ready(_)) => Poll::Ready(()),
142            Ok(Poll::Pending) => Poll::Pending,
143            Err(_) => Poll::Ready(()),
144        }
145    }
146}
147
148/// Wrapper to create a waker that sets a flag and wakes an AtomicWaker when woken.
149/// Used by `run_until` to detect when the `until` future's dependencies become ready.
150struct UntilWakerWrapper {
151    woken: Arc<std::sync::atomic::AtomicBool>,
152    idle_waker: Arc<futures_util::task::AtomicWaker>,
153}
154
155impl futures_util::task::ArcWake for UntilWakerWrapper {
156    fn wake_by_ref(arc_self: &Arc<Self>) {
157        arc_self.woken.store(true, Ordering::Release);
158        arc_self.idle_waker.wake();
159    }
160}
161
162/// Local (executor-thread-only) task record containing the !Send future.
163struct TaskRecord<K: QueueKey> {
164    header: Arc<TaskHeader<K>>,
165    waker: std::task::Waker,
166    fut: Pin<Box<dyn Future<Output = ()> + 'static>>, // !Send ok - type-erased CancelableFuture
167}
168
169/// Global per-queue state maintained by the executor (vruntime/shares).
170struct QueueState<K: QueueKey> {
171    vruntime: u128, // total CPU time consumed (in nanoseconds)
172    share: u64,
173    task_queue: Arc<TaskQueue>,
174    stats: QueueStats<K>,
175}
176
177impl<K: QueueKey> QueueState<K> {
178    fn new(queue: Queue<K>, task_queue: Arc<TaskQueue>) -> Self {
179        Self {
180            vruntime: 0,
181            stats: QueueStats::new(queue.id(), queue.share()),
182            share: queue.share(),
183            task_queue,
184        }
185    }
186}
187
188pub struct QueueHandle<K: QueueKey> {
189    executor: Rc<Executor<K>>,
190    qid: K,
191}
192impl<K: QueueKey> QueueHandle<K> {
193    pub fn spawn<T, F>(self: &Self, fut: F) -> JoinHandle<T, K>
194    where
195        T: 'static,
196        F: Future<Output = T> + 'static, // !Send ok
197    {
198        self.executor.spawn_inner(self.qid, fut)
199    }
200
201    /// Spawn a fire-and-forget task. No JoinHandle or JoinState is allocated,
202    /// making this cheaper than `spawn` when the caller doesn't need the result.
203    pub fn spawn_detached<F>(self: &Self, fut: F)
204    where
205        F: Future + 'static,
206    {
207        self.executor.spawn_detached_inner(self.qid, fut)
208    }
209}
210
211pub struct ExecutorBuilder<K: QueueKey> {
212    options: ExecutorOptions,
213    queues: Vec<Queue<K>>,
214}
215impl<K: QueueKey> ExecutorBuilder<K> {
216    pub fn new() -> Self {
217        Self {
218            options: ExecutorOptions::default(),
219            queues: Vec::new(),
220        }
221    }
222    pub fn with_sched_latency(mut self, sched_latency: Duration) -> Self {
223        self.options.sched_latency = sched_latency;
224        self
225    }
226    pub fn with_min_slice(mut self, min_slice: Duration) -> Self {
227        self.options.min_slice = min_slice;
228        self
229    }
230    pub fn with_driver_yield(mut self, driver_yield: Duration) -> Self {
231        self.options.driver_yield = driver_yield;
232        self
233    }
234
235    /// Add a queue with FIFO scheduling.
236    pub fn with_queue(mut self, qid: K, share: u64) -> Self {
237        let queue = Queue::new(qid, share);
238        self.queues.push(queue);
239        self
240    }
241    pub fn with_panic_on_task_panic(mut self, panic_on_task_panic: bool) -> Self {
242        self.options.panic_on_task_panic = panic_on_task_panic;
243        self
244    }
245    /// Set the maximum number of task polls before yielding to the driver.
246    /// This ensures I/O-heavy workloads don't starve the reactor.
247    /// Default is 61.
248    pub fn with_max_polls_per_yield(mut self, max_polls: u32) -> Self {
249        self.options.max_polls_per_yield = max_polls;
250        self
251    }
252    /// Enable or disable LIFO slot optimization.
253    /// When enabled, the most recently enqueued task is prioritized for cache locality.
254    /// Default is true.
255    pub fn with_enable_lifo(mut self, enable: bool) -> Self {
256        self.options.enable_lifo = enable;
257        self
258    }
259    /// Set the LIFO skip interval.
260    /// Every N pops, the LIFO slot is skipped to maintain fairness.
261    /// Only used when LIFO is enabled.
262    /// Default is 16.
263    pub fn with_lifo_skip_interval(mut self, interval: usize) -> Self {
264        self.options.lifo_skip_interval = interval;
265        self
266    }
267    pub fn build(self) -> Result<Rc<Executor<K>>, String> {
268        Executor::new(self.options, self.queues)
269    }
270}
271
272pub struct ExecutorOptions {
273    sched_latency: Duration,
274    min_slice: Duration,
275    driver_yield: Duration,
276    panic_on_task_panic: bool,
277    /// Maximum number of task polls before yielding to the driver.
278    /// This ensures I/O-heavy workloads don't starve the reactor.
279    max_polls_per_yield: u32,
280    /// Enable LIFO slot optimization for cache locality.
281    enable_lifo: bool,
282    /// Skip LIFO slot every N pops to maintain fairness.
283    lifo_skip_interval: usize,
284}
285impl Default for ExecutorOptions {
286    fn default() -> Self {
287        Self {
288            sched_latency: Duration::from_millis(2),
289            min_slice: Duration::from_micros(100),
290            driver_yield: Duration::from_micros(500),
291            panic_on_task_panic: true,
292            max_polls_per_yield: 61, // same as tokio
293            enable_lifo: false, // disabled by default
294            lifo_skip_interval: 16,
295        }
296    }
297}
298
299/// The priority executor: single-thread polling + class vruntime selection.
300pub struct Executor<K: QueueKey> {
301    options: ExecutorOptions,
302    task_queues: Vec<Arc<TaskQueue>>,
303    is_runnable: RefCell<Vec<bool>>, // true iff ith queue is runnable
304
305    tasks: RefCell<Slab<TaskRecord<K>>>,
306    queues: RefCell<Vec<QueueState<K>>>,
307    qids: RefCell<Vec<K>>,
308
309    min_vruntime: std::cell::Cell<u128>,
310
311    /// Shared preemption state - allows wakers to signal when a higher-priority
312    /// queue has tasks, enabling early timeslice termination.
313    preempt_state: Arc<PreemptState>,
314
315    // stats
316    stats: RefCell<ExecutorStats>,
317}
318
319assert_not_impl_any!(Executor<u8>: Send, Sync);
320
321impl<K: QueueKey> Executor<K> {
322    /// Create an executor with N classes, each with a weight (share).
323    pub fn new(options: ExecutorOptions, queues: Vec<Queue<K>>) -> Result<Rc<Self>, String> {
324        if queues.is_empty() {
325            return Err("Must have at least one queue".to_string());
326        }
327        // verify that all queues have unique ids
328        for i in 0..queues.len() {
329            for j in i + 1..queues.len() {
330                if queues[i].id() == queues[j].id() {
331                    return Err("All queues must have unique ids".to_string());
332                }
333            }
334        }
335        // no share can be 0
336        if queues.iter().any(|q| q.share() == 0) {
337            return Err("All queues must have a share > 0".to_string());
338        }
339
340        // Create one mpsc channel per queue
341        let num_queues = queues.len();
342        if num_queues > 256 {
343            return Err("Cannot have more than 256 queues (preemption mask limit)".to_string());
344        }
345
346        // Create shared preemption state
347        let preempt_state = Arc::new(PreemptState::new());
348
349        let task_queues: Vec<Arc<TaskQueue>> = (0..num_queues)
350            .map(|_| {
351                Arc::new(TaskQueue::new(
352                    options.enable_lifo,
353                    options.lifo_skip_interval,
354                ))
355            })
356            .collect();
357
358        let qids = queues.iter().map(|q| q.id()).collect::<Vec<_>>();
359        let queues: Vec<QueueState<K>> = queues
360            .into_iter()
361            .enumerate()
362            .map(|(idx, q)| QueueState::new(q, task_queues[idx].clone()))
363            .collect();
364
365        Ok(Rc::new(Self {
366            task_queues,
367            is_runnable: RefCell::new(vec![false; num_queues]),
368            tasks: RefCell::new(Slab::new()),
369            queues: RefCell::new(queues),
370            qids: RefCell::new(qids),
371            options,
372            min_vruntime: std::cell::Cell::new(0),
373            preempt_state,
374            stats: RefCell::new(ExecutorStats::new(Instant::now())),
375        }))
376    }
377
378    /// Get a handle to a queue through which tasks can be spawned
379    pub fn queue(self: &Rc<Self>, qid: K) -> Result<QueueHandle<K>, SpawnError<K>> {
380        let Some(_) = self.qids.borrow().iter().position(|q| *q == qid) else {
381            return Err(SpawnError::QueueNotFound(qid));
382        };
383        Ok(QueueHandle {
384            executor: self.clone(),
385            qid,
386        })
387    }
388
389    /// Internal method to spawn a task onto a queue.
390    fn spawn_inner<T, F>(self: &Rc<Self>, qid: K, fut: F) -> JoinHandle<T, K>
391    where
392        T: 'static,
393        F: Future<Output = T> + 'static, // !Send ok
394    {
395        let qid = qid.into();
396        let qidx = self
397            .qids
398            .borrow()
399            .iter()
400            .position(|q| *q == qid)
401            .expect("queue should exist");
402        let mut tasks = self.tasks.borrow_mut();
403        let entry = tasks.vacant_entry();
404        let id = entry.key();
405        let preempt_state = if self.task_queues.len() > 1 {
406            Some(self.preempt_state.clone())
407        } else {
408            None
409        };
410        let header = Arc::new(TaskHeader::new(
411            id,
412            qid,
413            qidx,
414            self.task_queues[qidx].clone(),
415            preempt_state,
416        ));
417        let join = Arc::new(JoinState::<T>::new());
418        // Wrap user future to publish result into JoinState.
419        // catch_panics = !panic_on_task_panic (if executor panics on task panic, we don't catch)
420        let catch_panics = !self.options.panic_on_task_panic;
421        let wrapped = CancelableFuture::new(header.clone(), join.clone(), fut, catch_panics);
422
423        let waker = futures::task::waker(header.clone());
424
425        entry.insert(TaskRecord {
426            header: header.clone(),
427            waker,
428            fut: Box::pin(wrapped),
429        });
430
431        // Enqueue initially.
432        header.enqueue();
433
434        JoinHandle::new(header, join)
435    }
436
437    /// Internal method to spawn a detached (fire-and-forget) task onto a queue.
438    /// No JoinState is allocated, making this cheaper than spawn_inner.
439    fn spawn_detached_inner<F>(self: &Rc<Self>, qid: K, fut: F)
440    where
441        F: Future + 'static,
442    {
443        let qid = qid.into();
444        let qidx = self
445            .qids
446            .borrow()
447            .iter()
448            .position(|q| *q == qid)
449            .expect("queue should exist");
450        let mut tasks = self.tasks.borrow_mut();
451        let entry = tasks.vacant_entry();
452        let id = entry.key();
453        let preempt_state = if self.task_queues.len() > 1 {
454            Some(self.preempt_state.clone())
455        } else {
456            None
457        };
458        let header = Arc::new(TaskHeader::new(
459            id,
460            qid,
461            qidx,
462            self.task_queues[qidx].clone(),
463            preempt_state,
464        ));
465        let catch_panics = !self.options.panic_on_task_panic;
466        let wrapped = DetachedFuture::new(header.clone(), fut, catch_panics);
467
468        let waker = futures::task::waker(header.clone());
469
470        entry.insert(TaskRecord {
471            header: header.clone(),
472            waker,
473            fut: Box::pin(wrapped),
474        });
475
476        header.enqueue();
477    }
478
479    /// Pick the next runnable class by deadline among classes that have
480    /// runnable tasks. Deadline is vruntime + sched_latency / num_runnable,
481    /// so higher weight classes have lower deadline for the same CPU time,
482    /// making them preferred.
483    ///
484    /// Returns (selected_idx, timeslice, selected_deadline, num_runnable) or None if no queues are runnable.
485    fn pick_next_class(&self) -> Option<(usize, Duration, u128, usize)> {
486        let mut best: Option<(usize, u128)> = None;
487        let mut runnable = None;
488        let mut num_runnable = 0;
489        let mut is_runnable = self.is_runnable.borrow_mut();
490        for (idx, q) in self.queues.borrow_mut().iter_mut().enumerate() {
491            let was_runnable = is_runnable[idx];
492            is_runnable[idx] = !q.task_queue.is_empty();
493            if !was_runnable && is_runnable[idx] {
494                // wasn't runnable before, but is now - inherit vruntime
495                q.vruntime = q.vruntime.max(self.min_vruntime.get());
496            }
497            if is_runnable[idx] {
498                num_runnable += 1;
499                runnable = Some(idx);
500            }
501        }
502        if num_runnable == 0 {
503            return None;
504        }
505        let request = self.options.sched_latency.as_nanos() as u128 / num_runnable as u128;
506        let request = request.max(self.options.min_slice.as_nanos() as u128);
507
508        if num_runnable == 1 {
509            let selected_idx = runnable.unwrap();
510            let queues = self.queues.borrow();
511            let selected_deadline =
512                queues[selected_idx].vruntime + (request / queues[selected_idx].share as u128);
513            return Some((
514                selected_idx,
515                Duration::from_nanos(request as u64),
516                selected_deadline,
517                num_runnable,
518            ));
519        }
520
521        // Multiple runnable queues - find the best one
522        for (idx, q) in self.queues.borrow().iter().enumerate() {
523            if q.task_queue.is_empty() {
524                continue;
525            }
526            // d_i = vruntime_i + request / share_i
527            let deadline = q.vruntime + (request / q.share as u128);
528            match best {
529                None => best = Some((idx, deadline)),
530                Some((_, bv)) if deadline < bv => best = Some((idx, deadline)),
531                _ => {}
532            }
533        }
534
535        let (selected_idx, selected_deadline) = best.unwrap();
536        Some((
537            selected_idx,
538            Duration::from_nanos(request as u64),
539            selected_deadline,
540            num_runnable,
541        ))
542    }
543
544    /// Compute and update the preemption mask based on the selected queue.
545    /// Empty queues that would have higher priority than the selected queue
546    /// are marked in the mask so their wakers can trigger preemption.
547    fn update_preempt_mask(&self, selected_deadline: u128, num_runnable: usize) {
548        let is_runnable = self.is_runnable.borrow();
549        let queues = self.queues.borrow();
550
551        // Calculate hypothetical request if one more queue becomes runnable
552        let hypothetical_request =
553            self.options.sched_latency.as_nanos() as u128 / (num_runnable + 1) as u128;
554        let hypothetical_request =
555            hypothetical_request.max(self.options.min_slice.as_nanos() as u128);
556        let min_vruntime = self.min_vruntime.get();
557
558        // Find empty queues that would preempt if they got a task
559        let preempting = (0..queues.len()).filter(|&idx| {
560            if is_runnable[idx] {
561                return false; // Already runnable, skip
562            }
563            // Hypothetical deadline: inherits min_vruntime
564            let hypothetical_deadline =
565                min_vruntime + (hypothetical_request / queues[idx].share as u128);
566            hypothetical_deadline < selected_deadline
567        });
568        self.preempt_state.update_mask(preempting);
569    }
570
571    /// Charge elapsed CPU time to a class.
572    /// We track total CPU time in nanoseconds and compute vruntime on-the-fly
573    /// when selecting (total_cpu_nanos / weight), avoiding rounding issues.
574    fn charge_class(&self, qidx: usize, elapsed: Duration) {
575        if self.task_queues.len() <= 1 {
576            return;
577        }
578        let mut queues = self.queues.borrow_mut();
579        let queue = &mut queues[qidx];
580        // ceil of (elapsed / share)
581        let incr = (elapsed.as_nanos() + queue.share as u128 - 1) / (queue.share as u128);
582        queue.vruntime += incr;
583        queue.stats.record_poll(elapsed);
584    }
585    fn update_min_vruntime(&self, including: u128) {
586        if self.task_queues.len() <= 1 {
587            return;
588        }
589        let min_vruntime = self
590            .queues
591            .borrow()
592            .iter()
593            .filter(|q| !q.task_queue.is_empty())
594            .map(|q| q.vruntime)
595            .chain(Some(including))
596            .min();
597        let min_vruntime = min_vruntime.unwrap();
598        // update executor's min_vruntime
599        let prev_min_vruntime = self.min_vruntime.get();
600        self.min_vruntime.set(prev_min_vruntime.max(min_vruntime));
601    }
602
603    /// Get the current executor stats.
604    pub fn stats(&self) -> ExecutorStats {
605        self.stats.borrow().clone()
606    }
607
608    /// Get the current queue stats.
609    pub fn qstats(&self) -> Vec<QueueStats<K>> {
610        self.queues
611            .borrow()
612            .iter()
613            .map(|q| q.stats.clone())
614            .collect()
615    }
616
617    /// Run the executor loop until the given future completes.
618    ///
619    /// Panic behavior: if any task panics while being polled, the executor
620    /// panics (propagates) unless executor has been configured to catch panics
621    /// with `with_panic_on_task_panic(false)`.
622    ///
623    /// The executor will continue running tasks until `until` completes, then
624    /// returns. When the executor stops, pending tasks remain pending and will
625    /// resume if `run_until` is called again (Tokio-like behavior).
626    pub async fn run_until<F: Future>(&self, until: F) -> F::Output {
627        let mut until_pinned = std::pin::pin!(until.fuse());
628
629        // Flag that gets set when until's waker is called
630        let until_woken = Arc::new(AtomicBool::new(false));
631        // Waker to wake the idle wait loop when until_woken is set
632        let idle_waker = Arc::new(AtomicWaker::new());
633        // Create a waker that sets the flag and wakes idle_waker
634        let until_waker = self.create_until_waker(until_woken.clone(), idle_waker.clone());
635
636        let mut last_driver_yield_at = Instant::now();
637        let mut iter = 0u64;
638
639        // Initial poll to register our waker
640        {
641            let mut cx = Context::from_waker(&until_waker);
642            if let Poll::Ready(result) = until_pinned.as_mut().poll(&mut cx) {
643                return result;
644            }
645        }
646
647        loop {
648            iter += 1;
649            let enable_stats = iter % 128 == 0;
650            self.stats.borrow_mut().record_loop_iter();
651
652            // Only poll until if it was woken (its dependencies became ready)
653            if until_woken.swap(false, Ordering::AcqRel) {
654                let mut cx = Context::from_waker(&until_waker);
655                if let Poll::Ready(result) = until_pinned.as_mut().poll(&mut cx) {
656                    return result;
657                }
658            }
659
660            // Select next queue to run
661            let Some((qidx, timeslice)) = self.select_queue(enable_stats) else {
662                // Nothing runnable - wait for work or until_woken signal
663                self.wait_for_work_or_signal(&until_woken, &idle_waker)
664                    .await;
665                continue;
666            };
667
668            // Execute timeslice
669            let timeslice = timeslice.min(self.options.driver_yield);
670            let end = self.run_timeslice(qidx, timeslice, enable_stats);
671
672            // Update executor's min_vruntime
673            let new_vruntime = self.queues.borrow()[qidx].vruntime;
674            self.update_min_vruntime(new_vruntime);
675
676            // Yield to driver
677            last_driver_yield_at = self.yield_to_driver(last_driver_yield_at, end).await;
678        }
679    }
680
681    /// Create a waker that sets `until_woken` and wakes `idle_waker` when called.
682    fn create_until_waker(
683        &self,
684        until_woken: Arc<std::sync::atomic::AtomicBool>,
685        idle_waker: Arc<futures_util::task::AtomicWaker>,
686    ) -> std::task::Waker {
687        let wrapper = Arc::new(UntilWakerWrapper {
688            woken: until_woken,
689            idle_waker,
690        });
691        futures::task::waker(wrapper)
692    }
693
694    /// Wait for either a queue to receive a task or `until_woken` to be set.
695    /// This is a single poll_fn that registers on all wakers directly, avoiding allocations.
696    async fn wait_for_work_or_signal(
697        &self,
698        until_woken: &Arc<AtomicBool>,
699        idle_waker: &Arc<AtomicWaker>,
700    ) {
701        use futures_util::future::poll_fn;
702
703        poll_fn(|cx| {
704            // Register our waker with idle_waker (for until_woken signal)
705            idle_waker.register(cx.waker());
706
707            // Register our waker with each queue's TaskQueue
708            // Register on every poll to ensure we don't miss wakeups
709            // (AtomicWaker handles duplicate registrations efficiently)
710            for task_queue in &self.task_queues {
711                task_queue.register_waker(cx.waker());
712            }
713
714            // Check if until was woken
715            if until_woken.load(Ordering::Acquire) {
716                return Poll::Ready(());
717            }
718
719            // Check if any queue has items
720            for task_queue in &self.task_queues {
721                if !task_queue.is_empty() {
722                    return Poll::Ready(());
723                }
724            }
725
726            Poll::Pending
727        })
728        .await
729    }
730
731    /// Select the next queue to run and measure the decision time.
732    /// Clears the preempt flag when a new timeslice is selected.
733    fn select_queue(&self, enable_stats: bool) -> Option<(usize, Duration)> {
734        let start = if enable_stats {
735            Some(Instant::now())
736        } else {
737            None
738        };
739        // if there is only one queue, bypass all machinery
740        if self.task_queues.len() == 1 {
741            match self.task_queues[0].is_empty() {
742                true => return None,
743                false => return Some((0, self.options.sched_latency)),
744            }
745        }
746
747        let Some((selected_idx, timeslice, selected_deadline, num_runnable)) =
748            self.pick_next_class()
749        else {
750            // No runnable queues, clear preempt mask
751            self.preempt_state.update_mask(std::iter::empty());
752            return None;
753        };
754
755        // Clear preempt flag when starting a new timeslice
756        self.preempt_state.clear_preempt();
757
758        // Compute which empty queues would preempt the selected queue
759        self.update_preempt_mask(selected_deadline, num_runnable);
760
761        if let Some(start) = start {
762            let elapsed = Instant::now().duration_since(start);
763            self.stats.borrow_mut().record_schedule_decision(elapsed);
764        }
765
766        Some((selected_idx, timeslice))
767    }
768
769    /// Pop the next valid task from a queue, skipping stale/done tasks.
770    /// Returns task id
771    fn pop_next_task_from_queue(&self, qidx: usize) -> Option<TaskId> {
772        loop {
773            let mut queues = self.queues.borrow_mut();
774            let queue = &mut queues[qidx];
775            // Check if queue was runnable before pop (to detect when it becomes runnable)
776            queue.stats.record_runnable_dequeue();
777            let maybe_id = queue.task_queue.pop();
778
779            drop(queues);
780
781            let Some(id) = maybe_id else {
782                return None;
783            };
784
785            let tasks = self.tasks.borrow();
786            let Some(task) = tasks.get(id) else {
787                // Stale id; try again
788                continue;
789            };
790
791            if task.header.is_done() {
792                // Spurious task; try again
793                continue;
794            }
795
796            return Some(id);
797        }
798    }
799
800    /// Poll a single task and return whether it completed.
801    /// Timing and vruntime charging are handled by the caller (run_timeslice).
802    fn poll_task(&self, id: TaskId) -> bool {
803        // Extract the future from the task while keeping the task in the Slab.
804        // This allows us to release the borrow before polling, enabling nested spawns.
805        let (waker, mut extracted_fut) = {
806            let mut tasks = self.tasks.borrow_mut();
807            let task = match tasks.get_mut(id) {
808                Some(task) => task,
809                None => return false,
810            };
811
812            // Clear queued before polling so a wake during poll can enqueue again.
813            task.header.set_queued(false);
814
815            // Clone what we need for the context
816            let waker = task.waker.clone();
817
818            // Extract the future using a dummy placeholder
819            // We use futures::future::ready(()) as a placeholder that immediately resolves
820            let placeholder = Box::pin(futures::future::ready(()));
821            let extracted_fut = mem::replace(&mut task.fut, placeholder);
822
823            (waker, extracted_fut)
824        };
825        // Borrow is now released - the task remains in the Slab with the placeholder future
826
827        let mut cx = Context::from_waker(&waker);
828
829        // CancelableFuture handles panics internally, so we can poll directly
830        // Now we can poll without holding the tasks borrow, allowing nested spawns
831        let poll = extracted_fut.as_mut().poll(&mut cx);
832
833        // Put the future back (or leave placeholder if task completed)
834        {
835            let mut tasks = self.tasks.borrow_mut();
836            let task = match tasks.get_mut(id) {
837                Some(task) => task,
838                None => return false,
839            };
840
841            match poll {
842                Poll::Ready(()) => true,
843                Poll::Pending => {
844                    // Task still pending - put the future back
845                    let placeholder = mem::replace(&mut task.fut, extracted_fut);
846                    drop(placeholder);
847                    false
848                }
849            }
850        }
851    }
852
853    /// Complete a task that finished (Ready).
854    fn complete_task(&self, id: TaskId, _qidx: usize) {
855        let mut tasks = self.tasks.borrow_mut();
856        let task = tasks.get_mut(id).expect("task should exist");
857        task.header.set_done();
858        tasks.remove(id);
859    }
860
861    /// Execute tasks from a selected queue until the timeslice is exhausted,
862    /// preemption is requested, or max polls is reached.
863    ///
864    /// Uses adaptive timing: measures the first few polls to calibrate how
865    /// often to call Instant::now(), avoiding per-poll syscall overhead for
866    /// fast tasks while maintaining accuracy for slow ones.
867    fn run_timeslice(&self, qidx: usize, timeslice: Duration, enable_stats: bool) -> Instant {
868        let now = Instant::now();
869        let until = now + timeslice;
870        if enable_stats {
871            self.queues.borrow_mut()[qidx]
872                .stats
873                .record_first_service_after_runnable(now);
874        }
875
876        // Drain LIFO slot at start of timeslice
877        {
878            let queue = &self.queues.borrow()[qidx];
879            queue.task_queue.drain_lifo_to_mpsc();
880        }
881
882        let max_polls = self.options.max_polls_per_yield;
883
884        // Phase 1: Calibrate — time the first K polls individually to measure
885        // average poll duration, then use that to set the sampling interval.
886        const CALIBRATE_POLLS: u32 = 4;
887        let mut polls_this_slice = 0u32;
888
889        for _ in 0..CALIBRATE_POLLS {
890            set_yield_maybe_deadline(until);
891            let Some(id) = self.pop_next_task_from_queue(qidx) else {
892                break;
893            };
894            let completed = self.poll_task(id);
895            if completed {
896                self.complete_task(id, qidx);
897            }
898            polls_this_slice += 1;
899            if polls_this_slice >= max_polls {
900                break;
901            }
902        }
903
904        // Sample after calibration
905        let sample_now = Instant::now();
906        let calibrate_elapsed = sample_now.saturating_duration_since(now);
907        self.charge_class(qidx, calibrate_elapsed);
908
909        if polls_this_slice == 0 || polls_this_slice >= max_polls || sample_now > until {
910            return sample_now;
911        }
912
913        // Compute adaptive sample interval: target ~2µs between Instant::now() calls,
914        // clamped to [1, 16]. Fast polls → sample every 16. Slow polls → every poll.
915        let avg_poll_ns = calibrate_elapsed.as_nanos() / polls_this_slice as u128;
916        let sample_interval = if avg_poll_ns == 0 {
917            16u32
918        } else {
919            (2000u128 / avg_poll_ns).clamp(1, 16) as u32
920        };
921
922        // Phase 2: Run remaining polls with sampled timing
923        let mut last_sample = sample_now;
924        let mut polls_since_sample = 0u32;
925
926        loop {
927            set_yield_maybe_deadline(until);
928
929            let Some(id) = self.pop_next_task_from_queue(qidx) else {
930                break;
931            };
932
933            let completed = self.poll_task(id);
934            if completed {
935                self.complete_task(id, qidx);
936            }
937            polls_this_slice += 1;
938            polls_since_sample += 1;
939
940            // Periodic time check
941            if polls_since_sample >= sample_interval {
942                let sample_now = Instant::now();
943                let elapsed = sample_now.saturating_duration_since(last_sample);
944                self.charge_class(qidx, elapsed);
945                last_sample = sample_now;
946                polls_since_sample = 0;
947
948                if sample_now > until {
949                    if enable_stats {
950                        self.stats.borrow_mut().record_poll(elapsed, true);
951                        let mut queues = self.queues.borrow_mut();
952                        queues[qidx].stats.record_slice_overrun();
953                        queues[qidx].stats.record_slice_exhausted();
954                    }
955                    break;
956                }
957            }
958
959            if polls_this_slice >= max_polls {
960                break;
961            }
962            if self.preempt_state.check() {
963                break;
964            }
965        }
966
967        // Charge any remaining unsampled polls
968        if polls_since_sample > 0 {
969            let now = Instant::now();
970            let elapsed = now.saturating_duration_since(last_sample);
971            self.charge_class(qidx, elapsed);
972            last_sample = now;
973        }
974
975        last_sample
976    }
977
978    /// Yield to the driver and record stats.
979    async fn yield_to_driver(&self, last_yield: Instant, now: Instant) -> Instant {
980        let since_last = now - last_yield;
981        yield_once().await;
982        let after_yield = Instant::now();
983        let in_driver = after_yield.duration_since(now);
984        self.stats
985            .borrow_mut()
986            .record_driver_yield(since_last, in_driver);
987        after_yield
988    }
989}
990
991pub async fn yield_maybe() {
992    let should_yield = YIELD_MAYBE_DEADLINE.with(|d| {
993        if let Some(dl) = d.get() {
994            Instant::now() >= dl
995        } else {
996            false
997        }
998    });
999    if should_yield {
1000        // clear so we don't yield repeatedly in a tight loop
1001        YIELD_MAYBE_DEADLINE.with(|d| d.set(None));
1002        yield_once().await;
1003    }
1004}
1005
1006#[cfg(test)]
1007mod tests {
1008    use super::*;
1009    use crate::join::JoinError;
1010    use crate::yield_once::yield_once;
1011    use std::sync::atomic::AtomicBool;
1012    use std::sync::atomic::{AtomicU32, Ordering};
1013    use std::sync::{Arc, Mutex};
1014    use tokio::task::LocalSet;
1015    use tokio::time::{sleep, timeout, Duration};
1016
1017    #[tokio::test]
1018    async fn test_basic_task_completion() {
1019        let local = LocalSet::new();
1020        local
1021            .run_until(async {
1022                let executor = ExecutorBuilder::new()
1023                    .with_queue(0, 1)
1024                    .build()
1025                    .unwrap();
1026                let counter = Arc::new(AtomicU32::new(0));
1027
1028                let counter_clone = counter.clone();
1029                // Run executor until task completes
1030                let result = executor.run_until(async {
1031                    let queue = executor.queue(0).unwrap();
1032                    let handle = queue.spawn(async move {
1033                        counter_clone.fetch_add(1, Ordering::Relaxed);
1034                    });
1035                    handle.await
1036                });
1037                let result = timeout(Duration::from_millis(100), result).await;
1038                assert!(result.is_ok(), "Task should complete");
1039                assert_eq!(counter.load(Ordering::Relaxed), 1);
1040            })
1041            .await;
1042    }
1043
1044    #[tokio::test]
1045    async fn test_join_handle_returns_result() {
1046        let local = LocalSet::new();
1047        local
1048            .run_until(async {
1049                let executor = ExecutorBuilder::new()
1050                    .with_queue(0, 1)
1051                    .build()
1052                    .unwrap();
1053
1054                let result = executor.run_until(async {
1055                    let queue = executor.queue(0).unwrap();
1056                    let handle = queue.spawn(async move { 42 });
1057                    handle.await
1058                });
1059                let result = timeout(Duration::from_millis(100), result).await;
1060                assert!(result.is_ok(), "JoinHandle should complete");
1061                let join_result = result.unwrap();
1062                assert_eq!(join_result, Ok(42));
1063            })
1064            .await;
1065    }
1066
1067    #[tokio::test]
1068    async fn test_join_handle_abort() {
1069        let local = LocalSet::new();
1070        local
1071            .run_until(async {
1072                let executor = ExecutorBuilder::new()
1073                    .with_queue(0, 1)
1074                    .build()
1075                    .unwrap();
1076                let started = Arc::new(AtomicBool::new(false));
1077                let completed = Arc::new(AtomicBool::new(false));
1078                let started_clone = started.clone();
1079                let completed_clone = completed.clone();
1080
1081                let queue = executor.queue(0).unwrap();
1082                let handle = executor
1083                    .run_until(async {
1084                        let handle = queue.spawn(async move {
1085                            started_clone.store(true, Ordering::Relaxed);
1086                            // Task that runs for a while
1087                            for _ in 0..100 {
1088                                sleep(Duration::from_millis(10)).await;
1089                            }
1090                            completed_clone.store(true, Ordering::Relaxed);
1091                        });
1092                        // Wait a bit for task to start
1093                        sleep(Duration::from_millis(50)).await;
1094                        assert!(started.load(Ordering::Relaxed), "Task should have started");
1095
1096                        // Abort the task
1097                        handle.abort();
1098                        handle
1099                    })
1100                    .await;
1101
1102                // Wait for abort to be processed
1103                let result = timeout(Duration::from_millis(500), handle).await;
1104                assert!(result.is_ok(), "JoinHandle should complete after abort");
1105                let join_result = result.unwrap();
1106                assert!(matches!(join_result, Err(JoinError::Cancelled)));
1107
1108                // verify task didn't complete
1109                assert!(
1110                    !completed.load(Ordering::Relaxed),
1111                    "Task should not have completed"
1112                );
1113            })
1114            .await;
1115    }
1116
1117    #[tokio::test]
1118    async fn test_vruntime_scheduling() {
1119        let local = LocalSet::new();
1120        local
1121            .run_until(async {
1122                let executor = ExecutorBuilder::new()
1123                    .with_queue(0, 8)
1124                    .with_queue(1, 1)
1125                    .build()
1126                    .unwrap();
1127                let queue1 = executor.queue(0).unwrap();
1128                let queue2 = executor.queue(1).unwrap();
1129                let high = Arc::new(AtomicU32::new(0));
1130                let low = Arc::new(AtomicU32::new(0));
1131                let high_clone = high.clone();
1132                let low_clone = low.clone();
1133
1134                executor
1135                    .run_until(async {
1136                        // Spawn tasks that run indefinitely with some work per iteration.
1137                        // Note: We use yield_once() instead of sleep() because sleep() makes tasks
1138                        // pending (not runnable), so they can't compete for CPU, thus
1139                        // giving low weight class access to the CPU when high weight
1140                        // class is not runnable.
1141                        let handle1 = queue1.spawn(async move {
1142                            loop {
1143                                for _ in 0..100_000 {
1144                                    high_clone.fetch_add(1, Ordering::Relaxed);
1145                                }
1146                                yield_once().await;
1147                            }
1148                        });
1149                        let handle2 = queue2.spawn(async move {
1150                            loop {
1151                                for _ in 0..100_000 {
1152                                    low_clone.fetch_add(1, Ordering::Relaxed);
1153                                }
1154                                yield_once().await;
1155                            }
1156                        });
1157                        sleep(Duration::from_millis(100)).await;
1158                        handle1.abort();
1159                        handle2.abort();
1160                    })
1161                    .await;
1162                let high_count = high.load(Ordering::Relaxed);
1163                let low_count = low.load(Ordering::Relaxed);
1164                // High weight class should get more CPU time (roughly 8x)
1165                assert!(
1166                    low_count * 2 < high_count && high_count < low_count * 16,
1167                    "High weight class should get significantly more CPU time. High: {}, Low: {}",
1168                    high_count,
1169                    low_count
1170                );
1171            })
1172            .await;
1173    }
1174
1175    #[tokio::test]
1176    async fn test_policy_fifo_ordering() {
1177        let local = LocalSet::new();
1178        local
1179            .run_until(async {
1180                let executor = ExecutorBuilder::new()
1181                    .with_queue(0, 1)
1182                    .build()
1183                    .unwrap();
1184                let queue = executor.queue(0).unwrap();
1185                let execution_order = Arc::new(Mutex::new(Vec::new()));
1186
1187                // Spawn multiple tasks that should execute in FIFO order
1188                for i in 0..5 {
1189                    let order_clone = execution_order.clone();
1190                    let _handle = queue.spawn(async move {
1191                        order_clone.lock().unwrap().push(i);
1192                    });
1193                }
1194
1195                let executor_clone = executor.clone();
1196                local.spawn_local(async move {
1197                    // Run until timeout to let tasks complete
1198                    executor_clone
1199                        .run_until(sleep(Duration::from_millis(200)))
1200                        .await;
1201                });
1202
1203                // Wait for all tasks to complete
1204                sleep(Duration::from_millis(200)).await;
1205
1206                let order = execution_order.lock().unwrap();
1207                // Tasks should execute in FIFO order (0, 1, 2, 3, 4)
1208                assert_eq!(order.len(), 5, "All tasks should have executed");
1209                assert_eq!(
1210                    *order,
1211                    vec![0, 1, 2, 3, 4],
1212                    "Tasks should execute in FIFO order"
1213                );
1214            })
1215            .await;
1216    }
1217
1218    #[tokio::test]
1219    async fn test_multiple_tasks_same_class() {
1220        let local = LocalSet::new();
1221        local
1222            .run_until(async {
1223                let executor = ExecutorBuilder::new()
1224                    .with_queue(0, 1)
1225                    .build()
1226                    .unwrap();
1227                let queue = executor.queue(0).unwrap();
1228                let counter = Arc::new(AtomicU32::new(0));
1229                let counter_clone = counter.clone();
1230
1231                executor
1232                    .run_until(async {
1233                        let mut handles = Vec::new();
1234                        for _ in 0..5 {
1235                            let counter_clone = counter.clone();
1236                            let handle = queue.spawn(async move {
1237                                counter_clone.fetch_add(1, Ordering::Relaxed);
1238                            });
1239                            handles.push(handle);
1240                        }
1241                        for handle in handles {
1242                            let result = timeout(Duration::from_millis(100), handle).await;
1243                            assert!(result.is_ok(), "All tasks should complete");
1244                        }
1245                    })
1246                    .await;
1247                assert_eq!(counter_clone.load(Ordering::Relaxed), 5);
1248            })
1249            .await;
1250    }
1251
1252    #[tokio::test]
1253    async fn test_task_with_yield() {
1254        let local = LocalSet::new();
1255        local
1256            .run_until(async {
1257                let executor = ExecutorBuilder::new()
1258                    .with_queue(0, 1)
1259                    .build()
1260                    .unwrap();
1261                let queue = executor.queue(0).unwrap();
1262                let counter = Arc::new(AtomicU32::new(0));
1263
1264                let counter_clone = counter.clone();
1265                executor
1266                    .run_until(async {
1267                        let handle = queue.spawn(async move {
1268                            for _ in 0..3 {
1269                                counter_clone.fetch_add(1, Ordering::Relaxed);
1270                                sleep(Duration::from_millis(10)).await;
1271                            }
1272                        });
1273                        let result = timeout(Duration::from_millis(500), handle).await;
1274                        assert!(
1275                            result.is_ok(),
1276                            "Task with yields should complete, got {:?}",
1277                            result
1278                        );
1279                    })
1280                    .await;
1281
1282                assert_eq!(counter.load(Ordering::Relaxed), 3);
1283            })
1284            .await;
1285    }
1286
1287    #[tokio::test]
1288    async fn test_abort_before_task_starts() {
1289        let local = LocalSet::new();
1290        local
1291            .run_until(async {
1292                let executor = ExecutorBuilder::new()
1293                    .with_queue(0, 1)
1294                    .build()
1295                    .unwrap();
1296                let queue = executor.queue(0).unwrap();
1297                let executed = Arc::new(AtomicBool::new(false));
1298
1299                let executed_clone = executed.clone();
1300                let handle = queue.spawn(async move {
1301                    executed_clone.store(true, Ordering::Relaxed);
1302                });
1303
1304                // Abort immediately before executor runs
1305                handle.abort();
1306
1307                let executor_clone = executor.clone();
1308                local.spawn_local(async move {
1309                    executor_clone
1310                        .run_until(sleep(Duration::from_millis(100)))
1311                        .await;
1312                });
1313
1314                // Wait a bit
1315                sleep(Duration::from_millis(100)).await;
1316
1317                // Task should not have executed
1318                assert!(
1319                    !executed.load(Ordering::Relaxed),
1320                    "Task should not execute after abort"
1321                );
1322
1323                // JoinHandle should return Cancelled
1324                let result = timeout(Duration::from_millis(50), handle).await;
1325                assert!(result.is_ok());
1326                assert!(matches!(result.unwrap(), Err(JoinError::Cancelled)));
1327            })
1328            .await;
1329    }
1330
1331    #[tokio::test]
1332    async fn test_enum_queue_ids() {
1333        #[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
1334        enum QueueId {
1335            High,
1336            Low,
1337        }
1338        let local = LocalSet::new();
1339        local
1340            .run_until(async {
1341                let executor = ExecutorBuilder::new()
1342                    .with_queue(QueueId::High, 1)
1343                    .with_queue(QueueId::Low, 1)
1344                    .build()
1345                    .unwrap();
1346                let high = Arc::new(AtomicU32::new(0));
1347                let low = Arc::new(AtomicU32::new(0));
1348
1349                let high_clone = high.clone();
1350                let low_clone = low.clone();
1351
1352                let executor_clone = executor.clone();
1353                local.spawn_local(async move {
1354                    executor_clone
1355                        .run_until(sleep(Duration::from_millis(100)))
1356                        .await;
1357                });
1358                let q1 = executor.queue(QueueId::High).unwrap();
1359                let _ = q1.spawn(async move {
1360                    high_clone.fetch_add(1, Ordering::Relaxed);
1361                    yield_once().await;
1362                });
1363                let q2 = executor.queue(QueueId::Low).unwrap();
1364                let _ = q2.spawn(async move {
1365                    low_clone.fetch_add(1, Ordering::Relaxed);
1366                    yield_once().await;
1367                });
1368                sleep(Duration::from_millis(100)).await;
1369            })
1370            .await;
1371    }
1372
1373    #[tokio::test]
1374    async fn test_vruntime_resets() {
1375        let local = LocalSet::new();
1376        local
1377            .run_until(async {
1378                let executor = ExecutorBuilder::new()
1379                    .with_queue(0, 1)
1380                    .with_queue(1, 1)
1381                    .build()
1382                    .unwrap();
1383                let counter = Arc::new(AtomicU32::new(0));
1384                let counter_clone = counter.clone();
1385                let q1 = executor.queue(0).unwrap();
1386                executor
1387                    .run_until(async {
1388                        let handle = q1.spawn(async move {
1389                            for _ in 0..1000 {
1390                                counter_clone.fetch_add(1, Ordering::Relaxed);
1391                                yield_once().await;
1392                            }
1393                        });
1394                        let result = timeout(Duration::from_millis(100), handle).await;
1395                        assert!(result.is_ok(), "Task should complete");
1396                        assert_eq!(counter.load(Ordering::Relaxed), 1000);
1397                        let vruntime1 = executor.queues.borrow()[0].vruntime;
1398                        assert!(vruntime1 > 0);
1399                        // now spawn a task in the second queue
1400                        let counter_clone = counter.clone();
1401                        let q2 = executor.queue(1).unwrap();
1402                        let handle = q2.spawn(async move {
1403                            counter_clone.fetch_add(1, Ordering::Relaxed);
1404                        });
1405                        let result = timeout(Duration::from_millis(100), handle).await;
1406                        assert!(result.is_ok(), "Task should complete");
1407                        assert_eq!(counter.load(Ordering::Relaxed), 1001);
1408                        let vruntime2 = executor.queues.borrow()[1].vruntime;
1409                        // even though the second task only ran for a short time
1410                        // its vruntime should have "inherited" the vruntime of the
1411                        // first queue when it started running
1412                        assert!(
1413                            vruntime2 > vruntime1,
1414                            "vruntime2 should be greater than vruntime1, got {} and {}",
1415                            vruntime2,
1416                            vruntime1
1417                        );
1418                    })
1419                    .await;
1420            })
1421            .await;
1422    }
1423
1424    #[tokio::test]
1425    async fn test_yield_maybe() {
1426        let local = LocalSet::new();
1427        local
1428            .run_until(async {
1429                let executor = ExecutorBuilder::new()
1430                    .with_queue(0, 1)
1431                    .build()
1432                    .unwrap();
1433                let queue = executor.queue(0).unwrap();
1434                let counter1 = Arc::new(AtomicU32::new(0));
1435                let counter1_clone = counter1.clone();
1436                local.spawn_local(async move {
1437                    executor
1438                        .run_until(async {
1439                            let handle = queue.spawn(async move {
1440                                let mut i = 0;
1441                                loop {
1442                                    counter1_clone.fetch_add(1, Ordering::Relaxed);
1443                                    if i % 1000 == 0 {
1444                                        yield_maybe().await;
1445                                    }
1446                                    i += 1;
1447                                }
1448                            });
1449                            sleep(Duration::from_millis(100)).await;
1450                            let count = counter1.load(Ordering::Relaxed);
1451                            assert!(count > 0);
1452                            let yields = executor.stats.borrow().driver_yields;
1453                            assert!(yields > 0);
1454                            // we have yielded at most half the time (in practice much
1455                            // much less)
1456                            assert!(yields < count as u64 / 1000 / 2);
1457                            handle.abort();
1458                        })
1459                        .await;
1460                });
1461            })
1462            .await;
1463    }
1464
1465    // Test with smol runtime
1466    #[test]
1467    fn test_smol_runtime() {
1468        let executor = ExecutorBuilder::new().with_queue(0, 1).build().unwrap();
1469        let smol_local_ex = smol::LocalExecutor::new();
1470        let h2 = smol_local_ex.spawn(async move {
1471            let queue = executor.queue(0).unwrap();
1472            executor
1473                .run_until(async {
1474                    let handle = queue.spawn(async move { 42 });
1475                    handle.await
1476                })
1477                .await
1478        });
1479
1480        let res = smol::future::block_on(smol_local_ex.run(async { h2.await }));
1481        assert_eq!(res, Ok(42));
1482    }
1483
1484    #[tokio::test]
1485    async fn test_abort_after_done() {
1486        let local = LocalSet::new();
1487        local
1488            .run_until(async {
1489                let executor = ExecutorBuilder::new()
1490                    .with_queue(0, 1)
1491                    .build()
1492                    .unwrap();
1493                let counter = Arc::new(AtomicU32::new(0));
1494                let counter_clone = counter.clone();
1495                let queue = executor.queue(0).unwrap();
1496                let result = executor
1497                    .run_until(async {
1498                        let handle = queue.spawn(async move {
1499                            counter_clone.fetch_add(1, Ordering::Relaxed);
1500                            42
1501                        });
1502                        // wait for task to complete
1503                        sleep(Duration::from_millis(100)).await;
1504                        assert!(counter.load(Ordering::Relaxed) > 0);
1505                        // handle should still be abortable - though no-op
1506                        handle.abort();
1507                        handle.await
1508                    })
1509                    .await;
1510                assert_eq!(result, Ok(42));
1511            })
1512            .await;
1513    }
1514
1515    // Test with monoio runtime
1516    #[test]
1517    fn test_monoio_runtime() {
1518        use monoio::LegacyDriver;
1519        let mut rt = monoio::RuntimeBuilder::<LegacyDriver>::new()
1520            .enable_timer() // Explicitly enable the timer
1521            .build()
1522            .unwrap();
1523        let _ = rt.block_on(async move {
1524            let executor = ExecutorBuilder::new().with_queue(0, 1).build().unwrap();
1525            let counter = Arc::new(AtomicU32::new(0));
1526
1527            let counter_clone = counter.clone();
1528            let queue = executor.queue(0).unwrap();
1529            let result = executor
1530                .run_until(async {
1531                    // initial value should be 0
1532                    assert_eq!(counter.load(Ordering::Relaxed), 0);
1533
1534                    let handle = queue.spawn(async move {
1535                        counter_clone.fetch_add(1, Ordering::Relaxed);
1536                        42
1537                    });
1538                    monoio::time::sleep(Duration::from_millis(100)).await;
1539                    // task should have completed
1540                    assert_eq!(counter.load(Ordering::Relaxed), 1);
1541                    handle.await
1542                })
1543                .await;
1544            assert_eq!(result, Ok(42));
1545        });
1546    }
1547
1548    #[test]
1549    fn test_bad_executor_creation() {
1550        // can't create executor with 0 shares
1551        let result = ExecutorBuilder::new().with_queue(0, 0).build();
1552        assert!(result.is_err());
1553        // can't create executor with duplicate queue IDs
1554        let result = ExecutorBuilder::new()
1555            .with_queue(0, 1)
1556            .with_queue(0, 1)
1557            .build();
1558        assert!(result.is_err());
1559        // can't create executor with 0 queues
1560        let result = Executor::<u8>::new(ExecutorOptions::default(), vec![]);
1561        assert!(result.is_err());
1562    }
1563
1564    #[tokio::test]
1565    async fn test_panic_crashes_executor() {
1566        let local = LocalSet::new();
1567        local
1568            .run_until(async {
1569                let executor = ExecutorBuilder::new()
1570                    .with_queue(0, 1)
1571                    .build()
1572                    .unwrap();
1573                let queue = executor.queue(0).unwrap();
1574                let handle = tokio::task::spawn_local(async move {
1575                    executor.run_until(sleep(Duration::from_millis(100))).await;
1576                });
1577                let _ = queue.spawn(async {
1578                    panic!("test");
1579                });
1580                let result = handle.await;
1581                assert!(result.is_err());
1582                assert!(result.unwrap_err().is_panic());
1583            })
1584            .await;
1585    }
1586
1587    #[tokio::test]
1588    async fn test_panic_caught_when_configured() {
1589        let local = LocalSet::new();
1590        local
1591            .run_until(async {
1592                // Configure executor to catch panics instead of crashing
1593                let executor = ExecutorBuilder::new()
1594                    .with_panic_on_task_panic(false)
1595                    .with_queue(0, 1)
1596                    .build()
1597                    .unwrap();
1598                let queue = executor.queue(0).unwrap();
1599                let result = executor.run_until(async {
1600                    let task_handle = queue.spawn(async {
1601                        panic!("test panic message");
1602                    });
1603                    task_handle.await
1604                });
1605
1606                // Wait for the task to complete (should complete with Panic error)
1607                let result = timeout(Duration::from_millis(100), result).await;
1608                assert!(result.is_ok(), "Task should complete (with panic error)");
1609
1610                let join_result = result.unwrap();
1611                assert!(join_result.is_err(), "Task should return an error");
1612
1613                match join_result.unwrap_err() {
1614                    JoinError::Panic(_) => {
1615                        // Expected - panic was caught and converted to JoinError::Panic
1616                    }
1617                    other => panic!("Expected JoinError::Panic, got {:?}", other),
1618                }
1619
1620                // Executor should still be running (not crashed)
1621                assert_eq!(executor.task_queues.len(), 1,);
1622            })
1623            .await;
1624    }
1625
1626    #[tokio::test]
1627    async fn test_preemption_mask_computed_correctly() {
1628        // Test that select_queue computes the preempt mask correctly
1629        let local = LocalSet::new();
1630        local
1631            .run_until(async {
1632                // Create executor with queues of different weights
1633                // Queue 0: weight 8 (highest priority when empty has lowest vruntime)
1634                // Queue 1: weight 4
1635                // Queue 2: weight 1 (lowest priority)
1636                let executor = ExecutorBuilder::new()
1637                    .with_queue(0, 8)
1638                    .with_queue(1, 4)
1639                    .with_queue(2, 1)
1640                    .build()
1641                    .unwrap();
1642
1643                let queue2 = executor.queue(2).unwrap();
1644                let preempt_state = executor.preempt_state.clone();
1645
1646                executor
1647                    .run_until(async {
1648                        // Spawn a task only on queue 2 (lowest priority)
1649                        let handle = queue2.spawn(async {
1650                            loop {
1651                                yield_once().await;
1652                            }
1653                        });
1654
1655                        // Use sleep to allow the executor to run select_queue() and compute the mask.
1656                        // Sleep gives the tokio driver a chance to run, and when it returns,
1657                        // the executor will have already run at least one iteration calling select_queue().
1658                        sleep(Duration::from_millis(10)).await;
1659
1660                        // At this point, queue 2 should be selected (only runnable)
1661                        // and queues 0 and 1 should be in the preempt mask since they
1662                        // would have higher priority if they got tasks
1663                        assert!(
1664                            preempt_state.would_preempt(0),
1665                            "Queue 0 (weight 8) should preempt queue 2 (weight 1)"
1666                        );
1667                        assert!(
1668                            preempt_state.would_preempt(1),
1669                            "Queue 1 (weight 4) should preempt queue 2 (weight 1)"
1670                        );
1671                        assert!(
1672                            !preempt_state.would_preempt(2),
1673                            "Queue 2 is runnable, should not be in preempt mask"
1674                        );
1675                        assert!(
1676                            !preempt_state.check(),
1677                            "Preempt flag should not be set (no higher priority task enqueued)"
1678                        );
1679
1680                        handle.abort();
1681                        let _ = handle.await;
1682                    })
1683                    .await;
1684            })
1685            .await;
1686    }
1687
1688    #[tokio::test]
1689    async fn test_spawn_detached_runs_to_completion() {
1690        let local = LocalSet::new();
1691        local
1692            .run_until(async {
1693                let executor = ExecutorBuilder::new()
1694                    .with_queue(0, 1)
1695                    .build()
1696                    .unwrap();
1697                let counter = Arc::new(AtomicU32::new(0));
1698
1699                let counter_clone = counter.clone();
1700                let result = executor.run_until(async {
1701                    let queue = executor.queue(0).unwrap();
1702                    queue.spawn_detached(async move {
1703                        counter_clone.fetch_add(1, Ordering::Relaxed);
1704                    });
1705                    // Spawn a second task whose completion signals the executor to stop.
1706                    // Since both tasks are on the same queue, the detached task
1707                    // will have run by the time the sentinel completes.
1708                    let queue2 = executor.queue(0).unwrap();
1709                    let handle = queue2.spawn(async { 99 });
1710                    handle.await
1711                });
1712                let result = timeout(Duration::from_millis(100), result).await;
1713                assert!(result.is_ok(), "Executor should finish");
1714                assert_eq!(counter.load(Ordering::Relaxed), 1, "Detached task should have run");
1715            })
1716            .await;
1717    }
1718
1719    #[tokio::test]
1720    async fn test_spawn_detached_many_tasks() {
1721        let local = LocalSet::new();
1722        local
1723            .run_until(async {
1724                let executor = ExecutorBuilder::new()
1725                    .with_queue(0, 1)
1726                    .build()
1727                    .unwrap();
1728                let counter = Arc::new(AtomicU32::new(0));
1729
1730                let counter_clone = counter.clone();
1731                let result = executor.run_until(async {
1732                    let queue = executor.queue(0).unwrap();
1733                    for _ in 0..1000 {
1734                        let c = counter_clone.clone();
1735                        queue.spawn_detached(async move {
1736                            c.fetch_add(1, Ordering::Relaxed);
1737                        });
1738                    }
1739                    let queue2 = executor.queue(0).unwrap();
1740                    let handle = queue2.spawn(async { 0 });
1741                    handle.await
1742                });
1743                let result = timeout(Duration::from_millis(500), result).await;
1744                assert!(result.is_ok(), "Executor should finish");
1745                assert_eq!(counter.load(Ordering::Relaxed), 1000, "All detached tasks should have run");
1746            })
1747            .await;
1748    }
1749}