use crate::{
loom::{
cell::UnsafeCell,
sync::atomic::{AtomicBool, AtomicPtr, Ordering::*},
},
util::{Backoff, CachePadded},
Linked,
};
use core::{
fmt,
marker::PhantomPinned,
ptr::{self, NonNull},
};
pub struct MpscQueue<T: Linked<Links<T>>> {
head: CachePadded<AtomicPtr<T>>,
tail: CachePadded<UnsafeCell<*mut T>>,
has_consumer: CachePadded<AtomicBool>,
stub_is_static: bool,
stub: NonNull<T>,
}
pub struct Consumer<'q, T: Linked<Links<T>>> {
q: &'q MpscQueue<T>,
}
pub struct Links<T> {
next: AtomicPtr<T>,
#[cfg(debug_assertions)]
is_stub: AtomicBool,
_unpin: PhantomPinned,
}
#[derive(Debug, Eq, PartialEq)]
pub enum TryDequeueError {
Empty,
Inconsistent,
Busy,
}
impl<T: Linked<Links<T>>> MpscQueue<T> {
#[must_use]
pub fn new() -> Self
where
T::Handle: Default,
{
Self::new_with_stub(Default::default())
}
#[must_use]
pub fn new_with_stub(stub: T::Handle) -> Self {
let stub = T::into_ptr(stub);
#[cfg(debug_assertions)]
unsafe {
links(stub).is_stub.store(true, Release);
}
let ptr = stub.as_ptr();
Self {
head: CachePadded(AtomicPtr::new(ptr)),
tail: CachePadded(UnsafeCell::new(ptr)),
has_consumer: CachePadded(AtomicBool::new(false)),
stub_is_static: false,
stub,
}
}
#[cfg(not(loom))]
#[must_use]
pub const unsafe fn new_with_static_stub(stub: &'static T) -> Self {
let ptr = stub as *const T as *mut T;
Self {
head: CachePadded(AtomicPtr::new(ptr)),
tail: CachePadded(UnsafeCell::new(ptr)),
has_consumer: CachePadded(AtomicBool::new(false)),
stub_is_static: true,
stub: NonNull::new_unchecked(ptr),
}
}
pub fn enqueue(&self, element: T::Handle) {
let ptr = T::into_ptr(element);
#[cfg(debug_assertions)]
debug_assert!(!unsafe { T::links(ptr).as_ref() }.is_stub());
self.enqueue_inner(ptr)
}
#[inline]
fn enqueue_inner(&self, ptr: NonNull<T>) {
unsafe { links(ptr).next.store(ptr::null_mut(), Relaxed) };
let ptr = ptr.as_ptr();
let prev = self.head.swap(ptr, AcqRel);
unsafe {
links(non_null(prev)).next.store(ptr, Release);
}
}
pub fn try_dequeue(&self) -> Result<T::Handle, TryDequeueError> {
if self
.has_consumer
.compare_exchange(false, true, AcqRel, Acquire)
.is_err()
{
return Err(TryDequeueError::Busy);
}
let res = unsafe {
self.try_dequeue_unchecked()
};
self.has_consumer.store(false, Release);
res
}
pub fn dequeue(&self) -> Option<T::Handle> {
let mut boff = Backoff::new();
loop {
match self.try_dequeue() {
Ok(val) => return Some(val),
Err(TryDequeueError::Empty) => return None,
Err(_) => boff.spin(),
}
}
}
pub fn consume(&self) -> Consumer<'_, T> {
self.lock_consumer();
Consumer { q: self }
}
pub fn try_consume(&self) -> Option<Consumer<'_, T>> {
self.try_lock_consumer().map(|_| Consumer { q: self })
}
pub unsafe fn try_dequeue_unchecked(&self) -> Result<T::Handle, TryDequeueError> {
self.tail.with_mut(|tail| {
let mut tail_node = NonNull::new(*tail).ok_or(TryDequeueError::Empty)?;
let mut next = links(tail_node).next.load(Acquire);
if tail_node == self.stub {
#[cfg(debug_assertions)]
debug_assert!(links(tail_node).is_stub());
let next_node = NonNull::new(next).ok_or(TryDequeueError::Empty)?;
*tail = next;
tail_node = next_node;
next = links(next_node).next.load(Acquire);
}
if !next.is_null() {
*tail = next;
return Ok(T::from_ptr(tail_node));
}
let head = self.head.load(Acquire);
if tail_node.as_ptr() != head {
return Err(TryDequeueError::Inconsistent);
}
self.enqueue_inner(self.stub);
next = links(tail_node).next.load(Acquire);
if next.is_null() {
return Err(TryDequeueError::Empty);
}
*tail = next;
#[cfg(debug_assertions)]
debug_assert!(!links(tail_node).is_stub());
Ok(T::from_ptr(tail_node))
})
}
pub unsafe fn dequeue_unchecked(&self) -> Option<T::Handle> {
let mut boff = Backoff::new();
loop {
match self.try_dequeue_unchecked() {
Ok(val) => return Some(val),
Err(TryDequeueError::Empty) => return None,
Err(TryDequeueError::Inconsistent) => boff.spin(),
Err(TryDequeueError::Busy) => {
unreachable!("try_dequeue_unchecked never returns `Busy`!")
}
}
}
}
#[inline]
fn lock_consumer(&self) {
let mut boff = Backoff::new();
while self
.has_consumer
.compare_exchange(false, true, AcqRel, Acquire)
.is_err()
{
while self.has_consumer.load(Relaxed) {
boff.spin();
}
}
}
#[inline]
fn try_lock_consumer(&self) -> Option<()> {
self.has_consumer
.compare_exchange(false, true, AcqRel, Acquire)
.map(|_| ())
.ok()
}
}
impl<T: Linked<Links<T>>> Drop for MpscQueue<T> {
fn drop(&mut self) {
let mut current = self.tail.with_mut(|tail| unsafe {
*tail
});
while let Some(node) = NonNull::new(current) {
unsafe {
let links = links(node);
let next = links.next.load(Relaxed);
if node != self.stub {
#[cfg(debug_assertions)]
debug_assert!(!links.is_stub());
drop(T::from_ptr(node));
} else {
#[cfg(debug_assertions)]
debug_assert!(links.is_stub());
}
current = next;
}
}
unsafe {
if !self.stub_is_static {
drop(T::from_ptr(self.stub));
}
}
}
}
impl<T> fmt::Debug for MpscQueue<T>
where
T: Linked<Links<T>>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MpscQueue")
.field("head", &format_args!("{:p}", self.head.load(Acquire)))
.field("tail", &format_args!("..."))
.field("has_consumer", &self.has_consumer.load(Acquire))
.field("stub", &self.stub)
.finish()
}
}
impl<T> Default for MpscQueue<T>
where
T: Linked<Links<T>>,
T::Handle: Default,
{
fn default() -> Self {
Self::new()
}
}
unsafe impl<T> Send for MpscQueue<T>
where
T: Send + Linked<Links<T>>,
T::Handle: Send,
{
}
unsafe impl<T: Send + Linked<Links<T>>> Sync for MpscQueue<T> {}
impl<'q, T: Send + Linked<Links<T>>> Consumer<'q, T> {
#[inline]
pub fn dequeue(&self) -> Option<T::Handle> {
debug_assert!(self.q.has_consumer.load(Acquire));
unsafe {
self.q.dequeue_unchecked()
}
}
#[inline]
pub fn try_dequeue(&self) -> Result<T::Handle, TryDequeueError> {
debug_assert!(self.q.has_consumer.load(Acquire));
unsafe {
self.q.try_dequeue_unchecked()
}
}
}
impl<T: Linked<Links<T>>> Drop for Consumer<'_, T> {
fn drop(&mut self) {
self.q.has_consumer.store(false, Release);
}
}
impl<T> fmt::Debug for Consumer<'_, T>
where
T: Linked<Links<T>>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let tail = self.q.tail.with(|tail| unsafe {
*tail
});
f.debug_struct("Consumer")
.field("head", &format_args!("{:p}", tail))
.field("tail", &tail)
.field("has_consumer", &self.q.has_consumer.load(Acquire))
.finish()
}
}
impl<T> Iterator for Consumer<'_, T>
where
T: Send + Linked<Links<T>>,
{
type Item = T::Handle;
fn next(&mut self) -> Option<Self::Item> {
self.dequeue()
}
}
impl<T> Links<T> {
#[cfg(not(loom))]
#[must_use]
pub const fn new() -> Self {
Self {
next: AtomicPtr::new(ptr::null_mut()),
_unpin: PhantomPinned,
#[cfg(debug_assertions)]
is_stub: AtomicBool::new(false),
}
}
#[cfg(not(loom))]
#[must_use]
pub const fn new_stub() -> Self {
Self {
next: AtomicPtr::new(ptr::null_mut()),
_unpin: PhantomPinned,
#[cfg(debug_assertions)]
is_stub: AtomicBool::new(true),
}
}
#[cfg(loom)]
#[must_use]
pub fn new() -> Self {
Self {
next: AtomicPtr::new(ptr::null_mut()),
_unpin: PhantomPinned,
#[cfg(debug_assertions)]
is_stub: AtomicBool::new(false),
}
}
#[cfg(loom)]
#[must_use]
pub fn new_stub() -> Self {
Self {
next: AtomicPtr::new(ptr::null_mut()),
_unpin: PhantomPinned,
#[cfg(debug_assertions)]
is_stub: AtomicBool::new(true),
}
}
#[cfg(debug_assertions)]
fn is_stub(&self) -> bool {
self.is_stub.load(Acquire)
}
}
impl<T> Default for Links<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> fmt::Debug for Links<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut s = f.debug_struct("Links");
s.field("next", &self.next.load(Acquire));
#[cfg(debug_assertions)]
s.field("is_stub", &self.is_stub.load(Acquire));
s.finish()
}
}
feature! {
#![feature = "alloc"]
use alloc::sync::Arc;
pub struct OwnedConsumer<T: Linked<Links<T>>> {
q: Arc<MpscQueue<T>>
}
impl<T: Linked<Links<T>>> OwnedConsumer<T> {
#[inline]
pub fn dequeue(&self) -> Option<T::Handle> {
debug_assert!(self.q.has_consumer.load(Acquire));
unsafe {
self.q.dequeue_unchecked()
}
}
#[inline]
pub fn try_dequeue(&self) -> Result<T::Handle, TryDequeueError> {
debug_assert!(self.q.has_consumer.load(Acquire));
unsafe {
self.q.try_dequeue_unchecked()
}
}
pub fn has_producers(&self) -> bool {
Arc::strong_count(&self.q) > 1
}
}
impl<T: Linked<Links<T>>> Drop for OwnedConsumer<T> {
fn drop(&mut self) {
self.q.has_consumer.store(false, Release);
}
}
impl<T: Linked<Links<T>>> fmt::Debug for OwnedConsumer<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let tail = self.q.tail.with(|tail| unsafe {
*tail
});
f.debug_struct("OwnedConsumer")
.field("head", &self.q.head.load(Acquire))
.field("tail", &tail)
.field("has_consumer", &self.q.has_consumer.load(Acquire))
.finish()
}
}
impl<T: Linked<Links<T>>> MpscQueue<T> {
pub fn consume_owned(self: Arc<Self>) -> OwnedConsumer<T> {
self.lock_consumer();
OwnedConsumer { q: self }
}
pub fn try_consume_owned(self: Arc<Self>) -> Option<OwnedConsumer<T>> {
self.try_lock_consumer().map(|_| OwnedConsumer { q: self })
}
}
}
#[inline(always)]
unsafe fn links<'a, T: Linked<Links<T>>>(ptr: NonNull<T>) -> &'a Links<T> {
T::links(ptr).as_ref()
}
#[cfg(debug_assertions)]
#[track_caller]
#[inline(always)]
unsafe fn non_null<T>(ptr: *mut T) -> NonNull<T> {
NonNull::new(ptr).expect(
"/!\\ constructed a `NonNull` from a null pointer! /!\\ \n\
in release mode, this would have called `NonNull::new_unchecked`, \
violating the `NonNull` invariant! this is a bug in `cordyceps!`.",
)
}
#[cfg(not(debug_assertions))]
#[inline(always)]
unsafe fn non_null<T>(ptr: *mut T) -> NonNull<T> {
NonNull::new_unchecked(ptr)
}
#[cfg(all(loom, test))]
mod loom {
use super::*;
use crate::loom::{self, sync::Arc, thread};
use test_util::*;
#[test]
fn basically_works_loom() {
const THREADS: i32 = 2;
const MSGS: i32 = THREADS;
const TOTAL_MSGS: i32 = THREADS * MSGS;
basically_works_test(THREADS, MSGS, TOTAL_MSGS);
}
#[test]
fn doesnt_leak() {
const THREADS: i32 = 2;
const MSGS: i32 = THREADS;
const TOTAL_MSGS: i32 = (THREADS * MSGS) / 2;
basically_works_test(THREADS, MSGS, TOTAL_MSGS);
}
fn basically_works_test(threads: i32, msgs: i32, total_msgs: i32) {
loom::model(move || {
let stub = entry(666);
let q = Arc::new(MpscQueue::<Entry>::new_with_stub(stub));
let threads: Vec<_> = (0..threads)
.map(|thread| thread::spawn(do_tx(thread, msgs, &q)))
.collect();
let mut i = 0;
while i < total_msgs {
match q.try_dequeue() {
Ok(val) => {
i += 1;
tracing::info!(?val, "dequeue {}/{}", i, total_msgs);
}
Err(TryDequeueError::Busy) => panic!(
"the queue should never be busy, as there is only a single consumer!"
),
Err(err) => {
tracing::info!(?err, "dequeue error");
thread::yield_now();
}
}
}
for thread in threads {
thread.join().unwrap();
}
})
}
fn do_tx(thread: i32, msgs: i32, q: &Arc<MpscQueue<Entry>>) -> impl FnOnce() + Send + Sync {
let q = q.clone();
move || {
for i in 0..msgs {
q.enqueue(entry(i + (thread * 10)));
tracing::info!(thread, "enqueue msg {}/{}", i, msgs);
}
}
}
#[test]
fn mpmc() {
const THREADS: i32 = 2;
const MSGS: i32 = THREADS;
fn do_rx(thread: i32, q: Arc<MpscQueue<Entry>>) {
let mut i = 0;
while let Some(val) = q.dequeue() {
tracing::info!(?val, ?thread, "dequeue {}/{}", i, THREADS * MSGS);
i += 1;
}
}
loom::model(|| {
let stub = entry(666);
let q = Arc::new(MpscQueue::<Entry>::new_with_stub(stub));
let mut threads: Vec<_> = (0..THREADS)
.map(|thread| thread::spawn(do_tx(thread, MSGS, &q)))
.collect();
threads.push(thread::spawn({
let q = q.clone();
move || do_rx(THREADS + 1, q)
}));
do_rx(THREADS + 2, q);
for thread in threads {
thread.join().unwrap();
}
})
}
}
#[cfg(all(test, not(loom)))]
mod tests {
use super::*;
use test_util::*;
use std::{ops::Deref, println, sync::Arc, thread};
#[test]
fn dequeue_empty() {
let stub = entry(666);
let q = MpscQueue::<Entry>::new_with_stub(stub);
assert_eq!(q.dequeue(), None)
}
#[test]
fn try_dequeue_empty() {
let stub = entry(666);
let q = MpscQueue::<Entry>::new_with_stub(stub);
assert_eq!(q.try_dequeue(), Err(TryDequeueError::Empty))
}
#[test]
fn try_dequeue_busy() {
let stub = entry(666);
let q = MpscQueue::<Entry>::new_with_stub(stub);
let consumer = q.try_consume().expect("must acquire consumer");
assert_eq!(consumer.try_dequeue(), Err(TryDequeueError::Empty));
q.enqueue(entry(1));
assert_eq!(q.try_dequeue(), Err(TryDequeueError::Busy));
assert_eq!(consumer.try_dequeue(), Ok(entry(1)),);
assert_eq!(q.try_dequeue(), Err(TryDequeueError::Busy));
assert_eq!(consumer.try_dequeue(), Err(TryDequeueError::Empty));
drop(consumer);
assert_eq!(q.try_dequeue(), Err(TryDequeueError::Empty));
}
#[test]
fn enqueue_dequeue() {
let stub = entry(666);
let e = entry(1);
let q = MpscQueue::<Entry>::new_with_stub(stub);
q.enqueue(e);
assert_eq!(q.dequeue(), Some(entry(1)));
assert_eq!(q.dequeue(), None)
}
#[test]
fn basically_works() {
let stub = entry(666);
let q = MpscQueue::<Entry>::new_with_stub(stub);
let q = Arc::new(q);
test_basically_works(q);
}
#[test]
fn basically_works_all_const() {
static STUB_ENTRY: Entry = const_stub_entry(666);
static MPSC: MpscQueue<Entry> =
unsafe { MpscQueue::<Entry>::new_with_static_stub(&STUB_ENTRY) };
test_basically_works(&MPSC);
}
#[test]
fn basically_works_mixed_const() {
static STUB_ENTRY: Entry = const_stub_entry(666);
let q = unsafe { MpscQueue::<Entry>::new_with_static_stub(&STUB_ENTRY) };
let q = Arc::new(q);
test_basically_works(q)
}
fn test_basically_works<Q>(q: Q)
where
Q: Deref<Target = MpscQueue<Entry>> + Clone,
Q: Send + 'static,
{
const THREADS: i32 = if_miri(3, 8);
const MSGS: i32 = if_miri(10, 1000);
assert_eq!(q.dequeue(), None);
let threads: Vec<_> = (0..THREADS)
.map(|thread| {
let q = q.clone();
thread::spawn(move || {
for i in 0..MSGS {
q.enqueue(entry(i));
println!("thread {}; msg {}/{}", thread, i, MSGS);
}
})
})
.collect();
let mut i = 0;
while i < THREADS * MSGS {
match q.try_dequeue() {
Ok(msg) => {
i += 1;
println!("recv {:?} ({}/{})", msg, i, THREADS * MSGS);
}
Err(TryDequeueError::Busy) => {
panic!("the queue should never be busy, as there is only one consumer")
}
Err(e) => {
println!("recv error {:?}", e);
thread::yield_now();
}
}
}
for thread in threads {
thread.join().unwrap();
}
}
const fn if_miri(miri: i32, not_miri: i32) -> i32 {
if cfg!(miri) {
miri
} else {
not_miri
}
}
}
#[cfg(test)]
mod test_util {
use super::*;
use crate::loom::alloc;
pub use std::{boxed::Box, pin::Pin, println, ptr, vec, vec::Vec};
pub(super) struct Entry {
links: Links<Entry>,
pub(super) val: i32,
_track: alloc::Track<()>,
}
impl std::cmp::PartialEq for Entry {
fn eq(&self, other: &Self) -> bool {
self.val == other.val
}
}
unsafe impl Linked<Links<Self>> for Entry {
type Handle = Pin<Box<Entry>>;
fn into_ptr(handle: Pin<Box<Entry>>) -> NonNull<Entry> {
unsafe { NonNull::from(Box::leak(Pin::into_inner_unchecked(handle))) }
}
unsafe fn from_ptr(ptr: NonNull<Entry>) -> Pin<Box<Entry>> {
Pin::new_unchecked(Box::from_raw(ptr.as_ptr()))
}
unsafe fn links(target: NonNull<Entry>) -> NonNull<Links<Entry>> {
let links = ptr::addr_of_mut!((*target.as_ptr()).links);
NonNull::new_unchecked(links)
}
}
impl fmt::Debug for Entry {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Entry")
.field("links", &self.links)
.field("val", &self.val)
.finish()
}
}
#[cfg(not(loom))]
pub(super) const fn const_stub_entry(val: i32) -> Entry {
Entry {
links: Links::new_stub(),
val,
_track: alloc::Track::new_const(()),
}
}
pub(super) fn entry(val: i32) -> Pin<Box<Entry>> {
Box::pin(Entry {
links: Links::new(),
val,
_track: alloc::Track::new(()),
})
}
}