use super::errors::TryRecvError;
use crate::pod::Pod;
use crate::ring::{Padded, SharedRing};
use crate::slot::Slot;
use crate::wait::WaitStrategy;
use alloc::sync::Arc;
use core::sync::atomic::{AtomicU64, Ordering};
pub struct SubscriberGroup<T: Pod, const N: usize> {
pub(super) ring: Arc<SharedRing<T>>,
pub(super) slots_ptr: *const Slot<T>,
pub(super) capacity: u64,
pub(super) mask: u64,
pub(super) reciprocal: u64,
pub(super) is_pow2: bool,
pub(super) cursor: u64,
pub(super) total_lagged: u64,
pub(super) total_received: u64,
pub(super) tracker: Option<Arc<Padded<AtomicU64>>>,
}
unsafe impl<T: Pod, const N: usize> Send for SubscriberGroup<T, N> {}
impl<T: Pod, const N: usize> SubscriberGroup<T, N> {
#[inline(always)]
fn slot_index(&self, seq: u64) -> usize {
if self.is_pow2 {
(seq & self.mask) as usize
} else {
let q = ((seq as u128 * self.reciprocal as u128) >> 64) as u64;
let mut r = seq - q.wrapping_mul(self.capacity);
if r >= self.capacity {
r -= self.capacity;
}
r as usize
}
}
#[inline]
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
let cur = self.cursor;
let slot = unsafe { &*self.slots_ptr.add(self.slot_index(cur)) };
let expected = cur * 2 + 2;
match slot.try_read(cur) {
Ok(Some(value)) => {
self.cursor = cur + 1;
self.total_received += 1;
self.update_tracker();
Ok(value)
}
Ok(None) => Err(TryRecvError::Empty),
Err(actual_stamp) => {
if actual_stamp & 1 != 0 || actual_stamp < expected {
return Err(TryRecvError::Empty);
}
let head = self.ring.cursor.0.load(Ordering::Acquire);
let cap = self.ring.capacity();
if head == u64::MAX || cur > head {
return Err(TryRecvError::Empty);
}
if head >= cap {
let oldest = head - cap + 1;
if cur < oldest {
let skipped = oldest - cur;
self.cursor = oldest;
self.total_lagged += skipped;
self.update_tracker();
return Err(TryRecvError::Lagged { skipped });
}
}
Err(TryRecvError::Empty)
}
}
}
#[inline]
pub fn recv(&mut self) -> T {
#[cfg(target_arch = "aarch64")]
unsafe {
core::arch::asm!("sevl", options(nomem, nostack));
}
loop {
match self.try_recv() {
Ok(val) => return val,
Err(TryRecvError::Empty) => {
#[cfg(target_arch = "aarch64")]
unsafe {
core::arch::asm!("wfe", options(nomem, nostack));
}
#[cfg(not(target_arch = "aarch64"))]
core::hint::spin_loop();
}
Err(TryRecvError::Lagged { .. }) => {}
}
}
}
#[inline]
pub fn recv_with(&mut self, strategy: WaitStrategy) -> T {
let cur = self.cursor;
let slot = unsafe { &*self.slots_ptr.add(self.slot_index(cur)) };
let expected = cur * 2 + 2;
let mut iter: u32 = 0;
loop {
match slot.try_read(cur) {
Ok(Some(value)) => {
self.cursor = cur + 1;
self.total_received += 1;
self.update_tracker();
return value;
}
Ok(None) => {
strategy.wait(iter);
iter = iter.saturating_add(1);
}
Err(stamp) => {
if stamp >= expected {
return self.recv_with_slow(strategy);
}
strategy.wait(iter);
iter = iter.saturating_add(1);
}
}
}
}
#[cold]
#[inline(never)]
fn recv_with_slow(&mut self, strategy: WaitStrategy) -> T {
let mut iter: u32 = 0;
loop {
match self.try_recv() {
Ok(val) => return val,
Err(TryRecvError::Empty) => {
strategy.wait(iter);
iter = iter.saturating_add(1);
}
Err(TryRecvError::Lagged { .. }) => {
iter = 0;
}
}
}
}
#[inline]
pub fn aligned_count(&self) -> usize {
N
}
#[inline]
pub fn pending(&self) -> u64 {
let head = self.ring.cursor.0.load(Ordering::Acquire);
if head == u64::MAX || self.cursor > head {
0
} else {
let raw = head - self.cursor + 1;
raw.min(self.ring.capacity())
}
}
#[inline]
pub fn total_received(&self) -> u64 {
self.total_received
}
#[inline]
pub fn total_lagged(&self) -> u64 {
self.total_lagged
}
#[inline]
pub fn receive_ratio(&self) -> f64 {
let total = self.total_received + self.total_lagged;
if total == 0 {
0.0
} else {
self.total_received as f64 / total as f64
}
}
#[inline]
pub fn recv_batch(&mut self, buf: &mut [T]) -> usize {
let mut count = 0;
for slot in buf.iter_mut() {
match self.try_recv() {
Ok(value) => {
*slot = value;
count += 1;
}
Err(TryRecvError::Empty) => break,
Err(TryRecvError::Lagged { .. }) => {
match self.try_recv() {
Ok(value) => {
*slot = value;
count += 1;
}
Err(_) => break,
}
}
}
}
count
}
#[inline]
fn update_tracker(&self) {
if let Some(ref tracker) = self.tracker {
tracker.0.store(self.cursor, Ordering::Relaxed);
}
}
}
impl<T: Pod, const N: usize> Drop for SubscriberGroup<T, N> {
fn drop(&mut self) {
if let Some(ref tracker) = self.tracker {
if let Some(ref bp) = self.ring.backpressure {
let weak = Arc::downgrade(tracker);
let mut trackers = bp.trackers.lock();
trackers.retain(|t| !t.ptr_eq(&weak));
}
}
}
}