use core::{mem::MaybeUninit, num::NonZeroUsize, ptr::NonNull};
use crate::{
Box,
spsc::{self, shards::ShardsPtr},
sync::atomic::{AtomicUsize, Ordering},
};
pub struct Sender<T> {
inner: spsc::Sender<T>,
shards: ShardsPtr<T>,
num_senders: NonNull<AtomicUsize>,
alive_senders: NonNull<AtomicUsize>,
max_shards: usize,
}
impl<T> Sender<T> {
pub fn try_clone(&self) -> Option<Self> {
unsafe {
Self::init(
self.shards.clone(),
self.max_shards,
self.num_senders,
self.alive_senders,
)
}
}
pub(super) fn new(shards: ShardsPtr<T>, max_shards: NonZeroUsize) -> Self {
let num_senders_ptr = Box::into_raw(Box::new(AtomicUsize::new(0)));
let alive_senders_ptr = Box::into_raw(Box::new(AtomicUsize::new(0)));
unsafe {
let num_senders = NonNull::new_unchecked(num_senders_ptr);
let alive_senders = NonNull::new_unchecked(alive_senders_ptr);
Self::init(shards, max_shards.get(), num_senders, alive_senders).unwrap_unchecked()
}
}
unsafe fn init(
shards: ShardsPtr<T>,
max_shards: usize,
num_senders: NonNull<AtomicUsize>,
alive_senders: NonNull<AtomicUsize>,
) -> Option<Self> {
let num_senders_ref = unsafe { num_senders.as_ref() };
let next_shard = num_senders_ref.fetch_add(1, Ordering::Relaxed);
if next_shard >= max_shards {
num_senders_ref.store(max_shards, Ordering::Relaxed);
return None;
}
unsafe { alive_senders.as_ref() }.fetch_add(1, Ordering::AcqRel);
let shard_ptr = shards.clone_queue_ptr(next_shard);
let inner = spsc::Sender::new(shard_ptr);
Some(Self {
inner,
shards,
num_senders,
alive_senders,
max_shards,
})
}
pub fn send(&mut self, value: T) {
self.inner.send(value)
}
pub fn try_send(&mut self, value: T) -> Result<(), T> {
self.inner.try_send(value)
}
pub fn write_buffer(&mut self) -> &mut [MaybeUninit<T>] {
self.inner.write_buffer()
}
pub unsafe fn commit(&mut self, len: usize) {
unsafe { self.inner.commit(len) }
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
unsafe {
if self.alive_senders.as_ref().fetch_sub(1, Ordering::AcqRel) == 1 {
_ = Box::from_raw(self.num_senders.as_ptr());
_ = Box::from_raw(self.alive_senders.as_ptr());
}
}
}
}
unsafe impl<T> Send for Sender<T> {}