use super::shared::Shared;
use crate::barrier::SequenceBarrier;
use crate::error::{RecvError, TryRecvError};
use cpu::{fence_acquire, fence_release, SpinLoopHintWait, WaitStrategy};
use std::sync::atomic::Ordering;
use std::sync::Arc;
pub struct Consumer<T, W: WaitStrategy = SpinLoopHintWait> {
pub(super) shared: Arc<Shared<T>>,
pub(super) next_sequence: i64,
pub(super) barrier: SequenceBarrier<W>,
pub(super) cached_disconnected: bool,
}
impl<T, W: WaitStrategy> Consumer<T, W> {
pub fn recv(&mut self) -> Result<T, RecvError> {
loop {
let available = self.barrier.get_cursor_relaxed();
if available >= self.next_sequence {
let event = unsafe { self.shared.buffer.read(self.next_sequence) };
self.shared.consumer_cursor.set_relaxed(self.next_sequence);
self.next_sequence += 1;
return Ok(event);
}
if self.is_disconnected() {
if self.barrier.get_cursor() >= self.next_sequence {
continue;
}
return Err(RecvError);
}
let _ = self.barrier.wait_for(self.next_sequence);
}
}
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
if !self.barrier.is_available(self.next_sequence) {
if self.is_disconnected() {
return Err(TryRecvError::Disconnected);
}
return Err(TryRecvError::Empty);
}
let event = unsafe { self.shared.buffer.read(self.next_sequence) };
self.shared.consumer_cursor.set_relaxed(self.next_sequence);
self.next_sequence += 1;
Ok(event)
}
pub fn peek(&mut self) -> Result<&T, TryRecvError> {
if !self.barrier.is_available(self.next_sequence) {
if self.is_disconnected() {
return Err(TryRecvError::Disconnected);
}
return Err(TryRecvError::Empty);
}
Ok(unsafe { self.shared.buffer.get(self.next_sequence) })
}
pub fn close(&self) {
self.shared.closed.store(true, Ordering::Release);
}
#[inline]
pub fn is_disconnected(&mut self) -> bool {
if self.cached_disconnected {
return true;
}
let disconnected =
self.shared.closed.load(Ordering::Relaxed) || Arc::strong_count(&self.shared) == 1;
if disconnected {
self.cached_disconnected = true;
}
disconnected
}
#[inline]
pub fn pending(&self) -> usize {
let producer_seq = self.barrier.get_cursor();
let diff = producer_seq - self.next_sequence + 1;
if diff < 0 {
0
} else {
diff as usize
}
}
#[inline]
pub fn has_pending(&self) -> bool {
self.barrier.is_available(self.next_sequence)
}
#[inline]
pub fn poll<F>(&mut self, mut handler: F) -> usize
where
F: FnMut(&T, i64, bool),
{
let available = self.barrier.get_cursor_relaxed();
if available < self.next_sequence {
return 0;
}
let count = (available - self.next_sequence + 1) as usize;
for i in 0..count {
let seq = self.next_sequence + i as i64;
let end_of_batch = i == count - 1;
let event = unsafe { self.shared.buffer.get(seq) };
handler(event, seq, end_of_batch);
}
let last_seq = self.next_sequence + count as i64 - 1;
self.next_sequence += count as i64;
fence_release();
self.shared.consumer_cursor.set_relaxed(last_seq);
count
}
pub fn recv_batch(&mut self, out: &mut [T]) -> Result<usize, RecvError> {
if out.is_empty() {
return Ok(0);
}
loop {
let available = self.barrier.get_cursor_relaxed();
if available >= self.next_sequence {
fence_acquire();
let available_count = (available - self.next_sequence + 1) as usize;
let count = core::cmp::min(out.len(), available_count);
for i in 0..count {
out[i] = unsafe { self.shared.buffer.read(self.next_sequence + i as i64) };
}
let last_seq = self.next_sequence + count as i64 - 1;
self.next_sequence += count as i64;
self.shared.consumer_cursor.set_relaxed(last_seq);
return Ok(count);
}
if self.is_disconnected() {
if self.barrier.get_cursor() >= self.next_sequence {
continue;
}
return Err(RecvError);
}
let _ = self.barrier.wait_for(self.next_sequence);
}
}
pub fn try_recv_batch(&mut self, out: &mut [T]) -> Result<usize, TryRecvError> {
if out.is_empty() {
return Ok(0);
}
let available = self.barrier.get_cursor_relaxed();
if available < self.next_sequence {
if self.is_disconnected() {
return Err(TryRecvError::Disconnected);
}
return Err(TryRecvError::Empty);
}
fence_acquire();
let available_count = (available - self.next_sequence + 1) as usize;
let count = core::cmp::min(out.len(), available_count);
for i in 0..count {
out[i] = unsafe { self.shared.buffer.read(self.next_sequence + i as i64) };
}
let last_seq = self.next_sequence + count as i64 - 1;
self.next_sequence += count as i64;
self.shared.consumer_cursor.set_relaxed(last_seq);
Ok(count)
}
pub fn recv_all(&mut self) -> Result<Vec<T>, RecvError> {
loop {
let available = self.barrier.get_cursor_relaxed();
if available >= self.next_sequence {
let count = (available - self.next_sequence + 1) as usize;
let mut out = Vec::with_capacity(count);
for i in 0..count {
out.push(unsafe { self.shared.buffer.read(self.next_sequence + i as i64) });
}
let last_seq = self.next_sequence + count as i64 - 1;
self.next_sequence += count as i64;
fence_release();
self.shared.consumer_cursor.set_relaxed(last_seq);
return Ok(out);
}
if self.is_disconnected() {
if self.barrier.get_cursor() >= self.next_sequence {
continue;
}
return Err(RecvError);
}
let _ = self.barrier.wait_for(self.next_sequence);
}
}
}
impl<T, W: WaitStrategy> Drop for Consumer<T, W> {
fn drop(&mut self) {
self.close();
}
}
unsafe impl<T: Send, W: WaitStrategy> Send for Consumer<T, W> {}
impl<T, W: WaitStrategy> core::fmt::Debug for Consumer<T, W> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let disconnected = self.cached_disconnected
|| self.shared.closed.load(Ordering::Relaxed)
|| Arc::strong_count(&self.shared) == 1;
f.debug_struct("Consumer")
.field("next_sequence", &self.next_sequence)
.field("pending", &self.pending())
.field("disconnected", &disconnected)
.finish()
}
}