use super::shared::Shared;
use crate::common::wait_backoff;
use crate::error::{RecvError, TryRecvError};
use cpu::{SpinLoopHintWait, WaitStrategy, fence_acquire};
use std::sync::atomic::Ordering;
use std::sync::Arc;
pub struct Consumer<T, W: WaitStrategy = SpinLoopHintWait> {
pub(super) shared: Arc<Shared<T>>,
pub(super) next_sequence: i64,
pub(super) _wait_strategy: W,
}
impl<T, W: WaitStrategy> Consumer<T, W> {
pub fn recv(&mut self) -> Result<T, RecvError> {
let mut iteration = 0u32;
loop {
if self.shared.is_published(self.next_sequence) {
fence_acquire();
let event = unsafe { self.shared.buffer.read(self.next_sequence) };
self.shared.consumer_cursor.set_relaxed(self.next_sequence);
self.next_sequence += 1;
return Ok(event);
}
if self.is_disconnected() {
if self.shared.is_published(self.next_sequence) {
continue;
}
return Err(RecvError);
}
wait_backoff(&mut iteration);
}
}
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
if !self.shared.is_published(self.next_sequence) {
if self.is_disconnected() {
return Err(TryRecvError::Disconnected);
}
return Err(TryRecvError::Empty);
}
fence_acquire();
let event = unsafe { self.shared.buffer.read(self.next_sequence) };
self.shared.consumer_cursor.set_relaxed(self.next_sequence);
self.next_sequence += 1;
Ok(event)
}
pub fn peek(&self) -> Result<&T, TryRecvError> {
if !self.shared.is_published(self.next_sequence) {
if self.is_disconnected() {
return Err(TryRecvError::Disconnected);
}
return Err(TryRecvError::Empty);
}
Ok(unsafe { self.shared.buffer.get(self.next_sequence) })
}
pub fn close(&self) {
self.shared.closed.store(true, Ordering::Release);
}
#[inline]
pub fn is_disconnected(&self) -> bool {
self.shared.closed.load(Ordering::Acquire)
}
#[inline]
pub fn pending(&self) -> usize {
let highest = self.shared.highest_published(self.next_sequence);
let diff = highest - self.next_sequence + 1;
if diff < 0 { 0 } else { diff as usize }
}
pub fn recv_batch(&mut self, out: &mut [T]) -> Result<usize, RecvError> {
if out.is_empty() {
return Ok(0);
}
let mut iteration = 0u32;
loop {
let highest = self.shared.highest_published(self.next_sequence);
if highest >= self.next_sequence {
let available = (highest - self.next_sequence + 1) as usize;
let count = available.min(out.len());
for i in 0..count {
let seq = self.next_sequence + i as i64;
out[i] = unsafe { self.shared.buffer.read(seq) };
}
let last_seq = self.next_sequence + count as i64 - 1;
self.next_sequence += count as i64;
self.shared.consumer_cursor.set_relaxed(last_seq);
return Ok(count);
}
if self.is_disconnected() {
return Err(RecvError);
}
wait_backoff(&mut iteration);
}
}
pub fn try_recv_batch(&mut self, out: &mut [T]) -> Result<usize, TryRecvError> {
if out.is_empty() {
return Ok(0);
}
let highest = self.shared.highest_published(self.next_sequence);
if highest < self.next_sequence {
if self.is_disconnected() {
return Err(TryRecvError::Disconnected);
}
return Err(TryRecvError::Empty);
}
let available = (highest - self.next_sequence + 1) as usize;
let count = available.min(out.len());
for i in 0..count {
let seq = self.next_sequence + i as i64;
out[i] = unsafe { self.shared.buffer.read(seq) };
}
let last_seq = self.next_sequence + count as i64 - 1;
self.next_sequence += count as i64;
self.shared.consumer_cursor.set_relaxed(last_seq);
Ok(count)
}
#[inline]
pub fn poll<F>(&mut self, mut handler: F) -> usize
where
F: FnMut(&T, i64, bool),
{
let highest = self.shared.highest_published(self.next_sequence);
if highest < self.next_sequence {
return 0;
}
let count = (highest - self.next_sequence + 1) as usize;
for i in 0..count {
let seq = self.next_sequence + i as i64;
let end_of_batch = i == count - 1;
let event = unsafe { self.shared.buffer.get(seq) };
handler(event, seq, end_of_batch);
}
let last_seq = self.next_sequence + count as i64 - 1;
self.next_sequence += count as i64;
self.shared.consumer_cursor.set_relaxed(last_seq);
count
}
}
impl<T, W: WaitStrategy> Drop for Consumer<T, W> {
fn drop(&mut self) {
self.close();
}
}
unsafe impl<T: Send, W: WaitStrategy> Send for Consumer<T, W> {}
impl<T, W: WaitStrategy> core::fmt::Debug for Consumer<T, W> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("mpsc::Consumer")
.field("next_sequence", &self.next_sequence)
.field("pending", &self.pending())
.field("disconnected", &self.is_disconnected())
.finish()
}
}