use std::fmt;
use std::future::Future;
use std::future::poll_fn;
use std::pin::pin;
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::task::Context;
use std::task::Poll;
use std::task::Waker;
use crate::atomicbox::AtomicOptionBox;
use crate::internal::Acquire;
use crate::internal::Semaphore;
use crate::mpsc::RecvError;
use crate::mpsc::SendError;
use crate::mpsc::TryRecvError;
use crate::mpsc::error::TrySendError;
#[track_caller]
pub fn bounded<T>(buffer: usize) -> (BoundedSender<T>, BoundedReceiver<T>) {
assert!(buffer > 0, "mpsc bounded channel requires buffer > 0");
let state = Arc::new(BoundedState {
senders: AtomicUsize::new(1),
tx_permits: Semaphore::new(0),
rx_task: AtomicOptionBox::none(),
});
let (sender, receiver) = std::sync::mpsc::sync_channel(buffer);
let sender = BoundedSender {
state: state.clone(),
sender: Some(sender),
};
let receiver = BoundedReceiver {
state: state.clone(),
receiver: Some(receiver),
};
(sender, receiver)
}
struct BoundedState {
senders: AtomicUsize,
tx_permits: Semaphore,
rx_task: AtomicOptionBox<Waker>,
}
pub struct BoundedSender<T> {
state: Arc<BoundedState>,
sender: Option<std::sync::mpsc::SyncSender<T>>,
}
impl<T> Clone for BoundedSender<T> {
fn clone(&self) -> Self {
self.state.senders.fetch_add(1, Ordering::Release);
BoundedSender {
state: self.state.clone(),
sender: self.sender.clone(),
}
}
}
impl<T> fmt::Debug for BoundedSender<T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("BoundedSender").finish_non_exhaustive()
}
}
impl<T> Drop for BoundedSender<T> {
fn drop(&mut self) {
drop(self.sender.take());
match self.state.senders.fetch_sub(1, Ordering::AcqRel) {
1 => {
if let Some(waker) = self.state.rx_task.take() {
waker.wake();
}
}
_ => {
}
}
}
}
impl<T> BoundedSender<T> {
pub async fn send(&self, value: T) -> Result<(), SendError<T>> {
let value = match self.try_send(value) {
Ok(()) => return Ok(()),
Err(TrySendError::Disconnected(value)) => return Err(SendError::new(value)),
Err(TrySendError::Full(value)) => value,
};
struct SendState<'a, T> {
sender: &'a BoundedSender<T>,
value: Option<T>,
acquire: Acquire<'a>,
}
impl<T> SendState<'_, T> {
fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), SendError<T>>> {
let mut value = match self.value.take() {
Some(value) => value,
None => return Poll::Ready(Ok(())),
};
loop {
let poll = pin!(&mut self.acquire).poll(cx);
value = match self.sender.try_send(value) {
Ok(()) => return Poll::Ready(Ok(())),
Err(TrySendError::Disconnected(value)) => {
return Poll::Ready(Err(SendError::new(value)));
}
Err(TrySendError::Full(value)) => value,
};
if poll.is_ready() {
self.acquire = self.sender.state.tx_permits.poll_acquire(1);
} else {
self.value = Some(value);
return Poll::Pending;
}
}
}
}
let acquire = self.state.tx_permits.poll_acquire(1);
let mut send = SendState {
sender: self,
value: Some(value),
acquire,
};
poll_fn(|cx| send.poll_send(cx)).await
}
pub fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
let sender = self.sender.as_ref().unwrap();
match sender.try_send(value) {
Ok(()) => {
if let Some(waker) = self.state.rx_task.take() {
waker.wake();
}
Ok(())
}
Err(std::sync::mpsc::TrySendError::Full(value)) => Err(TrySendError::Full(value)),
Err(std::sync::mpsc::TrySendError::Disconnected(value)) => {
Err(TrySendError::Disconnected(value))
}
}
}
}
pub struct BoundedReceiver<T> {
state: Arc<BoundedState>,
receiver: Option<std::sync::mpsc::Receiver<T>>,
}
unsafe impl<T: Send> Sync for BoundedReceiver<T> {}
impl<T> fmt::Debug for BoundedReceiver<T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("BoundedReceiver").finish_non_exhaustive()
}
}
impl<T> Drop for BoundedReceiver<T> {
fn drop(&mut self) {
drop(self.receiver.take());
self.state.tx_permits.notify_all();
}
}
impl<T> BoundedReceiver<T> {
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
let receiver = self.receiver.as_ref().unwrap();
match receiver.try_recv() {
Ok(v) => {
self.state.tx_permits.release_if_nonempty(1);
Ok(v)
}
Err(std::sync::mpsc::TryRecvError::Disconnected) => Err(TryRecvError::Disconnected),
Err(std::sync::mpsc::TryRecvError::Empty) => Err(TryRecvError::Empty),
}
}
pub async fn recv(&mut self) -> Result<T, RecvError> {
poll_fn(|cx| self.poll_recv(cx)).await
}
fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
match self.try_recv() {
Ok(v) => Poll::Ready(Ok(v)),
Err(TryRecvError::Disconnected) => Poll::Ready(Err(RecvError::Disconnected)),
Err(TryRecvError::Empty) => {
let waker = Some(Box::new(cx.waker().clone()));
self.state.rx_task.store(waker);
match self.try_recv() {
Ok(v) => Poll::Ready(Ok(v)),
Err(TryRecvError::Disconnected) => Poll::Ready(Err(RecvError::Disconnected)),
Err(TryRecvError::Empty) => Poll::Pending,
}
}
}
}
}