use core::{cell::UnsafeCell, ptr::NonNull};
use crate::{
Backoff, Box,
padded::Padded,
queue::ShardOwnership,
read_guard::BatchReader,
spsc::{self, shards::ShardsPtr},
sync::atomic::{AtomicBool, AtomicUsize, Ordering},
};
struct Shared<T> {
receivers: Box<[UnsafeCell<spsc::Receiver<T, ShardOwnership>>]>,
locks: Box<[Padded<AtomicBool>]>,
alive_receivers: AtomicUsize,
max_shards: usize,
}
pub struct Receiver<T> {
shared: NonNull<Shared<T>>,
next_shard: usize,
}
impl<T> Receiver<T> {
pub(super) fn new(shards: ShardsPtr<T>, max_shards: usize) -> Self {
let mut locks = Box::<[Padded<AtomicBool>]>::new_uninit_slice(max_shards);
let mut receivers = Box::new_uninit_slice(max_shards);
for i in 0..max_shards {
let shard = shards.claim_consumer_queue_ptr(i).unwrap();
receivers[i].write(UnsafeCell::new(spsc::Receiver::new(shard)));
locks[i].write(Padded::new(AtomicBool::new(false)));
}
let shared = Box::new(Shared {
receivers: unsafe { receivers.assume_init() },
locks: unsafe { locks.assume_init() },
alive_receivers: AtomicUsize::new(1),
max_shards,
});
Self {
shared: unsafe { NonNull::new_unchecked(Box::into_raw(shared)) },
next_shard: 0,
}
}
pub fn try_clone(&self) -> Option<Self> {
let shared = self.shared();
let mut live = shared.alive_receivers.load(Ordering::Acquire);
loop {
if live >= shared.max_shards {
return None;
}
match shared.alive_receivers.compare_exchange(
live,
live + 1,
Ordering::AcqRel,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(actual) => live = actual,
}
}
Some(Self {
shared: self.shared,
next_shard: 0,
})
}
pub fn recv(&mut self) -> T {
self.recv_with_spin_count(128)
}
pub fn recv_with_spin_count(&mut self, spin_count: u32) -> T {
let mut backoff = Backoff::with_spin_count(spin_count);
loop {
match self.try_recv() {
None => backoff.backoff(),
Some(ret) => return ret,
}
}
}
pub fn try_recv(&mut self) -> Option<T> {
let start = self.next_shard;
loop {
let idx = self.next_shard;
if self.try_lock(idx) {
let receiver = unsafe { self.receiver_mut(idx) };
receiver.refresh_head();
let ret = receiver.try_recv();
unsafe { self.unlock(idx) };
if let Some(v) = ret {
return Some(v);
}
}
self.next_shard += 1;
if self.next_shard == self.shared().max_shards {
self.next_shard = 0;
}
if self.next_shard == start {
return None;
}
}
}
pub fn read_guard(&mut self) -> crate::read_guard::ReadGuard<'_, Self> {
crate::read_guard::ReadGuard::new(self)
}
#[inline(always)]
fn shared(&self) -> &Shared<T> {
unsafe { self.shared.as_ref() }
}
#[inline(always)]
unsafe fn receiver_mut(&mut self, shard: usize) -> &mut spsc::Receiver<T, ShardOwnership> {
unsafe { &mut *self.shared().receivers[shard].get() }
}
#[inline(always)]
fn shard_lock(&self, shard: usize) -> &AtomicBool {
&self.shared().locks[shard].value
}
#[inline(always)]
fn try_lock(&self, shard: usize) -> bool {
self.shard_lock(shard)
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
}
#[inline(always)]
unsafe fn unlock(&self, shard: usize) {
self.shard_lock(shard).store(false, Ordering::Release);
}
}
unsafe impl<T> BatchReader for Receiver<T> {
type Item = T;
fn read_buffer(&mut self) -> &[T] {
let start = self.next_shard;
loop {
let idx = self.next_shard;
if self.try_lock(idx) {
let receiver = unsafe { self.receiver_mut(idx) };
receiver.refresh_head();
let ret = receiver.read_buffer();
if !ret.is_empty() {
return unsafe { core::mem::transmute::<&[T], &[T]>(ret) };
} else {
unsafe { self.unlock(idx) };
}
}
self.next_shard += 1;
if self.next_shard == self.shared().max_shards {
self.next_shard = 0;
}
if self.next_shard == start {
return &[];
}
}
}
unsafe fn advance(&mut self, n: usize) {
let receiver = unsafe { self.receiver_mut(self.next_shard) };
unsafe { receiver.advance(n) };
}
unsafe fn release(&mut self) {
unsafe { self.unlock(self.next_shard) };
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
unsafe {
if self.shared().alive_receivers.fetch_sub(1, Ordering::AcqRel) == 1 {
_ = Box::from_raw(self.shared.as_ptr());
}
}
}
}
unsafe impl<T> Send for Receiver<T> {}