#![cfg_attr(
all(nightly, target_arch = "aarch64"),
feature(stdarch_aarch64_prefetch)
)]
use core::alloc::Layout;
use core::cell::UnsafeCell;
use core::marker::PhantomData;
use core::mem::{MaybeUninit, size_of};
use core::ptr;
use core::sync::atomic::{AtomicUsize, Ordering};
use std::alloc::{alloc, dealloc, handle_alloc_error};
use std::sync::Arc;
#[cfg_attr(
any(
target_arch = "x86_64",
target_arch = "aarch64",
target_arch = "arm64ec",
target_arch = "powerpc64",
),
repr(C, align(128))
)]
#[cfg_attr(
any(
target_arch = "arm",
target_arch = "mips",
target_arch = "mips32r6",
target_arch = "mips64",
target_arch = "mips64r6",
target_arch = "sparc",
target_arch = "hexagon",
),
repr(C, align(32))
)]
#[cfg_attr(
not(any(
target_arch = "x86_64",
target_arch = "aarch64",
target_arch = "arm64ec",
target_arch = "powerpc64",
target_arch = "arm",
target_arch = "mips",
target_arch = "mips32r6",
target_arch = "mips64",
target_arch = "mips64r6",
target_arch = "sparc",
target_arch = "hexagon",
)),
repr(C, align(64))
)]
struct CachePadded<T>(T);
#[cfg(any(
target_arch = "x86_64",
target_arch = "aarch64",
target_arch = "arm64ec",
target_arch = "powerpc64",
))]
const CACHE_LINE_SIZE: usize = 128;
#[cfg(any(
target_arch = "arm",
target_arch = "mips",
target_arch = "mips32r6",
target_arch = "mips64",
target_arch = "mips64r6",
target_arch = "sparc",
target_arch = "hexagon",
))]
const CACHE_LINE_SIZE: usize = 32;
#[cfg(not(any(
target_arch = "x86_64",
target_arch = "aarch64",
target_arch = "arm64ec",
target_arch = "powerpc64",
target_arch = "arm",
target_arch = "mips",
target_arch = "mips32r6",
target_arch = "mips64",
target_arch = "mips64r6",
target_arch = "sparc",
target_arch = "hexagon",
)))]
const CACHE_LINE_SIZE: usize = 64;
pub struct FastQueue<T> {
mask: CachePadded<usize>,
capacity: CachePadded<usize>,
buffer: CachePadded<*mut MaybeUninit<T>>,
head: CachePadded<AtomicUsize>,
tail: CachePadded<AtomicUsize>,
_pd: PhantomData<T>,
}
unsafe impl<T: Send> Send for FastQueue<T> {}
unsafe impl<T: Send> Sync for FastQueue<T> {}
impl<T> FastQueue<T> {
#[allow(clippy::new_ret_no_self)]
pub fn new(capacity: usize) -> (Producer<T>, Consumer<T>) {
let capacity = capacity.next_power_of_two().max(2);
let mask = capacity - 1;
let layout =
Layout::from_size_align(capacity * size_of::<MaybeUninit<T>>(), CACHE_LINE_SIZE)
.expect("layout");
let buffer = unsafe { alloc(layout) as *mut MaybeUninit<T> };
if buffer.is_null() {
handle_alloc_error(layout);
}
let queue = Arc::new(FastQueue {
mask: CachePadded(mask),
capacity: CachePadded(capacity),
buffer: CachePadded(buffer),
head: CachePadded(AtomicUsize::new(0)),
tail: CachePadded(AtomicUsize::new(0)),
_pd: PhantomData,
});
let producer = Producer {
queue: CachePadded(Arc::clone(&queue)),
head: CachePadded(UnsafeCell::new(0)),
cached_tail: CachePadded(UnsafeCell::new(0)),
_pd: PhantomData,
};
let consumer = Consumer {
queue: CachePadded(queue),
tail: CachePadded(UnsafeCell::new(0)),
cached_head: CachePadded(UnsafeCell::new(0)),
_pd: PhantomData,
};
(producer, consumer)
}
}
impl<T> Drop for FastQueue<T> {
fn drop(&mut self) {
let head = self.head.0.load(Ordering::Relaxed);
let mut tail = self.tail.0.load(Ordering::Relaxed);
while tail != head {
unsafe {
let index = tail & self.mask.0;
let slot = self.buffer.0.add(index);
ptr::drop_in_place((*slot).as_mut_ptr());
}
tail = tail.wrapping_add(1);
}
unsafe {
let layout = Layout::from_size_align(
self.capacity.0 * size_of::<MaybeUninit<T>>(),
CACHE_LINE_SIZE,
)
.expect("layout");
dealloc(self.buffer.0 as *mut u8, layout);
}
}
}
pub struct Producer<T> {
queue: CachePadded<Arc<FastQueue<T>>>,
head: CachePadded<UnsafeCell<usize>>,
cached_tail: CachePadded<UnsafeCell<usize>>,
_pd: PhantomData<T>,
}
unsafe impl<T: Send> Send for Producer<T> {}
impl<T> Producer<T> {
#[inline(always)]
pub fn push(&mut self, value: T) -> Result<(), T> {
let head = unsafe { *self.head.0.get() };
let next_head = head.wrapping_add(1);
self.prefetch_write(next_head);
let cached_tail = unsafe { *self.cached_tail.0.get() };
if next_head.wrapping_sub(cached_tail) > self.queue.0.capacity.0 {
let tail = self.queue.0.tail.0.load(Ordering::Acquire);
if tail != cached_tail {
unsafe {
*self.cached_tail.0.get() = tail;
}
}
if next_head.wrapping_sub(tail) > self.queue.0.capacity.0 {
return Err(value);
}
}
unsafe {
let index = head & self.queue.0.mask.0;
let slot = self.queue.0.buffer.0.add(index);
(*slot).as_mut_ptr().write(value);
*self.head.0.get() = next_head;
}
self.queue.0.head.0.store(next_head, Ordering::Release);
Ok(())
}
#[inline(always)]
pub fn len(&self) -> usize {
let head = unsafe { *self.head.0.get() };
let tail = self.queue.0.tail.0.load(Ordering::Relaxed);
head.wrapping_sub(tail)
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline(always)]
pub fn is_full(&self) -> bool {
self.len() >= self.queue.0.capacity.0
}
#[inline(always)]
fn prefetch_write(&self, index: usize) {
let slot_index = index & self.queue.0.mask.0;
let _slot = unsafe { self.queue.0.buffer.0.add(slot_index) };
#[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
unsafe {
core::arch::x86_64::_mm_prefetch(_slot as *const i8, core::arch::x86_64::_MM_HINT_T0);
}
#[cfg(all(target_arch = "x86_64", target_feature = "prfchw"))]
unsafe {
core::arch::x86_64::_mm_prefetch(_slot as *const i8, core::arch::x86_64::_MM_HINT_ET0);
}
#[cfg(target_arch = "x86")]
unsafe {
core::arch::x86::_mm_prefetch(_slot as *const i8, core::arch::x86::_MM_HINT_ET0);
}
#[cfg(all(feature = "unstable", nightly, target_arch = "aarch64"))]
unsafe {
core::arch::aarch64::_prefetch::<
{ core::arch::aarch64::_PREFETCH_WRITE },
{ core::arch::aarch64::_PREFETCH_LOCALITY0 },
>(_slot as *const i8);
}
}
}
pub struct Consumer<T> {
queue: CachePadded<Arc<FastQueue<T>>>,
tail: CachePadded<UnsafeCell<usize>>,
cached_head: CachePadded<UnsafeCell<usize>>,
_pd: PhantomData<T>,
}
unsafe impl<T: Send> Send for Consumer<T> {}
impl<T> Consumer<T> {
#[inline(always)]
pub fn pop(&mut self) -> Option<T> {
let tail = unsafe { *self.tail.0.get() };
self.prefetch_read(tail.wrapping_add(1));
let cached_head = unsafe { *self.cached_head.0.get() };
if tail == cached_head {
let head = self.queue.0.head.0.load(Ordering::Acquire);
if head != cached_head {
unsafe {
*self.cached_head.0.get() = head;
}
}
if tail == head {
return None;
}
}
let value = unsafe {
let index = tail & self.queue.0.mask.0;
let slot = self.queue.0.buffer.0.add(index);
(*slot).as_ptr().read()
};
let next_tail = tail.wrapping_add(1);
unsafe { *self.tail.0.get() = next_tail };
self.queue.0.tail.0.store(next_tail, Ordering::Release);
Some(value)
}
#[inline(always)]
pub fn peek(&self) -> Option<&T> {
let tail = unsafe { *self.tail.0.get() };
let head = self.queue.0.head.0.load(Ordering::Acquire);
if tail == head {
return None;
}
unsafe {
let index = tail & self.queue.0.mask.0;
let slot = self.queue.0.buffer.0.add(index);
Some(&*(*slot).as_ptr())
}
}
#[inline(always)]
pub fn len(&self) -> usize {
let head = self.queue.0.head.0.load(Ordering::Relaxed);
let tail = unsafe { *self.tail.0.get() };
head.wrapping_sub(tail)
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline(always)]
fn prefetch_read(&self, index: usize) {
let slot_index = index & self.queue.0.mask.0;
let _slot = unsafe { self.queue.0.buffer.0.add(slot_index) };
#[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
unsafe {
core::arch::x86_64::_mm_prefetch(_slot as *const i8, core::arch::x86_64::_MM_HINT_T0);
}
#[cfg(target_arch = "x86")]
unsafe {
core::arch::x86::_mm_prefetch(_slot as *const i8, core::arch::x86::_MM_HINT_T0);
}
#[cfg(all(feature = "unstable", nightly, target_arch = "aarch64"))]
unsafe {
core::arch::aarch64::_prefetch::<
{ core::arch::aarch64::_PREFETCH_READ },
{ core::arch::aarch64::_PREFETCH_LOCALITY0 },
>(_slot as *const i8);
}
}
}
impl<T> Iterator for Consumer<T> {
type Item = T;
#[inline(always)]
fn next(&mut self) -> Option<Self::Item> {
self.pop()
}
#[inline(always)]
fn size_hint(&self) -> (usize, Option<usize>) {
(self.len(), None)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread;
#[test]
fn test_basic_push_pop() {
let (mut producer, mut consumer) = FastQueue::<usize>::new(2);
assert!(producer.push(42).is_ok());
assert_eq!(consumer.pop(), Some(42));
assert_eq!(consumer.pop(), None);
}
#[test]
fn test_capacity() {
let (mut producer, mut consumer) = FastQueue::<usize>::new(4);
assert!(producer.push(1).is_ok());
assert!(producer.push(2).is_ok());
assert!(producer.push(3).is_ok());
assert!(producer.push(4).is_ok());
assert!(producer.push(5).is_err());
assert_eq!(consumer.pop(), Some(1));
assert!(producer.push(5).is_ok()); assert_eq!(consumer.pop(), Some(2));
assert_eq!(consumer.pop(), Some(3));
assert_eq!(consumer.pop(), Some(4));
assert_eq!(consumer.pop(), Some(5));
}
#[test]
fn test_concurrent() {
const COUNT: usize = 1_000_000;
let (mut producer, mut consumer) = FastQueue::<usize>::new(1024);
let done = Arc::new(AtomicBool::new(false));
let done_clone = Arc::clone(&done);
let producer_thread = thread::spawn(move || {
for i in 0..COUNT {
while producer.push(i).is_err() {
std::hint::spin_loop();
}
}
done_clone.store(true, Ordering::Release);
});
let consumer_thread = thread::spawn(move || {
let mut count = 0;
while count < COUNT {
if let Some(val) = consumer.pop() {
assert_eq!(val, count);
count += 1;
} else if done.load(Ordering::Acquire) && consumer.is_empty() {
break;
} else {
std::hint::spin_loop();
}
}
assert_eq!(count, COUNT);
});
producer_thread.join().unwrap();
consumer_thread.join().unwrap();
}
#[test]
fn test_wraparound() {
let (mut producer, mut consumer) = FastQueue::<usize>::new(4);
for i in 0..4 {
assert!(producer.push(i).is_ok());
}
for i in 0..2 {
assert_eq!(consumer.pop(), Some(i));
}
for i in 4..6 {
assert!(producer.push(i).is_ok());
}
for i in 2..6 {
assert_eq!(consumer.pop(), Some(i));
}
assert!(consumer.pop().is_none());
}
}