use core::mem::MaybeUninit;
use crate::{atomic::Ordering, spsc::queue::QueuePtr};
pub struct Sender<T> {
ptr: QueuePtr<T>,
local_head: usize,
local_tail: usize,
}
impl<T> Sender<T> {
pub(crate) fn new(queue_ptr: QueuePtr<T>) -> Self {
Self {
ptr: queue_ptr,
local_head: 0,
local_tail: 0,
}
}
pub fn try_send(&mut self, value: T) -> Result<(), T> {
let new_tail = self.local_tail.wrapping_add(1);
if new_tail > self.max_tail() {
self.load_head();
if new_tail > self.max_tail() {
return Err(value);
}
}
unsafe { self.ptr.set(self.local_tail, value) };
self.store_tail(new_tail);
self.local_tail = new_tail;
#[cfg(feature = "async")]
self.ptr.wake_receiver();
Ok(())
}
pub fn send(&mut self, value: T) {
self.send_with_spin_count(value, 128);
}
pub fn send_with_spin_count(&mut self, value: T, spin_count: u32) {
let new_tail = self.local_tail.wrapping_add(1);
let mut backoff = crate::Backoff::with_spin_count(spin_count);
while new_tail > self.max_tail() {
backoff.backoff();
self.load_head();
}
unsafe { self.ptr.set(self.local_tail, value) };
self.store_tail(new_tail);
self.local_tail = new_tail;
#[cfg(feature = "async")]
self.ptr.wake_receiver();
}
#[cfg(feature = "async")]
pub async fn send_async(&mut self, value: T) {
use core::task::Poll;
let new_tail = self.local_tail.wrapping_add(1);
if new_tail > self.max_tail() {
futures::future::poll_fn(|ctx| {
self.load_head();
if new_tail > self.max_tail() {
self.ptr.register_sender_waker(ctx.waker());
self.local_head = self.ptr.head().load(Ordering::SeqCst);
if new_tail > self.max_tail() {
return Poll::Pending;
}
}
Poll::Ready(())
})
.await;
}
unsafe { self.ptr.set(self.local_tail, value) };
self.store_tail(new_tail);
self.local_tail = new_tail;
self.ptr.wake_receiver();
}
pub fn write_buffer(&mut self) -> &mut [MaybeUninit<T>] {
let mut available = self.ptr.size - self.local_tail.wrapping_sub(self.local_head);
if available == 0 {
self.load_head();
available = self.ptr.size - self.local_tail.wrapping_sub(self.local_head);
}
let start = self.local_tail & self.ptr.mask;
let contiguous = self.ptr.capacity - start;
let len = available.min(contiguous);
unsafe {
let ptr = self.ptr.exact_at(start).cast();
core::slice::from_raw_parts_mut(ptr.as_ptr(), len)
}
}
#[inline(always)]
pub unsafe fn commit(&mut self, len: usize) {
#[cfg(debug_assertions)]
{
let start = self.local_tail & self.ptr.mask;
let contiguous = self.ptr.capacity - start;
let available =
contiguous.min(self.ptr.size - self.local_tail.wrapping_sub(self.local_head));
assert!(
len <= available,
"advancing ({len}) more than available space ({available})"
);
}
let new_tail = self.local_tail.wrapping_add(len);
self.store_tail(new_tail);
self.local_tail = new_tail;
#[cfg(feature = "async")]
self.ptr.wake_receiver();
}
#[inline(always)]
fn max_tail(&self) -> usize {
self.local_head.wrapping_add(self.ptr.size)
}
#[inline(always)]
fn store_tail(&self, value: usize) {
self.ptr.tail().store(value, Ordering::Release);
}
#[inline(always)]
fn load_head(&mut self) {
self.local_head = self.ptr.head().load(Ordering::Acquire);
}
}
unsafe impl<T: Send> Send for Sender<T> {}