uasync 0.1.1

fast, safe, async executor
Documentation
use super::{
    queue::Queue,
    random::{RandomSequence, Rng},
    semaphore::Semaphore,
    task::Runnable,
    thread::Thread,
};
use std::{
    num::NonZeroUsize,
    sync::atomic::{fence, AtomicUsize, Ordering},
    sync::Arc,
    time::Duration,
};

pub(super) struct Scheduler {
    join_semaphore: Semaphore,
    idle_semaphore: Semaphore,
    rand_seq: RandomSequence,
    tasks: AtomicUsize,
    state: AtomicUsize,
    injector: Queue,
    run_queues: Box<[Queue]>,
    on_thread_park: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
    on_thread_unpark: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
}

impl Scheduler {
    pub(super) fn new(
        worker_threads: NonZeroUsize,
        on_thread_park: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
        on_thread_unpark: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
    ) -> Self {
        let worker_threads = worker_threads.get().min(Self::STATE_MASK);
        let worker_threads = NonZeroUsize::new(worker_threads).unwrap();

        Self {
            join_semaphore: Semaphore::default(),
            idle_semaphore: Semaphore::default(),
            rand_seq: RandomSequence::new(worker_threads),
            tasks: AtomicUsize::new(0),
            state: AtomicUsize::new(0),
            injector: Queue::default(),
            run_queues: (0..worker_threads.get())
                .map(|_| Queue::default())
                .collect(),
            on_thread_park,
            on_thread_unpark,
        }
    }

    pub(super) fn worker_threads(&self) -> NonZeroUsize {
        NonZeroUsize::new(self.run_queues.len()).unwrap()
    }

    pub(super) fn schedule(
        &self,
        runnable: Arc<dyn Runnable>,
        thread: Option<&Thread>,
        be_fair: bool,
    ) {
        if let Some(thread) = thread {
            if let Some(queue_index) = thread.queue_index {
                let mut runnable = Some(runnable);
                if be_fair {
                    thread.be_fair.set(true);
                } else {
                    runnable = thread.lifo_slot.replace(runnable);
                }

                if let Some(runnable) = runnable {
                    self.run_queues[queue_index].push(runnable);
                    self.unpark();
                }

                return;
            }
        }

        self.injector.push(runnable);
        fence(Ordering::SeqCst);
        self.unpark()
    }

    const STATE_BITS: u32 = (usize::BITS - 1) / 2;
    const STATE_MASK: usize = (1 << Self::STATE_BITS) - 1;
    const SHUTDOWN_MASK: usize = 1 << (usize::BITS - 1);

    #[allow(clippy::erasing_op)]
    const IDLE_SHIFT: u32 = Self::STATE_BITS * 0;

    #[allow(clippy::identity_op)]
    const SEARCHING_SHIFT: u32 = Self::STATE_BITS * 1;

    fn unpark(&self) {
        self.state
            .fetch_update(Ordering::Release, Ordering::Relaxed, |state| {
                if state & Self::SHUTDOWN_MASK != 0 {
                    return None;
                }

                let searching = (state >> Self::SEARCHING_SHIFT) & Self::STATE_MASK;
                assert!(searching <= self.run_queues.len());
                if searching > 0 {
                    return None;
                }

                let mut idle = (state >> Self::IDLE_SHIFT) & Self::STATE_MASK;
                assert!(idle <= self.run_queues.len());
                idle = idle.checked_sub(1)?;

                Some((1 << Self::SEARCHING_SHIFT) | (idle << Self::IDLE_SHIFT))
            })
            .map(|_| self.idle_semaphore.post(1))
            .unwrap_or(())
    }

    fn search(&self, was_searching: bool) -> bool {
        if was_searching {
            return true;
        }

        let state = self.state.load(Ordering::Relaxed);
        let searching = (state >> Self::SEARCHING_SHIFT) & Self::STATE_MASK;
        assert!(searching <= self.run_queues.len());
        if (2 * searching) >= self.run_queues.len() {
            return false;
        }

        let state = self
            .state
            .fetch_add(1 << Self::SEARCHING_SHIFT, Ordering::Acquire);

        let searching = (state >> Self::SEARCHING_SHIFT) & Self::STATE_MASK;
        assert!(searching < self.run_queues.len());
        true
    }

    fn discovered(&self, was_searching: bool) -> bool {
        if was_searching {
            let state = self
                .state
                .fetch_sub(1 << Self::SEARCHING_SHIFT, Ordering::Release);

            let searching = (state >> Self::SEARCHING_SHIFT) & Self::STATE_MASK;
            assert!(searching <= self.run_queues.len());
            assert!(searching > 0);

            if searching == 1 {
                self.unpark();
            }
        }

        false
    }

    fn park(&self, was_searching: bool) -> Option<bool> {
        let mut update: usize = 1 << Self::IDLE_SHIFT;
        if was_searching {
            update = update.wrapping_sub(1 << Self::SEARCHING_SHIFT);
        }

        let state = self.state.fetch_add(update, Ordering::SeqCst);
        let idle = (state >> Self::IDLE_SHIFT) & Self::STATE_MASK;
        assert!(idle < self.run_queues.len());

        let searching = (state >> Self::SEARCHING_SHIFT) & Self::STATE_MASK;
        assert!(searching <= self.run_queues.len());
        assert!(searching >= was_searching as usize);

        if state & Self::SHUTDOWN_MASK != 0 {
            let state = self
                .state
                .fetch_sub(1 << Self::IDLE_SHIFT, Ordering::Relaxed);

            let idle = (state >> Self::IDLE_SHIFT) & Self::STATE_MASK;
            assert!(idle <= self.run_queues.len());
            assert!(idle > 0);
            return None;
        }

        if was_searching && searching == 1 && !self.is_empty() {
            self.unpark();
        }

        if let Some(callback) = self.on_thread_park.as_ref() {
            (callback)();
        }

        self.idle_semaphore.wait(None);

        if let Some(callback) = self.on_thread_unpark.as_ref() {
            (callback)();
        }

        Some(true)
    }

    fn is_empty(&self) -> bool {
        self.run_queues
            .iter()
            .map(|queue| queue.is_empty())
            .find(|&is_empty| !is_empty)
            .unwrap_or_else(|| self.injector.is_empty())
    }

    pub(super) fn shutdown(&self) {
        self.state
            .fetch_update(Ordering::AcqRel, Ordering::Relaxed, |mut state| {
                let idle = (state >> Self::IDLE_SHIFT) & Self::STATE_MASK;
                state -= idle << Self::IDLE_SHIFT;
                state += idle << Self::SEARCHING_SHIFT;

                let searching = (state >> Self::SEARCHING_SHIFT) & Self::STATE_MASK;
                assert!(searching <= self.run_queues.len());

                state |= Self::SHUTDOWN_MASK;
                Some(state)
            })
            .map(|state| {
                let idle = (state >> Self::IDLE_SHIFT) & Self::STATE_MASK;
                if idle > 0 {
                    self.idle_semaphore.post(idle);
                }
            })
            .unwrap_or(())
    }

    pub(super) fn on_task_begin(&self) {
        let tasks = self.tasks.fetch_add(1, Ordering::Relaxed);
        assert_ne!(tasks, usize::MAX);
    }

    pub(super) fn on_task_complete(&self) {
        let tasks = self.tasks.fetch_sub(1, Ordering::AcqRel);
        assert_ne!(tasks, 0);

        if tasks == 1 {
            let state = self.state.load(Ordering::Relaxed);
            if state & Self::SHUTDOWN_MASK != 0 {
                self.join_semaphore.post(1);
            }
        }
    }

    pub(super) fn join(&self, timeout: Option<Duration>) {
        let mut tasks = self.tasks.load(Ordering::Acquire);
        if tasks > 0 {
            self.join_semaphore.wait(timeout);
            tasks = self.tasks.load(Ordering::Acquire);
        }

        if timeout.is_none() {
            assert_eq!(tasks, 0);
        }
    }

    pub(super) fn run_worker(self: &Arc<Self>, queue_index: usize) {
        let thread = Thread::enter(self, Some(queue_index));

        let mut tick = queue_index;
        let mut is_searching = false;
        let mut rng = Rng::new(queue_index);

        let run_queue = &self.run_queues[queue_index];
        loop {
            let fairness = match thread.be_fair.take() || (tick % 61 == 0) {
                true => self.injector.steal().unwrap_or(None),
                false => None,
            };

            let polled = fairness.or_else(|| {
                thread.lifo_slot.take().or_else(|| {
                    run_queue.pop().or_else(|| {
                        is_searching = self.search(is_searching);
                        if is_searching {
                            for _ in 0..32 {
                                let mut is_empty = false;
                                let rand_seed = rng.next().unwrap().get();

                                for target_index in self.rand_seq.iter(rand_seed) {
                                    if target_index != queue_index {
                                        match self.run_queues[target_index].steal() {
                                            Ok(Some(runnable)) => return Some(runnable),
                                            Ok(None) => {}
                                            Err(_) => is_empty = false,
                                        }
                                    }
                                }

                                match self.injector.steal() {
                                    Ok(Some(runnable)) => return Some(runnable),
                                    Ok(None) => {}
                                    Err(_) => is_empty = false,
                                }

                                if is_empty {
                                    break;
                                }
                            }
                        }

                        None
                    })
                })
            });

            if let Some(runnable) = polled {
                is_searching = self.discovered(is_searching);
                tick = tick.wrapping_add(1);
                runnable.run(&thread);
                continue;
            }

            is_searching = match self.park(is_searching) {
                Some(searching) => searching,
                None => break,
            };
        }
    }
}