async-cpupool 0.4.0

A simple async threadpool for CPU-bound tasks
Documentation
use std::{
    collections::VecDeque,
    future::{poll_fn, Future},
    task::{Poll, Waker},
};

use crate::sync::{Arc, AtomicU8, Mutex, Ordering};

thread_local! {
    #[cfg(any(loom, test))]
    static NOTIFY_COUNT: std::cell::RefCell<std::num::Wrapping<u64>> = std::cell::RefCell::new(std::num::Wrapping(0));
}

#[inline(always)]
fn increment_notify_count() {
    #[cfg(any(loom, test))]
    NOTIFY_COUNT.with_borrow_mut(|v| *v += 1);
}

#[inline(always)]
fn decrement_notify_count() {
    #[cfg(any(loom, test))]
    NOTIFY_COUNT.with_borrow_mut(|v| *v -= 1);
}

#[cfg(any(test, loom))]
#[doc(hidden)]
pub fn notify_count() -> u64 {
    NOTIFY_COUNT.with_borrow(|v| v.0)
}

const UNNOTIFIED: u8 = 0b0000;
const NOTIFIED_ONE: u8 = 0b0001;
const RESOLVED: u8 = 0b0010;

#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct NotifyId(u64);

#[doc(hidden)]
pub struct Notify {
    state: Mutex<NotifyState>,
}

struct NotifyState {
    listeners: VecDeque<(NotifyId, Arc<AtomicU8>, Waker)>,
    token: u64,
    next_id: u64,
}

#[doc(hidden)]
pub struct Listener<'a> {
    state: &'a Mutex<NotifyState>,
    waker: Waker,
    woken: Arc<AtomicU8>,
    id: NotifyId,
}

impl Notify {
    #[doc(hidden)]
    pub fn new() -> Self {
        increment_notify_count();
        metrics::counter!("async-cpupool.notify.created").increment(1);

        Notify {
            state: Mutex::new(NotifyState {
                listeners: VecDeque::new(),
                token: 0,
                next_id: 0,
            }),
        }
    }

    // although this is an async fn, it is not capable of yielding to the executor
    #[doc(hidden)]
    pub async fn listen(&self) -> Listener<'_> {
        poll_fn(|cx| Poll::Ready(self.make_listener(cx.waker().clone()))).await
    }

    #[doc(hidden)]
    #[cfg(any(loom, test))]
    pub fn token(&self) -> u64 {
        self.state.lock().unwrap().token
    }

    pub(super) fn make_listener(&self, waker: Waker) -> Listener<'_> {
        let (id, woken) = self
            .state
            .lock()
            .expect("not poisoned")
            .insert(waker.clone());

        Listener {
            state: &self.state,
            waker,
            woken,
            id,
        }
    }

    #[doc(hidden)]
    pub fn notify_one(&self) {
        self.state.lock().expect("not poisoned").notify_one();
    }
}

impl NotifyState {
    fn notify_one(&mut self) {
        loop {
            if let Some((_, woken, waker)) = self.listeners.pop_front() {
                metrics::counter!("async-cpupool.notify.removed").increment(1);
                // can't use "weak" because we need to know failure means we failed fr fr to avoid
                // popping unwoken listeners
                match woken.compare_exchange(
                    UNNOTIFIED,
                    NOTIFIED_ONE,
                    Ordering::Release,
                    Ordering::Relaxed,
                ) {
                    Ok(_) => waker.wake(),

                    // if this listener isn't unnotified (races Listener Drop) we should wake the
                    // next one
                    Err(_) => continue,
                }
            } else {
                self.token += 1;
            }

            break;
        }
    }

    fn insert(&mut self, waker: Waker) -> (NotifyId, Arc<AtomicU8>) {
        let id = NotifyId(self.next_id);
        self.next_id += 1;

        let token = if self.token > 0 {
            self.token -= 1;
            true
        } else {
            false
        };

        // don't insert waker if token is true - next poll will be ready
        let woken = if token {
            Arc::new(AtomicU8::new(NOTIFIED_ONE))
        } else {
            let woken = Arc::new(AtomicU8::new(UNNOTIFIED));
            self.listeners.push_back((id, Arc::clone(&woken), waker));
            metrics::counter!("async-cpupool.notify.inserted").increment(1);

            woken
        };

        (id, woken)
    }

    fn remove(&mut self, id: NotifyId) {
        if let Some(index) = self.find(&id) {
            self.listeners.remove(index);
            metrics::counter!("async-cpupool.notify.removed").increment(1);
        }
    }

    fn find(&self, needle_id: &NotifyId) -> Option<usize> {
        self.listeners
            .binary_search_by_key(needle_id, |(haystack_id, _, _)| *haystack_id)
            .ok()
    }

    fn update(&mut self, id: NotifyId, needle_waker: &Waker) {
        if let Some(index) = self.find(&id) {
            if let Some((_, _, haystack_waker)) = self.listeners.get_mut(index) {
                haystack_waker.clone_from(needle_waker);
            }
        }
    }
}

impl Drop for Notify {
    fn drop(&mut self) {
        decrement_notify_count();
        metrics::counter!("async-cpupool.notify.dropped").increment(1);
    }
}

impl Drop for Listener<'_> {
    fn drop(&mut self) {
        // races compare_exchange in notify_one
        let flags = self.woken.swap(RESOLVED, Ordering::AcqRel);

        if flags == RESOLVED {
            // do nothing
        } else if flags == NOTIFIED_ONE {
            let mut guard = self.state.lock().expect("not poisoned");
            guard.notify_one();
        } else if flags == UNNOTIFIED {
            let mut guard = self.state.lock().expect("not poisoned");
            guard.remove(self.id);
        } else {
            unreachable!("No other states exist")
        }
    }
}

impl Future for Listener<'_> {
    type Output = ();

    fn poll(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> Poll<Self::Output> {
        let mut flags = self.woken.load(Ordering::Acquire);

        loop {
            if flags == UNNOTIFIED {
                break;
            } else if flags == RESOLVED {
                return Poll::Ready(());
            } else {
                match self.woken.compare_exchange_weak(
                    flags,
                    RESOLVED,
                    Ordering::Release,
                    Ordering::Acquire,
                ) {
                    Ok(_) => return Poll::Ready(()),
                    Err(updated) => flags = updated,
                };
            }
        }

        if !self.waker.will_wake(cx.waker()) {
            self.waker.clone_from(cx.waker());

            self.state
                .lock()
                .expect("not poisoned")
                .update(self.id, cx.waker());
        }

        Poll::Pending
    }
}