use crossbeam_utils::CachePadded;
use std::{
fmt,
sync::atomic::{AtomicU64, Ordering},
};
const LOCAL_COUNTER_BIT_SIZE: u8 = 32;
const SHARED_COUNTER_BIT_SIZE: u8 = LOCAL_COUNTER_BIT_SIZE - 1;
const SHARED_COUNTER_MASK: u64 = (1 << SHARED_COUNTER_BIT_SIZE) - 1;
const SHARED_COUNTER_ONE: u64 = 1;
const NOTIFIED_FLAG: u64 = 1 << SHARED_COUNTER_BIT_SIZE;
const LOCAL_COUNTER_ONE: u64 = 1 << LOCAL_COUNTER_BIT_SIZE;
pub struct TaskQueue {
counter: CachePadded<AtomicU64>,
}
#[derive(Clone, Copy)]
#[must_use]
pub struct Counter(u64);
impl Counter {
#[inline]
pub fn local(self) -> u64 {
self.0 >> LOCAL_COUNTER_BIT_SIZE
}
#[inline]
pub fn shared(self) -> u64 {
self.0 & SHARED_COUNTER_MASK
}
#[inline]
pub fn shared_queue_has_data(self) -> bool {
self.shared() > 0
}
#[inline]
pub fn total(self) -> u64 {
self.local() + self.shared()
}
#[inline]
pub(crate) fn is_notified(self) -> bool {
(self.0 & NOTIFIED_FLAG) == NOTIFIED_FLAG
}
}
impl TaskQueue {
pub fn new() -> Self {
Self {
counter: CachePadded::new(AtomicU64::new(0)),
}
}
#[inline]
pub fn accept_notify_once_if_shared_queue_is_empty(&self) -> (bool, Counter) {
let result = self
.counter
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |curr| {
let curr = Counter(curr);
if curr.shared() != 0 {
return None;
}
if !curr.is_notified() {
return None;
}
Some(curr.0 & !NOTIFIED_FLAG)
});
match result {
Ok(state) => (true, Counter(state)),
Err(state) => (false, Counter(state)),
}
}
pub fn increase_local(&self) -> Counter {
Counter(self.counter.fetch_add(LOCAL_COUNTER_ONE, Ordering::Relaxed))
}
pub fn decrease_local(&self) -> Counter {
let old = Counter(self.counter.fetch_sub(LOCAL_COUNTER_ONE, Ordering::Relaxed));
debug_assert!(old.local() > 0);
old
}
pub fn clear_notified_flag(&self) {
self.counter.fetch_and(!NOTIFIED_FLAG, Ordering::Release);
}
pub fn increase_shared_and_mark_as_notified(&self) -> Counter {
let state = self
.counter
.fetch_update(Ordering::AcqRel, Ordering::Relaxed, |curr| {
Some((curr | NOTIFIED_FLAG) + SHARED_COUNTER_ONE)
})
.unwrap();
Counter(state)
}
pub fn move_shared_to_local(&self, n: Counter) {
let shared = n.shared();
self.counter.fetch_add(
(shared << LOCAL_COUNTER_BIT_SIZE) - shared,
Ordering::Relaxed,
);
}
#[inline]
pub fn load(&self) -> Counter {
Counter(self.counter.load(Ordering::Relaxed))
}
}
impl fmt::Debug for Counter {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Counter {{ local: {}, shared: {}, notified: {} }}",
self.local(),
self.shared(),
self.is_notified()
)
}
}
#[cfg(test)]
mod tests {
use super::*;
impl TaskQueue {
pub fn increase_shared(&self) -> Counter {
self.increase_shared_and_mark_as_notified()
}
}
#[test]
fn test_local_counter() {
let q = TaskQueue::new();
assert_eq!(q.increase_local().total(), 0);
assert_eq!(q.decrease_local().local(), 1);
assert_eq!(q.load().total(), 0);
}
#[test]
fn test_shared_counter() {
let q = TaskQueue::new();
assert_eq!(q.increase_local().shared(), 0);
assert_eq!(q.increase_shared().local(), 1);
assert_eq!(q.increase_local().total(), 2);
assert_eq!(q.decrease_local().shared(), 1);
assert_eq!(q.load().total(), 2);
}
#[test]
fn test_move_counter() {
let q = TaskQueue::new();
assert_eq!(q.increase_local().total(), 0);
q.move_shared_to_local(q.load());
assert_eq!(q.load().total(), 1);
assert_eq!(q.increase_shared().local(), 1);
let old = q.increase_local();
q.move_shared_to_local(old);
assert_eq!(q.load().shared(), 0);
assert_eq!(q.load().local(), 3);
}
#[test]
fn test_notification_flag() {
let q = TaskQueue::new();
assert_eq!(q.load().is_notified(), false);
assert!(!q.increase_shared_and_mark_as_notified().is_notified());
let old = q.increase_local();
assert_eq!(old.local(), 0);
assert_eq!(old.shared(), 1);
assert_eq!(old.is_notified(), true);
let (is_flag_removed, old) = q.accept_notify_once_if_shared_queue_is_empty();
assert_eq!(is_flag_removed, false);
assert_eq!(old.is_notified(), true);
assert_eq!(q.load().is_notified(), true);
q.move_shared_to_local(old);
let (is_flag_removed, old) = q.accept_notify_once_if_shared_queue_is_empty();
assert_eq!(is_flag_removed, true);
assert_eq!(old.local(), 2);
assert_eq!(old.shared(), 0);
assert_eq!(old.is_notified(), true);
let old = q.increase_shared_and_mark_as_notified();
assert_eq!(old.is_notified(), false);
let curr = q.load();
assert_eq!(curr.local(), 2);
assert_eq!(curr.shared(), 1);
assert_eq!(curr.is_notified(), true);
q.clear_notified_flag();
assert_eq!(q.load().is_notified(), false);
let old = q.increase_shared_and_mark_as_notified();
assert_eq!(old.shared(), 1);
assert_eq!(old.is_notified(), false);
let old = q.increase_local();
assert_eq!(old.local(), 2);
assert_eq!(old.is_notified(), true);
let curr = q.load();
assert_eq!(curr.local(), 3);
assert_eq!(curr.shared(), 2);
assert_eq!(curr.is_notified(), true);
}
}