use core::ptr::{self, NonNull};
use crate::{
Backoff, Box,
padded::Padded,
read_guard::BatchReader,
spsc::{self, shards::ShardsPtr},
sync::atomic::{AtomicBool, AtomicUsize, Ordering},
};
type Lock = Padded<AtomicBool>;
pub struct Receiver<T> {
receivers: Box<[spsc::Receiver<T>]>,
locks: NonNull<Lock>,
alive_receivers: NonNull<AtomicUsize>,
shards: ShardsPtr<T>,
max_shards: usize,
next_shard: usize,
}
impl<T> Receiver<T> {
pub(super) fn new(shards: ShardsPtr<T>, max_shards: usize) -> Self {
let mut locks = Box::<[Lock]>::new_uninit_slice(max_shards);
let mut receivers = Box::new_uninit_slice(max_shards);
for i in 0..max_shards {
let shard = shards.clone_queue_ptr(i);
receivers[i].write(spsc::Receiver::new(shard));
locks[i].write(Padded::new(AtomicBool::new(false)));
}
let locks = unsafe { NonNull::new_unchecked(Box::into_raw(locks.assume_init())) }.cast();
let alive_receivers_ptr = Box::into_raw(Box::new(AtomicUsize::new(1)));
let alive_receivers = unsafe { NonNull::new_unchecked(alive_receivers_ptr) };
Self {
receivers: unsafe { receivers.assume_init() },
locks,
alive_receivers,
shards,
max_shards,
next_shard: 0,
}
}
pub fn try_clone(&self) -> Option<Self> {
let num_receivers_ref = unsafe { self.alive_receivers.as_ref() };
if num_receivers_ref.fetch_add(1, Ordering::AcqRel) == self.max_shards {
num_receivers_ref.fetch_sub(1, Ordering::AcqRel);
return None;
}
let mut receivers = Box::new_uninit_slice(self.max_shards);
for i in 0..self.max_shards {
receivers[i].write(unsafe { self.receivers[i].clone_via_ptr() });
}
Some(Self {
receivers: unsafe { receivers.assume_init() },
alive_receivers: self.alive_receivers,
shards: self.shards.clone(),
locks: self.locks,
max_shards: self.max_shards,
next_shard: 0,
})
}
pub fn recv(&mut self) -> T {
self.recv_with_spin_count(128)
}
pub fn recv_with_spin_count(&mut self, spin_count: u32) -> T {
let mut backoff = Backoff::with_spin_count(spin_count);
loop {
match self.try_recv() {
None => backoff.backoff(),
Some(ret) => return ret,
}
}
}
pub fn try_recv(&mut self) -> Option<T> {
let start = self.next_shard;
loop {
let idx = self.next_shard;
if !self.receivers[idx].is_empty() && self.try_lock(idx) {
self.receivers[idx].refresh_head();
let ret = self.receivers[idx].try_recv();
unsafe { self.unlock(idx) };
if let Some(v) = ret {
return Some(v);
}
}
self.next_shard += 1;
if self.next_shard == self.max_shards {
self.next_shard = 0;
}
if self.next_shard == start {
return None;
}
}
}
pub fn read_guard(&mut self) -> crate::read_guard::ReadGuard<'_, Self> {
crate::read_guard::ReadGuard::new(self)
}
#[inline(always)]
fn shard_lock(&self, shard: usize) -> &AtomicBool {
unsafe { self.locks.add(shard).cast::<AtomicBool>().as_ref() }
}
#[inline(always)]
fn try_lock(&self, shard: usize) -> bool {
self.shard_lock(shard)
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
}
#[inline(always)]
unsafe fn unlock(&self, shard: usize) {
self.shard_lock(shard).store(false, Ordering::Release);
}
}
unsafe impl<T> BatchReader for Receiver<T> {
type Item = T;
fn read_buffer(&mut self) -> &[T] {
let start = self.next_shard;
loop {
let idx = self.next_shard;
if !self.receivers[idx].is_empty() && self.try_lock(idx) {
let receiver_ptr = &mut self.receivers[idx] as *mut spsc::Receiver<T>;
unsafe { (*receiver_ptr).refresh_head() };
let ret = unsafe { (*receiver_ptr).read_buffer() };
if !ret.is_empty() {
return unsafe { core::mem::transmute::<&[T], &[T]>(ret) };
} else {
unsafe { self.unlock(idx) };
}
}
self.next_shard += 1;
if self.next_shard == self.max_shards {
self.next_shard = 0;
}
if self.next_shard == start {
return &[];
}
}
}
unsafe fn advance(&mut self, n: usize) {
unsafe { self.receivers[self.next_shard].advance(n) };
}
unsafe fn release(&mut self) {
unsafe { self.unlock(self.next_shard) };
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
unsafe {
if self.alive_receivers.as_ref().fetch_sub(1, Ordering::AcqRel) == 1 {
let slice_ptr = ptr::slice_from_raw_parts_mut(self.locks.as_ptr(), self.max_shards);
_ = Box::from_raw(slice_ptr);
_ = Box::from_raw(self.alive_receivers.as_ptr());
}
}
}
}
unsafe impl<T> Send for Receiver<T> {}