use super::batch::BatchSlots;
use super::shared::Shared;
use crate::common::wait_backoff;
use crate::error::{SendError, TrySendError};
use std::sync::atomic::Ordering;
use std::sync::Arc;
pub struct Producer<T> {
pub(super) shared: Arc<Shared<T>>,
pub(super) cached_consumer: i64,
pub(super) next_sequence: i64,
pub(super) capacity: i64,
pub(super) claimed: bool,
pub(super) cached_disconnected: bool,
}
impl<T> Producer<T> {
pub fn send(&mut self, event: T) -> Result<(), SendError<T>> {
assert!(!self.claimed, "must call publish() after claim()");
if self.is_disconnected() {
return Err(SendError(event));
}
if self.wait_for_space().is_err() {
return Err(SendError(event));
}
unsafe {
self.shared.buffer.write(self.next_sequence, event);
}
self.shared.producer_cursor.set(self.next_sequence);
self.next_sequence += 1;
Ok(())
}
pub fn try_send(&mut self, event: T) -> Result<(), TrySendError<T>> {
assert!(!self.claimed, "must call publish() after claim()");
if self.is_disconnected() {
return Err(TrySendError::Disconnected(event));
}
let wrap_point = self.next_sequence - self.capacity;
if self.cached_consumer < wrap_point {
self.cached_consumer = self.shared.consumer_cursor.value_relaxed();
if self.cached_consumer < wrap_point {
return Err(TrySendError::Full(event));
}
}
unsafe {
self.shared.buffer.write(self.next_sequence, event);
}
self.shared.producer_cursor.set(self.next_sequence);
self.next_sequence += 1;
Ok(())
}
pub fn claim(&mut self) -> Result<&mut T, SendError<()>> {
assert!(!self.claimed, "already claimed, must call publish() first");
if self.is_disconnected() {
return Err(SendError(()));
}
if self.wait_for_space().is_err() {
return Err(SendError(()));
}
self.claimed = true;
Ok(unsafe { self.shared.buffer.get_mut(self.next_sequence) })
}
pub fn try_claim(&mut self) -> Result<&mut T, TrySendError<()>> {
assert!(!self.claimed, "already claimed, must call publish() first");
if self.is_disconnected() {
return Err(TrySendError::Disconnected(()));
}
let wrap_point = self.next_sequence - self.capacity;
if self.cached_consumer < wrap_point {
self.cached_consumer = self.shared.consumer_cursor.value_relaxed();
if self.cached_consumer < wrap_point {
return Err(TrySendError::Full(()));
}
}
self.claimed = true;
Ok(unsafe { self.shared.buffer.get_mut(self.next_sequence) })
}
pub fn publish(&mut self) {
assert!(self.claimed, "must call claim() before publish()");
self.shared.producer_cursor.set(self.next_sequence);
self.next_sequence += 1;
self.claimed = false;
}
pub fn claim_batch(&mut self, n: usize) -> Result<BatchSlots<'_, T>, SendError<()>> {
assert!(!self.claimed, "already claimed, must call publish() first");
assert!(n > 0, "batch size must be > 0");
assert!(
n <= self.capacity as usize,
"batch size exceeds buffer capacity"
);
if self.is_disconnected() {
return Err(SendError(()));
}
self.wait_for_space_n(n as i64)?;
let start = self.next_sequence;
self.next_sequence += n as i64;
Ok(BatchSlots {
shared: &self.shared,
start,
count: n,
published: false,
})
}
pub fn try_claim_batch(&mut self, n: usize) -> Result<BatchSlots<'_, T>, TrySendError<()>> {
assert!(!self.claimed, "already claimed, must call publish() first");
assert!(n > 0, "batch size must be > 0");
assert!(
n <= self.capacity as usize,
"batch size exceeds buffer capacity"
);
if self.is_disconnected() {
return Err(TrySendError::Disconnected(()));
}
let wrap_point = self.next_sequence + n as i64 - 1 - self.capacity;
self.cached_consumer = self.shared.consumer_cursor.value_relaxed();
if self.cached_consumer < wrap_point {
return Err(TrySendError::Full(()));
}
let start = self.next_sequence;
self.next_sequence += n as i64;
Ok(BatchSlots {
shared: &self.shared,
start,
count: n,
published: false,
})
}
pub fn close(&self) {
self.shared.closed.store(true, Ordering::Release);
}
#[inline]
pub fn is_disconnected(&mut self) -> bool {
if self.cached_disconnected {
return true;
}
let disconnected =
self.shared.closed.load(Ordering::Relaxed) || Arc::strong_count(&self.shared) == 1;
if disconnected {
self.cached_disconnected = true;
}
disconnected
}
#[inline]
pub fn available_slots(&self) -> usize {
let consumer_seq = self.shared.consumer_cursor.value();
let pending = self.next_sequence - consumer_seq - 1;
if pending < 0 {
self.capacity as usize
} else {
(self.capacity as usize).saturating_sub(pending as usize)
}
}
#[inline]
fn wait_for_space(&mut self) -> Result<(), SendError<()>> {
self.wait_for_space_n(1)
}
#[inline]
fn wait_for_space_n(&mut self, n: i64) -> Result<(), SendError<()>> {
let wrap_point = self.next_sequence + n - 1 - self.capacity;
if self.cached_consumer >= wrap_point {
return Ok(());
}
let mut iteration = 0u32;
loop {
self.cached_consumer = self.shared.consumer_cursor.value_relaxed();
if self.cached_consumer >= wrap_point {
return Ok(());
}
if self.is_disconnected() {
return Err(SendError(()));
}
wait_backoff(&mut iteration);
}
}
}
impl<T> Drop for Producer<T> {
fn drop(&mut self) {
self.close();
}
}
unsafe impl<T: Send> Send for Producer<T> {}
impl<T> core::fmt::Debug for Producer<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let disconnected = self.cached_disconnected
|| self.shared.closed.load(Ordering::Relaxed)
|| Arc::strong_count(&self.shared) == 1;
f.debug_struct("Producer")
.field("next_sequence", &self.next_sequence)
.field("capacity", &self.capacity)
.field("claimed", &self.claimed)
.field("disconnected", &disconnected)
.finish()
}
}