use crate::coord::CapacityGate;
use crate::error::{
BatchSendErrorReason, CloseError, RecvError, SendBatchError, SendError, TryRecvError,
TrySendBatchError, TrySendError,
};
use crate::mpsc::unbounded_v2;
use crate::{sync_util, RecvErrorTimeout};
use std::mem;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant};
use super::bounded_async::{AsyncReceiver, AsyncSender};
#[derive(Debug)]
pub(crate) struct Permit {
pub(crate) gate: Arc<CapacityGate>,
pub(crate) is_rendezvous: bool,
}
impl Drop for Permit {
fn drop(&mut self) {
if !self.is_rendezvous {
self.gate.release();
}
}
}
impl Permit {
pub(crate) fn into_parts(self) -> (Arc<CapacityGate>, bool /* is_rendezvous */) {
let this = std::mem::ManuallyDrop::new(self);
let gate = unsafe { std::ptr::read(&this.gate) };
(gate, this.is_rendezvous)
}
}
pub(crate) fn unwrap_batch_messages<T: Send>(
gate: &CapacityGate,
msgs: Vec<BoundedMessage<T>>,
out: &mut Vec<T>,
) -> usize {
let k = msgs.len();
let mut to_release = 0usize;
out.reserve(k);
for msg in msgs {
out.push(msg.value);
let (_gate, is_rendezvous) = msg._permit.into_parts();
if !is_rendezvous {
to_release += 1;
}
}
if to_release > 0 {
gate.release_many(to_release);
}
k
}
pub(crate) struct BoundedMessage<T> {
pub(crate) value: T,
pub(crate) _permit: Permit,
}
#[derive(Debug)]
pub(crate) struct BoundedMpscShared<T: Send> {
pub(crate) gate: Arc<CapacityGate>,
pub(crate) channel: Arc<unbounded_v2::MpscShared<BoundedMessage<T>>>,
}
#[derive(Debug)]
pub struct Sender<T: Send> {
pub(crate) shared: Arc<BoundedMpscShared<T>>,
pub(crate) closed: AtomicBool,
}
#[derive(Debug)]
pub struct Receiver<T: Send> {
pub(crate) shared: Arc<BoundedMpscShared<T>>,
pub(crate) closed: AtomicBool,
}
impl<T: Send> Sender<T> {
pub fn send(&self, value: T) -> Result<(), SendError> {
if self.closed.load(Ordering::Relaxed)
|| self.shared.channel.receiver_dropped.load(Ordering::Acquire)
{
return Err(SendError::Closed);
}
self.shared.gate.acquire_sync();
let permit = Permit {
gate: self.shared.gate.clone(),
is_rendezvous: self.capacity() == 0,
};
let message = BoundedMessage {
value,
_permit: permit,
};
let mut cache = None;
if unbounded_v2::send_internal(&self.shared.channel, message, &mut cache).is_err() {
return Err(SendError::Closed);
}
Ok(())
}
pub fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
if self.closed.load(Ordering::Relaxed)
|| self.shared.channel.receiver_dropped.load(Ordering::Acquire)
{
return Err(TrySendError::Closed(value));
}
if !self.shared.gate.try_acquire() {
return Err(TrySendError::Full(value));
}
let permit = Permit {
gate: self.shared.gate.clone(),
is_rendezvous: self.capacity() == 0,
};
let message = BoundedMessage {
value,
_permit: permit,
};
let mut cache = None;
if let Err(msg) = unbounded_v2::send_internal(&self.shared.channel, message, &mut cache) {
return Err(TrySendError::Closed(msg.value));
}
Ok(())
}
pub fn try_send_batch(&self, items: Vec<T>) -> Result<usize, TrySendBatchError<T>> {
let total = items.len();
if total == 0 {
return Ok(0);
}
if self.closed.load(Ordering::Relaxed)
|| self.shared.channel.receiver_dropped.load(Ordering::Acquire)
{
return Err(TrySendBatchError {
sent: 0,
unsent: items,
reason: BatchSendErrorReason::Closed,
});
}
let k = self.shared.gate.try_acquire_many(total);
if k == 0 {
return Err(TrySendBatchError {
sent: 0,
unsent: items,
reason: BatchSendErrorReason::Full,
});
}
if self.shared.channel.receiver_dropped.load(Ordering::Acquire) {
return Err(TrySendBatchError {
sent: 0,
unsent: items,
reason: BatchSendErrorReason::Closed,
});
}
let mut iter = items.into_iter();
self.push_batch_messages(&mut iter, k);
if k == total {
Ok(total)
} else {
Err(TrySendBatchError {
sent: k,
unsent: iter.collect(),
reason: BatchSendErrorReason::Full,
})
}
}
pub fn send_batch(&self, items: Vec<T>) -> Result<usize, SendBatchError<T>> {
let total = items.len();
if total == 0 {
return Ok(0);
}
if self.closed.load(Ordering::Relaxed)
|| self.shared.channel.receiver_dropped.load(Ordering::Acquire)
{
return Err(SendBatchError {
sent: 0,
unsent: items,
});
}
let mut iter = items.into_iter();
let mut sent = 0;
while sent < total {
let k = self.shared.gate.acquire_many_sync(total - sent);
if self.shared.channel.receiver_dropped.load(Ordering::Acquire) {
return Err(SendBatchError {
sent,
unsent: iter.collect(),
});
}
self.push_batch_messages(&mut iter, k);
sent += k;
}
Ok(total)
}
pub fn try_send_batch_mut(&self, items: &mut Vec<T>) -> Result<usize, SendError> {
if items.is_empty() {
return Ok(0);
}
if self.closed.load(Ordering::Relaxed)
|| self.shared.channel.receiver_dropped.load(Ordering::Acquire)
{
return Err(SendError::Closed);
}
let k = self.shared.gate.try_acquire_many(items.len());
if k == 0 {
return Ok(0);
}
if self.shared.channel.receiver_dropped.load(Ordering::Acquire) {
return Err(SendError::Closed);
}
let mut drain = items.drain(..k);
self.push_batch_messages(&mut drain, k);
drop(drain);
Ok(k)
}
pub fn send_batch_mut(&self, items: &mut Vec<T>) -> Result<usize, SendError> {
if items.is_empty() {
return Ok(0);
}
if self.closed.load(Ordering::Relaxed)
|| self.shared.channel.receiver_dropped.load(Ordering::Acquire)
{
return Err(SendError::Closed);
}
let mut sent = 0;
while !items.is_empty() {
let k = self.shared.gate.acquire_many_sync(items.len());
if self.shared.channel.receiver_dropped.load(Ordering::Acquire) {
return Err(SendError::Closed);
}
let mut drain = items.drain(..k);
self.push_batch_messages(&mut drain, k);
drop(drain);
sent += k;
}
Ok(sent)
}
fn push_batch_messages(&self, iter: &mut impl Iterator<Item = T>, k: usize) {
let is_rendezvous = self.capacity() == 0;
let shared = &self.shared;
let mut msg_iter = iter.by_ref().map(|value| BoundedMessage {
value,
_permit: Permit {
gate: shared.gate.clone(),
is_rendezvous,
},
});
let mut cache = None;
unbounded_v2::send_batch_internal(&shared.channel, &mut msg_iter, k, &mut cache);
}
pub fn close(&self) -> Result<(), CloseError> {
if self
.closed
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
self.close_internal();
Ok(())
} else {
Err(CloseError)
}
}
fn close_internal(&self) {
if self
.shared
.channel
.sender_count
.fetch_sub(1, Ordering::AcqRel)
== 1
{
self.shared.channel.wake_consumer();
self.shared.gate.release();
}
}
pub fn is_closed(&self) -> bool {
self.shared.channel.receiver_dropped.load(Ordering::Acquire)
}
pub fn sender_count(&self) -> usize {
self.shared.channel.sender_count.load(Ordering::Relaxed)
}
pub fn len(&self) -> usize {
self.shared.channel.current_len.load(Ordering::Relaxed)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn capacity(&self) -> usize {
self.shared.gate.capacity()
}
pub fn is_full(&self) -> bool {
self.len() == self.capacity()
}
pub fn to_async(self) -> AsyncSender<T> {
let shared = unsafe { std::ptr::read(&self.shared) };
mem::forget(self);
AsyncSender {
shared,
closed: AtomicBool::new(false),
}
}
}
impl<T: Send> Clone for Sender<T> {
fn clone(&self) -> Self {
self
.shared
.channel
.sender_count
.fetch_add(1, Ordering::Relaxed);
Self {
shared: self.shared.clone(),
closed: AtomicBool::new(false),
}
}
}
impl<T: Send> Drop for Sender<T> {
fn drop(&mut self) {
if !self.closed.swap(true, Ordering::AcqRel) {
self.close_internal();
}
}
}
impl<T: Send> Receiver<T> {
pub fn recv(&self) -> Result<T, RecvError> {
if self.closed.load(Ordering::Relaxed) {
return Err(RecvError::Disconnected);
}
if self.capacity() == 0 {
self.shared.gate.release();
}
loop {
match self.try_recv_internal_no_release() {
Ok(value) => return Ok(value),
Err(TryRecvError::Disconnected) => return Err(RecvError::Disconnected),
Err(TryRecvError::Empty) => {}
}
let lf_shared = &self.shared.channel;
*lf_shared.consumer_thread.lock().unwrap() = Some(thread::current());
lf_shared.consumer_parked.store(true, Ordering::Release);
match self.try_recv_internal_no_release() {
Ok(value) => {
if lf_shared
.consumer_parked
.compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
*lf_shared.consumer_thread.lock().unwrap() = None;
}
return Ok(value);
}
Err(TryRecvError::Disconnected) => {
if lf_shared
.consumer_parked
.compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
*lf_shared.consumer_thread.lock().unwrap() = None;
}
return Err(RecvError::Disconnected);
}
Err(TryRecvError::Empty) => {
sync_util::park_thread();
if lf_shared
.consumer_parked
.compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
*lf_shared.consumer_thread.lock().unwrap() = None;
}
}
}
}
}
pub fn recv_timeout(&self, timeout: Duration) -> Result<T, RecvErrorTimeout> {
if self.closed.load(Ordering::Relaxed) {
return Err(RecvErrorTimeout::Disconnected);
}
let start_time = Instant::now();
if self.capacity() == 0 {
self.shared.gate.release();
}
match self.try_recv_internal_no_release() {
Ok(value) => return Ok(value),
Err(TryRecvError::Disconnected) => return Err(RecvErrorTimeout::Disconnected),
Err(TryRecvError::Empty) => {} }
loop {
let elapsed = start_time.elapsed();
if elapsed >= timeout {
return Err(RecvErrorTimeout::Timeout);
}
let remaining_timeout = timeout - elapsed;
if self.capacity() == 0 {
self.shared.gate.release();
}
let lf_shared = &self.shared.channel;
*lf_shared.consumer_thread.lock().unwrap() = Some(thread::current());
lf_shared.consumer_parked.store(true, Ordering::Release);
match self.try_recv_internal_no_release() {
Ok(value) => {
if lf_shared
.consumer_parked
.compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
*lf_shared.consumer_thread.lock().unwrap() = None;
}
return Ok(value);
}
Err(TryRecvError::Disconnected) => {
if lf_shared
.consumer_parked
.compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
*lf_shared.consumer_thread.lock().unwrap() = None;
}
return Err(RecvErrorTimeout::Disconnected);
}
Err(TryRecvError::Empty) => {
sync_util::park_thread_timeout(remaining_timeout);
if lf_shared
.consumer_parked
.compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
*lf_shared.consumer_thread.lock().unwrap() = None;
}
}
}
match self.try_recv_internal_no_release() {
Ok(value) => return Ok(value),
Err(TryRecvError::Disconnected) => return Err(RecvErrorTimeout::Disconnected),
Err(TryRecvError::Empty) => {} }
}
}
fn try_recv_internal_no_release(&self) -> Result<T, TryRecvError> {
if self.closed.load(Ordering::Relaxed) {
return Err(TryRecvError::Disconnected);
}
self.shared.channel.try_recv_internal().map(|msg| msg.value)
}
pub fn try_recv(&self) -> Result<T, TryRecvError> {
if self.closed.load(Ordering::Relaxed) {
return Err(TryRecvError::Disconnected);
}
if self.capacity() == 0 {
self.shared.gate.release();
}
self.shared.channel.try_recv_internal().map(|msg| msg.value)
}
fn try_recv_batch_internal_no_release(
&self,
out: &mut Vec<T>,
max: usize,
) -> Result<usize, TryRecvError> {
if self.closed.load(Ordering::Relaxed) {
return Err(TryRecvError::Disconnected);
}
let mut msgs = Vec::new();
let k = self.shared.channel.try_recv_batch_internal(&mut msgs, max)?;
debug_assert_eq!(k, msgs.len());
Ok(unwrap_batch_messages(&self.shared.gate, msgs, out))
}
pub fn try_recv_batch(&self, max: usize) -> Result<Vec<T>, TryRecvError> {
let mut out = Vec::new();
self.try_recv_batch_mut(&mut out, max)?;
Ok(out)
}
pub fn try_recv_batch_mut(&self, out: &mut Vec<T>, max: usize) -> Result<usize, TryRecvError> {
if max == 0 {
return Ok(0);
}
if self.closed.load(Ordering::Relaxed) {
return Err(TryRecvError::Disconnected);
}
if self.capacity() == 0 {
self.shared.gate.release();
}
self.try_recv_batch_internal_no_release(out, max)
}
pub fn recv_batch(&self, max: usize) -> Result<Vec<T>, RecvError> {
let mut out = Vec::new();
self.recv_batch_mut(&mut out, max)?;
Ok(out)
}
pub fn recv_batch_mut(&self, out: &mut Vec<T>, max: usize) -> Result<usize, RecvError> {
if max == 0 {
return Ok(0);
}
if self.closed.load(Ordering::Relaxed) {
return Err(RecvError::Disconnected);
}
if self.capacity() == 0 {
self.shared.gate.release();
}
loop {
match self.try_recv_batch_internal_no_release(out, max) {
Ok(k) => return Ok(k),
Err(TryRecvError::Disconnected) => return Err(RecvError::Disconnected),
Err(TryRecvError::Empty) => {}
}
let lf_shared = &self.shared.channel;
*lf_shared.consumer_thread.lock().unwrap() = Some(thread::current());
lf_shared.consumer_parked.store(true, Ordering::Release);
match self.try_recv_batch_internal_no_release(out, max) {
Ok(k) => {
if lf_shared
.consumer_parked
.compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
*lf_shared.consumer_thread.lock().unwrap() = None;
}
return Ok(k);
}
Err(TryRecvError::Disconnected) => {
if lf_shared
.consumer_parked
.compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
*lf_shared.consumer_thread.lock().unwrap() = None;
}
return Err(RecvError::Disconnected);
}
Err(TryRecvError::Empty) => {
sync_util::park_thread();
if lf_shared
.consumer_parked
.compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
*lf_shared.consumer_thread.lock().unwrap() = None;
}
}
}
}
}
pub fn close(&self) -> Result<(), CloseError> {
if self
.closed
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
self.close_internal();
Ok(())
} else {
Err(CloseError)
}
}
fn close_internal(&self) {
self
.shared
.channel
.receiver_dropped
.store(true, Ordering::Release);
while self.shared.channel.try_recv_internal().is_ok() {}
self.shared.gate.close();
}
pub fn is_closed(&self) -> bool {
let chan = &self.shared.channel;
chan.sender_count.load(Ordering::Acquire) == 0 && self.is_empty()
}
pub fn sender_count(&self) -> usize {
self.shared.channel.sender_count.load(Ordering::Relaxed)
}
pub fn len(&self) -> usize {
self.shared.channel.current_len.load(Ordering::Relaxed)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn capacity(&self) -> usize {
self.shared.gate.capacity()
}
pub fn is_full(&self) -> bool {
self.len() == self.capacity()
}
pub fn to_async(self) -> AsyncReceiver<T> {
let shared = unsafe { std::ptr::read(&self.shared) };
mem::forget(self);
AsyncReceiver {
shared,
closed: AtomicBool::new(false),
}
}
}
impl<T: Send> Drop for Receiver<T> {
fn drop(&mut self) {
if !self.closed.swap(true, Ordering::AcqRel) {
self.close_internal();
}
}
}