use std::fmt;
use std::future::poll_fn;
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::task::Context;
use std::task::Poll;
use std::task::Waker;
use crate::atomicbox::AtomicOptionBox;
use crate::mpsc::RecvError;
use crate::mpsc::SendError;
use crate::mpsc::TryRecvError;
pub fn unbounded<T>() -> (UnboundedSender<T>, UnboundedReceiver<T>) {
let state = Arc::new(UnboundedState {
senders: AtomicUsize::new(1),
rx_task: AtomicOptionBox::none(),
});
let (sender, receiver) = std::sync::mpsc::channel();
let sender = UnboundedSender {
state: state.clone(),
sender: Some(sender),
};
let receiver = UnboundedReceiver {
state: state.clone(),
receiver,
};
(sender, receiver)
}
struct UnboundedState {
senders: AtomicUsize,
rx_task: AtomicOptionBox<Waker>,
}
pub struct UnboundedSender<T> {
state: Arc<UnboundedState>,
sender: Option<std::sync::mpsc::Sender<T>>,
}
impl<T> Clone for UnboundedSender<T> {
fn clone(&self) -> Self {
self.state.senders.fetch_add(1, Ordering::Release);
UnboundedSender {
state: self.state.clone(),
sender: self.sender.clone(),
}
}
}
impl<T> fmt::Debug for UnboundedSender<T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("UnboundedSender").finish_non_exhaustive()
}
}
impl<T> Drop for UnboundedSender<T> {
fn drop(&mut self) {
drop(self.sender.take());
match self.state.senders.fetch_sub(1, Ordering::AcqRel) {
1 => {
if let Some(waker) = self.state.rx_task.take() {
waker.wake();
}
}
_ => {
}
}
}
}
impl<T> UnboundedSender<T> {
pub fn send(&self, value: T) -> Result<(), SendError<T>> {
let sender = self.sender.as_ref().unwrap();
sender.send(value).map_err(|err| SendError::new(err.0))?;
if let Some(waker) = self.state.rx_task.take() {
waker.wake();
}
Ok(())
}
}
pub struct UnboundedReceiver<T> {
state: Arc<UnboundedState>,
receiver: std::sync::mpsc::Receiver<T>,
}
unsafe impl<T: Send> Sync for UnboundedReceiver<T> {}
impl<T> fmt::Debug for UnboundedReceiver<T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("UnboundedReceiver")
.finish_non_exhaustive()
}
}
impl<T> UnboundedReceiver<T> {
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
match self.receiver.try_recv() {
Ok(v) => Ok(v),
Err(std::sync::mpsc::TryRecvError::Disconnected) => Err(TryRecvError::Disconnected),
Err(std::sync::mpsc::TryRecvError::Empty) => Err(TryRecvError::Empty),
}
}
pub async fn recv(&mut self) -> Result<T, RecvError> {
poll_fn(|cx| self.poll_recv(cx)).await
}
fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
match self.try_recv() {
Ok(v) => Poll::Ready(Ok(v)),
Err(TryRecvError::Disconnected) => Poll::Ready(Err(RecvError::Disconnected)),
Err(TryRecvError::Empty) => {
let waker = Some(Box::new(cx.waker().clone()));
self.state.rx_task.store(waker);
match self.try_recv() {
Ok(v) => Poll::Ready(Ok(v)),
Err(TryRecvError::Disconnected) => Poll::Ready(Err(RecvError::Disconnected)),
Err(TryRecvError::Empty) => Poll::Pending,
}
}
}
}
}