use std::cell::UnsafeCell;
use std::mem::MaybeUninit;
use std::sync::atomic::{AtomicUsize, Ordering};
use super::RawTask;
pub struct LocalQueue {
buffer: Box<[UnsafeCell<MaybeUninit<RawTask>>]>,
capacity: usize,
mask: usize,
head: AtomicUsize,
tail: AtomicUsize,
}
unsafe impl Send for LocalQueue {}
unsafe impl Sync for LocalQueue {}
impl LocalQueue {
#[must_use]
pub fn new(capacity: usize) -> Self {
let capacity = capacity.next_power_of_two().max(2);
let mask = capacity - 1;
let buffer = (0..capacity)
.map(|_| UnsafeCell::new(MaybeUninit::uninit()))
.collect();
Self {
buffer,
capacity,
mask,
head: AtomicUsize::new(0),
tail: AtomicUsize::new(0),
}
}
#[inline]
pub fn push(&self, task: RawTask) -> bool {
loop {
let tail = self.tail.load(Ordering::Relaxed);
let head = self.head.load(Ordering::Acquire);
if tail - head >= self.capacity {
return false;
}
let pos = tail & self.mask;
unsafe {
self.buffer[pos].get().write(MaybeUninit::new(task));
}
if self
.tail
.compare_exchange(tail, tail + 1, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
return true;
}
}
}
#[inline]
pub fn pop(&self) -> Option<RawTask> {
loop {
let head = self.head.load(Ordering::Relaxed);
let tail = self.tail.load(Ordering::Acquire);
if head == tail {
return None;
}
let pos = head & self.mask;
let task = unsafe { self.buffer[pos].get().read().assume_init() };
if let Ok(_) =
self.head
.compare_exchange(head, head + 1, Ordering::AcqRel, Ordering::Relaxed)
{
return Some(task);
}
unsafe {
self.buffer[pos].get().write(MaybeUninit::new(task));
}
}
}
#[inline]
#[must_use]
pub fn len(&self) -> usize {
let tail = self.tail.load(Ordering::Relaxed);
let head = self.head.load(Ordering::Relaxed);
tail.saturating_sub(head)
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
#[must_use]
pub const fn capacity(&self) -> usize {
self.capacity
}
pub fn steal(&self, dest: &LocalQueue) -> usize {
let head = self.head.load(Ordering::Relaxed);
let tail = self.tail.load(Ordering::Acquire);
let len = tail.saturating_sub(head);
if len == 0 {
return 0;
}
let steal_count = len / 2;
if steal_count == 0 {
return 0;
}
let mut stolen = 0;
for i in 0..steal_count {
let pos = (head + i) & self.mask;
let task = unsafe { self.buffer[pos].get().read().assume_init() };
if dest.push(task) {
stolen += 1;
self.head.store(head + i + 1, Ordering::Release);
} else {
break;
}
}
stolen
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_queue_push_pop() {
let queue = LocalQueue::new(16);
let task1 = 0x1000 as RawTask;
let task2 = 0x2000 as RawTask;
assert!(queue.push(task1));
assert!(queue.push(task2));
assert_eq!(queue.pop(), Some(task1));
assert_eq!(queue.pop(), Some(task2));
assert_eq!(queue.pop(), None);
}
#[test]
fn test_queue_empty_full() {
let queue = LocalQueue::new(4);
assert!(queue.is_empty());
assert_eq!(queue.len(), 0);
for i in 0..4 {
assert!(queue.push(i as RawTask));
}
assert!(!queue.push(99 as RawTask));
assert_eq!(queue.len(), 4);
for i in 0..4 {
assert_eq!(queue.pop(), Some(i as RawTask));
}
assert!(queue.is_empty());
}
#[test]
fn test_queue_wrap_around() {
let queue = LocalQueue::new(4);
for round in 0..3 {
for i in 0..4 {
assert!(queue.push((round * 4 + i) as RawTask));
}
for i in 0..4 {
assert_eq!(queue.pop(), Some((round * 4 + i) as RawTask));
}
}
}
#[test]
fn test_queue_capacity_power_of_two() {
let q = LocalQueue::new(5);
assert_eq!(q.capacity(), 8);
let q = LocalQueue::new(100);
assert_eq!(q.capacity(), 128);
}
}