use super::batch::MpscBatchSlots;
use super::shared::Shared;
use crate::common::wait_backoff;
use crate::error::{SendError, TrySendError};
use std::sync::atomic::Ordering;
use std::sync::Arc;
pub struct Producer<T> {
pub(super) shared: Arc<Shared<T>>,
}
impl<T> Producer<T> {
pub fn send(&self, event: T) -> Result<(), SendError<T>> {
if self.is_disconnected() {
return Err(SendError(event));
}
let sequence = self.shared.claim_cursor.increment();
let wrap_point = sequence - self.shared.capacity as i64;
let mut iteration = 0u32;
while self.shared.consumer_cursor.value_relaxed() < wrap_point {
if self.is_disconnected() {
return Err(SendError(event));
}
wait_backoff(&mut iteration);
}
unsafe {
self.shared.buffer.write(sequence, event);
}
self.publish_sequence(sequence);
Ok(())
}
pub fn try_send(&self, event: T) -> Result<(), TrySendError<T>> {
if self.is_disconnected() {
return Err(TrySendError::Disconnected(event));
}
let sequence = match self.try_claim_sequence() {
Ok(seq) => seq,
Err(()) => return Err(TrySendError::Full(event)),
};
unsafe {
self.shared.buffer.write(sequence, event);
}
self.publish_sequence(sequence);
Ok(())
}
#[inline]
fn try_claim_sequence(&self) -> Result<i64, ()> {
for _ in 0..8 {
let current = self.shared.claim_cursor.value_relaxed();
let next = current + 1;
let wrap_point = next - self.shared.capacity as i64;
let consumer_seq = self.shared.consumer_cursor.value_relaxed();
if consumer_seq < wrap_point {
return Err(());
}
if self
.shared
.claim_cursor
.compare_exchange(current, next)
.is_ok()
{
return Ok(next);
}
core::hint::spin_loop();
}
Err(())
}
#[inline]
fn publish_sequence(&self, sequence: i64) {
self.shared.mark_published(sequence);
}
pub fn claim_batch(&self, n: usize) -> Result<MpscBatchSlots<'_, T>, SendError<()>> {
assert!(n > 0, "batch size must be > 0");
assert!(n <= self.shared.capacity, "batch size exceeds capacity");
if self.is_disconnected() {
return Err(SendError(()));
}
let n_i64 = n as i64;
let end_sequence = self.shared.claim_cursor.add(n_i64);
let start = end_sequence - n_i64 + 1;
let wrap_point = end_sequence - self.shared.capacity as i64;
let mut iteration = 0u32;
while self.shared.consumer_cursor.value_relaxed() < wrap_point {
if self.is_disconnected() {
return Err(SendError(()));
}
wait_backoff(&mut iteration);
}
Ok(MpscBatchSlots {
shared: &self.shared,
start,
count: n,
published: false,
})
}
pub fn try_claim_batch(&self, n: usize) -> Result<MpscBatchSlots<'_, T>, TrySendError<()>> {
assert!(n > 0, "batch size must be > 0");
assert!(n <= self.shared.capacity, "batch size exceeds capacity");
if self.is_disconnected() {
return Err(TrySendError::Disconnected(()));
}
let n_i64 = n as i64;
for _ in 0..8 {
let current = self.shared.claim_cursor.value_relaxed();
let next = current + n_i64;
let wrap_point = next - self.shared.capacity as i64;
let consumer_seq = self.shared.consumer_cursor.value_relaxed();
if consumer_seq < wrap_point {
return Err(TrySendError::Full(()));
}
if self
.shared
.claim_cursor
.compare_exchange(current, next)
.is_ok()
{
return Ok(MpscBatchSlots {
shared: &self.shared,
start: current + 1,
count: n,
published: false,
});
}
core::hint::spin_loop();
}
Err(TrySendError::Full(()))
}
pub fn close(&self) {
self.shared.closed.store(true, Ordering::Release);
}
#[inline]
pub fn is_disconnected(&self) -> bool {
self.shared.closed.load(Ordering::Acquire)
|| (self.shared.producer_count.load(Ordering::Acquire) == 0
&& Arc::strong_count(&self.shared) == 1)
}
}
impl<T> Drop for Producer<T> {
fn drop(&mut self) {
let prev = self.shared.producer_count.fetch_sub(1, Ordering::AcqRel);
if prev == 1 {
self.shared.closed.store(true, Ordering::Release);
}
}
}
impl<T> Clone for Producer<T> {
fn clone(&self) -> Self {
self.shared.producer_count.fetch_add(1, Ordering::AcqRel);
Self {
shared: Arc::clone(&self.shared),
}
}
}
unsafe impl<T: Send> Send for Producer<T> {}
unsafe impl<T: Send> Sync for Producer<T> {}
impl<T> core::fmt::Debug for Producer<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("mpsc::Producer")
.field(
"producer_count",
&self.shared.producer_count.load(Ordering::Relaxed),
)
.field("disconnected", &self.is_disconnected())
.finish()
}
}