use std::num::NonZeroUsize;
use std::sync::atomic::Ordering;
use crate::QueueConfig;
use crate::cacheline_aligned;
use crate::error::*;
use crate::shm::{Chunk, Span};
use crate::AtomicIndex;
use crate::Index;
use crate::MIN_MSGS;
const INVALID_INDEX: Index = Index::MAX;
const CONSUMED_FLAG: Index = Index::MAX - Index::MAX / 2;
const ORIGIN_MASK: Index = CONSUMED_FLAG;
const INDEX_MASK: Index = !ORIGIN_MASK;
#[derive(PartialEq, Eq)]
pub enum ConsumeResult {
QueueError,
NoMessage,
NoNewMessage,
Success,
SuccessMessagesDiscarded,
}
#[derive(PartialEq, Eq)]
pub enum ProduceForceResult {
QueueError,
Success,
SuccessMessageDiscarded,
}
#[derive(PartialEq, Eq)]
pub enum ProduceTryResult {
QueueError,
QueueFull,
Success,
}
pub(crate) struct Queue {
_chunk: Chunk,
message_size: NonZeroUsize,
head: *mut Index,
tail: *mut Index,
chain: Vec<*mut Index>,
messages: Vec<*mut ()>,
}
impl Queue {
pub(crate) fn new(chunk: Chunk, config: &QueueConfig) -> Result<Self, ShmMapError> {
let queue_len = config.additional_messages + MIN_MSGS;
let index_size = size_of::<Index>();
let queue_size = (2 + queue_len) * index_size;
let message_size = NonZeroUsize::new(cacheline_aligned(config.message_size.get())).unwrap();
let mut offset_index = 0;
let mut offset = cacheline_aligned(queue_size);
let tail: *mut Index = chunk.get_ptr(offset_index)?;
offset_index += index_size;
let head: *mut Index = chunk.get_ptr(offset_index)?;
offset_index += index_size;
let mut chain: Vec<*mut Index> = Vec::with_capacity(queue_len);
let mut messages: Vec<*mut ()> = Vec::with_capacity(queue_len);
for _ in 0..queue_len {
let index: *mut Index = chunk.get_ptr(offset_index)?;
let message: *mut () = chunk.get_span_ptr(&Span {
offset,
size: message_size,
})?;
chain.push(index);
messages.push(message);
offset_index += index_size;
offset += message_size.get();
}
Ok(Self {
_chunk: chunk,
message_size,
head,
tail,
chain,
messages,
})
}
fn is_valid_index(&self, idx: Index) -> bool {
idx < self.len() as u32
}
pub(crate) fn init(&self) {
self.tail_store(INVALID_INDEX);
self.head_store(INVALID_INDEX);
}
pub(crate) fn message_size(&self) -> NonZeroUsize {
self.message_size
}
fn tail(&self) -> &AtomicIndex {
unsafe { AtomicIndex::from_ptr(self.tail) }
}
fn head(&self) -> &AtomicIndex {
unsafe { AtomicIndex::from_ptr(self.head) }
}
fn chain(&self, idx: Index) -> &AtomicIndex {
unsafe { AtomicIndex::from_ptr(self.chain[idx as usize]) }
}
pub(self) fn tail_load(&self) -> Index {
self.tail().load(Ordering::SeqCst)
}
pub(self) fn tail_store(&self, val: Index) {
self.tail().store(val, Ordering::SeqCst)
}
pub(self) fn tail_fetch_or(&self, val: Index) -> Index {
self.tail().fetch_or(val, Ordering::SeqCst)
}
pub(self) fn tail_compare_exchange(&self, current: Index, new: Index) -> bool {
self.tail()
.compare_exchange(current, new, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
}
pub(self) fn head_load(&self) -> Index {
self.head().load(Ordering::SeqCst)
}
pub(self) fn head_store(&self, val: Index) {
self.head().store(val, Ordering::SeqCst);
}
pub(self) fn chain_load(&self, idx: Index) -> Index {
self.chain(idx).load(Ordering::SeqCst)
}
pub(self) fn queue_store(&self, idx: Index, val: Index) {
self.chain(idx).store(val, Ordering::SeqCst);
}
pub(self) fn len(&self) -> usize {
self.chain.len()
}
}
unsafe impl Send for Queue {}
pub struct ProducerQueue {
queue: Queue,
chain: Vec<Index>,
head: Index,
current: Index,
overrun: Index,
}
impl ProducerQueue {
pub(crate) fn new(queue: Queue) -> Self {
let queue_len = queue.len();
let mut chain: Vec<Index> = Vec::with_capacity(queue_len);
let last = queue_len - 1;
for i in 0..last {
let next = i + 1;
queue.queue_store(i as Index, next as Index);
chain.push(next as Index);
}
queue.queue_store(last as Index, 0);
chain.push(0);
Self {
queue,
head: INVALID_INDEX,
chain,
current: 0,
overrun: INVALID_INDEX,
}
}
pub(crate) fn current_message(&self) -> *mut () {
let ptr = self.queue.messages.get(self.current as usize).unwrap();
ptr.cast()
}
fn queue_store(&mut self, idx: Index, val: Index) {
self.chain[idx as usize] = val;
self.queue.queue_store(idx, val);
}
fn move_tail(&self, tail: Index) -> bool {
let next = self.chain[(tail & INDEX_MASK) as usize];
self.queue.tail_compare_exchange(tail, next)
}
fn enqueue_first_message(&mut self) {
self.queue_store(self.current, INVALID_INDEX);
self.queue.tail_store(self.current);
self.head = self.current;
self.queue.head_store(self.head);
}
fn enqueue_message(&mut self) {
self.queue_store(self.current, INVALID_INDEX);
self.queue_store(self.head, self.current);
self.head = self.current;
self.queue.head_store(self.head);
}
fn overrun(&mut self, tail: Index) -> bool {
let queue = &mut self.queue;
let new_current = self.chain[(tail & INDEX_MASK) as usize];
let new_tail = self.chain[new_current as usize];
if queue.tail_compare_exchange(tail, new_tail) {
self.overrun = tail & INDEX_MASK;
self.current = new_current;
true
} else {
self.current = tail & INDEX_MASK;
false
}
}
pub(crate) fn full(&self) -> bool {
if self.head == INVALID_INDEX {
return false;
}
let tail = self.queue.tail_load();
if !self.queue.is_valid_index(tail & INDEX_MASK) {
return false;
}
if self.overrun != INVALID_INDEX {
let consumed: bool = (tail & CONSUMED_FLAG) != 0;
!consumed
} else {
let next = self.chain[self.current as usize];
let full: bool = next == (tail & INDEX_MASK);
!full
}
}
pub(crate) fn force_push(&mut self) -> ProduceForceResult {
let next = self.chain[self.current as usize];
if self.head == INVALID_INDEX {
self.enqueue_first_message();
self.current = next;
return ProduceForceResult::Success;
}
let mut discarded = false;
self.enqueue_message();
let tail = self.queue.tail_load();
if !self.queue.is_valid_index(tail & INDEX_MASK) {
return ProduceForceResult::QueueError;
}
let consumed: bool = (tail & CONSUMED_FLAG) != 0;
if self.overrun != INVALID_INDEX {
if consumed {
self.queue_store(self.overrun, next);
self.current = self.overrun;
self.overrun = INVALID_INDEX;
} else {
if self.move_tail(tail) {
self.current = tail & INDEX_MASK;
discarded = true;
} else {
self.queue_store(self.overrun, next);
self.current = self.overrun;
self.overrun = INVALID_INDEX;
}
}
} else {
let full: bool = next == (tail & INDEX_MASK);
if !full {
self.current = next;
} else if !consumed {
if self.move_tail(tail) {
self.current = next;
discarded = true;
} else {
discarded = self.overrun(tail | CONSUMED_FLAG);
}
} else {
discarded = self.overrun(tail);
}
}
if discarded {
ProduceForceResult::SuccessMessageDiscarded
} else {
ProduceForceResult::Success
}
}
pub(crate) fn try_push(&mut self) -> ProduceTryResult {
let next = self.chain[self.current as usize];
if self.head == INVALID_INDEX {
self.enqueue_first_message();
self.current = next;
return ProduceTryResult::Success;
}
let tail = self.queue.tail_load();
if !self.queue.is_valid_index(tail & INDEX_MASK) {
return ProduceTryResult::QueueError;
}
if self.overrun != INVALID_INDEX {
let consumed = (tail & CONSUMED_FLAG) != 0;
if consumed {
self.enqueue_message();
self.queue_store(self.overrun, next);
self.current = self.overrun;
self.overrun = INVALID_INDEX;
return ProduceTryResult::Success;
}
} else {
let full = next == (tail & INDEX_MASK);
if !full {
self.enqueue_message();
self.current = next;
return ProduceTryResult::Success;
}
}
ProduceTryResult::QueueFull
}
}
pub struct ConsumerQueue {
queue: Queue,
current: Index,
}
impl ConsumerQueue {
pub(crate) fn new(queue: Queue) -> Self {
Self { queue, current: 0 }
}
pub(crate) fn current_message(&self) -> Option<*const ()> {
let ptr = self.queue.messages.get(self.current as usize)?;
Some(ptr.cast())
}
pub(crate) fn flush(&mut self) -> ConsumeResult {
loop {
let tail = self.queue.tail_fetch_or(CONSUMED_FLAG);
if tail == INVALID_INDEX {
return ConsumeResult::NoMessage;
}
if !self.queue.is_valid_index(tail & INDEX_MASK) {
return ConsumeResult::QueueError;
}
let head = self.queue.head_load();
if !self.queue.is_valid_index(head) {
return ConsumeResult::QueueError;
}
if self
.queue
.tail_compare_exchange(tail | CONSUMED_FLAG, head | CONSUMED_FLAG)
{
self.current = head;
return ConsumeResult::Success;
}
}
}
pub(crate) fn pop(&mut self) -> ConsumeResult {
let tail = self.queue.tail_fetch_or(CONSUMED_FLAG);
if tail == INVALID_INDEX {
return ConsumeResult::NoMessage;
}
if !self.queue.is_valid_index(tail & INDEX_MASK) {
return ConsumeResult::QueueError;
}
if tail & CONSUMED_FLAG == 0 {
self.current = tail;
return ConsumeResult::SuccessMessagesDiscarded;
}
let next = self.queue.chain_load(self.current);
if next == INVALID_INDEX {
return ConsumeResult::NoNewMessage;
}
if !self.queue.is_valid_index(next) {
return ConsumeResult::QueueError;
}
if self.queue.tail_compare_exchange(tail, next | CONSUMED_FLAG) {
self.current = next;
ConsumeResult::Success
} else {
let current = self.queue.tail_fetch_or(CONSUMED_FLAG);
if !self.queue.is_valid_index(current) {
return ConsumeResult::QueueError;
}
self.current = current;
ConsumeResult::SuccessMessagesDiscarded
}
}
}