use crate::pod::Pod;
use crate::ring::SharedRing;
use crate::slot::Slot;
use alloc::sync::Arc;
use core::sync::atomic::{AtomicU64, Ordering};
use super::prefetch_write_next;
pub struct MpPublisher<T: Pod> {
pub(super) ring: Arc<SharedRing<T>>,
pub(super) slots_ptr: *const Slot<T>,
pub(super) capacity: u64,
pub(super) mask: u64,
pub(super) reciprocal: u64,
pub(super) is_pow2: bool,
pub(super) cursor_ptr: *const AtomicU64,
pub(super) next_seq_ptr: *const AtomicU64,
}
impl<T: Pod> Clone for MpPublisher<T> {
fn clone(&self) -> Self {
MpPublisher {
ring: self.ring.clone(),
slots_ptr: self.slots_ptr,
capacity: self.capacity,
mask: self.mask,
reciprocal: self.reciprocal,
is_pow2: self.is_pow2,
cursor_ptr: self.cursor_ptr,
next_seq_ptr: self.next_seq_ptr,
}
}
}
unsafe impl<T: Pod> Send for MpPublisher<T> {}
unsafe impl<T: Pod> Sync for MpPublisher<T> {}
impl<T: Pod> MpPublisher<T> {
#[inline(always)]
fn slot_index(&self, seq: u64) -> usize {
if self.is_pow2 {
(seq & self.mask) as usize
} else {
let q = ((seq as u128 * self.reciprocal as u128) >> 64) as u64;
let mut r = seq - q.wrapping_mul(self.capacity);
if r >= self.capacity {
r -= self.capacity;
}
r as usize
}
}
#[inline]
pub fn publish(&self, value: T) {
let next_seq_atomic = unsafe { &*self.next_seq_ptr };
let seq = next_seq_atomic.fetch_add(1, Ordering::AcqRel);
let slot = unsafe { &*self.slots_ptr.add(self.slot_index(seq)) };
prefetch_write_next(self.slots_ptr, self.slot_index(seq + 1) as u64);
slot.write(seq, value);
self.advance_cursor(seq);
}
#[inline]
pub fn publish_with(&self, f: impl FnOnce(&mut core::mem::MaybeUninit<T>)) {
let next_seq_atomic = unsafe { &*self.next_seq_ptr };
let seq = next_seq_atomic.fetch_add(1, Ordering::AcqRel);
let slot = unsafe { &*self.slots_ptr.add(self.slot_index(seq)) };
prefetch_write_next(self.slots_ptr, self.slot_index(seq + 1) as u64);
slot.write_with(seq, f);
self.advance_cursor(seq);
}
#[inline]
pub fn published(&self) -> u64 {
unsafe { &*self.next_seq_ptr }.load(Ordering::Relaxed)
}
#[inline]
pub fn capacity(&self) -> u64 {
self.ring.capacity()
}
#[inline]
fn advance_cursor(&self, seq: u64) {
let cursor_atomic = unsafe { &*self.cursor_ptr };
let expected_cursor = if seq == 0 { u64::MAX } else { seq - 1 };
if cursor_atomic
.compare_exchange(expected_cursor, seq, Ordering::Release, Ordering::Relaxed)
.is_ok()
{
self.catch_up_cursor(seq);
return;
}
if seq > 0 {
let pred_slot = unsafe { &*self.slots_ptr.add(self.slot_index(seq - 1)) };
let pred_done = (seq - 1) * 2 + 2;
#[cfg(target_arch = "aarch64")]
unsafe {
core::arch::asm!("sevl", options(nomem, nostack));
}
while pred_slot.stamp_load() < pred_done {
#[cfg(target_arch = "aarch64")]
unsafe {
core::arch::asm!("wfe", options(nomem, nostack));
}
#[cfg(not(target_arch = "aarch64"))]
core::hint::spin_loop();
}
}
let _ = cursor_atomic.compare_exchange(
expected_cursor,
seq,
Ordering::Release,
Ordering::Relaxed,
);
if cursor_atomic.load(Ordering::Relaxed) == seq {
self.catch_up_cursor(seq);
}
}
#[inline]
fn catch_up_cursor(&self, mut seq: u64) {
let cursor_atomic = unsafe { &*self.cursor_ptr };
let next_seq_atomic = unsafe { &*self.next_seq_ptr };
loop {
let next = seq + 1;
if next >= next_seq_atomic.load(Ordering::Acquire) {
break;
}
let done_stamp = next * 2 + 2;
let slot = unsafe { &*self.slots_ptr.add(self.slot_index(next)) };
if slot.stamp_load() < done_stamp {
break;
}
if cursor_atomic
.compare_exchange(seq, next, Ordering::Release, Ordering::Relaxed)
.is_err()
{
break;
}
seq = next;
}
}
}