use crate::channel::{
INIT_STATE, MAX_BUFFER, MAX_CAPACITY, OPEN_MASK, Priority, SendError, SendErrorKind,
TryRecvError, TrySendError, decode_state, encode_state, queue::Queue,
};
use futures::{
future::poll_fn,
stream::{FusedStream, Stream},
task::AtomicWaker,
};
use std::{
pin::Pin,
sync::{
Arc,
atomic::{AtomicUsize, Ordering::SeqCst},
},
task::{Context, Poll},
};
pub fn unbounded<T>() -> (UnboundedSender<T>, UnboundedReceiver<T>) {
let inner = Arc::new(UnboundedInner {
state: AtomicUsize::new(INIT_STATE),
message_queue: Queue::new(),
quick_message_queue: Queue::new(),
num_senders: AtomicUsize::new(1),
recv_task: AtomicWaker::new(),
});
let tx = UnboundedSenderInner {
inner: inner.clone(),
};
let rx = UnboundedReceiver { inner: Some(inner) };
(UnboundedSender(Some(tx)), rx)
}
#[derive(Debug)]
struct UnboundedInner<T> {
state: AtomicUsize,
message_queue: Queue<T>,
quick_message_queue: Queue<T>,
num_senders: AtomicUsize,
recv_task: AtomicWaker,
}
impl<T> UnboundedInner<T> {
fn set_closed(&self) {
let curr = self.state.load(SeqCst);
if !decode_state(curr).is_open {
return;
}
self.state.fetch_and(!OPEN_MASK, SeqCst);
}
}
unsafe impl<T: Send> Send for UnboundedInner<T> {}
unsafe impl<T: Send> Sync for UnboundedInner<T> {}
#[derive(Debug)]
struct UnboundedSenderInner<T> {
inner: Arc<UnboundedInner<T>>,
}
impl<T> Unpin for UnboundedSenderInner<T> {}
impl<T> UnboundedSenderInner<T> {
fn poll_ready_nb(&self) -> Poll<Result<(), SendError>> {
let state = decode_state(self.inner.state.load(SeqCst));
if state.is_open {
Poll::Ready(Ok(()))
} else {
Poll::Ready(Err(SendError {
kind: SendErrorKind::Disconnected,
}))
}
}
fn queue_push_and_signal(&self, msg: T, priority: Priority) {
match priority {
Priority::High => self.inner.quick_message_queue.push(msg),
Priority::Normal => self.inner.message_queue.push(msg),
}
self.inner.recv_task.wake();
}
fn inc_num_messages(&self) -> Option<usize> {
let mut curr = self.inner.state.load(SeqCst);
loop {
let mut state = decode_state(curr);
if !state.is_open {
return None;
}
assert!(
state.num_messages < MAX_CAPACITY,
"buffer space \
exhausted; sending this messages would overflow the state"
);
state.num_messages += 1;
let next = encode_state(&state);
match self
.inner
.state
.compare_exchange(curr, next, SeqCst, SeqCst)
{
Ok(_) => return Some(state.num_messages),
Err(actual) => curr = actual,
}
}
}
fn same_receiver(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.inner, &other.inner)
}
fn ptr(&self) -> *const UnboundedInner<T> {
&*self.inner
}
fn is_closed(&self) -> bool {
!decode_state(self.inner.state.load(SeqCst)).is_open
}
fn close_channel(&self) {
self.inner.set_closed();
self.inner.recv_task.wake();
}
}
impl<T> Clone for UnboundedSenderInner<T> {
fn clone(&self) -> UnboundedSenderInner<T> {
let mut curr = self.inner.num_senders.load(SeqCst);
loop {
if curr == MAX_BUFFER {
panic!("cannot clone `Sender` -- too many outstanding senders");
}
debug_assert!(curr < MAX_BUFFER);
let next = curr + 1;
match self
.inner
.num_senders
.compare_exchange(curr, next, SeqCst, SeqCst)
{
Ok(actual) => {
if actual == curr {
return UnboundedSenderInner {
inner: self.inner.clone(),
};
}
}
Err(actual) => curr = actual,
}
}
}
}
impl<T> Drop for UnboundedSenderInner<T> {
fn drop(&mut self) {
let prev = self.inner.num_senders.fetch_sub(1, SeqCst);
if prev == 1 {
self.close_channel();
}
}
}
#[derive(Debug)]
pub struct UnboundedSender<T>(Option<UnboundedSenderInner<T>>);
impl<T> UnboundedSender<T> {
pub fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), SendError>> {
let inner = self.0.as_ref().ok_or(SendError {
kind: SendErrorKind::Disconnected,
})?;
inner.poll_ready_nb()
}
pub fn is_closed(&self) -> bool {
self.0
.as_ref()
.map(UnboundedSenderInner::is_closed)
.unwrap_or(true)
}
pub fn close_channel(&self) {
if let Some(inner) = &self.0 {
inner.close_channel();
}
}
pub fn disconnect(&mut self) {
self.0 = None;
}
pub fn len(&self) -> Option<usize> {
self.0.as_ref().and_then(|inner| {
let state = decode_state(inner.inner.state.load(SeqCst));
if state.is_open {
Some(state.num_messages)
} else {
None
}
})
}
fn do_send_nb(&self, msg: T, priority: Priority) -> Result<(), TrySendError<T>> {
if let Some(inner) = &self.0 {
if inner.inc_num_messages().is_some() {
inner.queue_push_and_signal(msg, priority);
return Ok(());
}
}
Err(TrySendError {
err: SendError {
kind: SendErrorKind::Disconnected,
},
val: msg,
})
}
pub async fn async_send(&self, msg: T) -> Result<(), SendError> {
let mut msg = Some(msg);
poll_fn(|cx| {
let item = msg.take().unwrap();
match self.poll_ready(cx)? {
Poll::Ready(()) => Poll::Ready(self.start_send(item)),
Poll::Pending => {
msg = Some(item);
Poll::Pending
}
}
})
.await
}
pub async fn async_quick_send(&self, msg: T) -> Result<(), SendError> {
let mut msg = Some(msg);
poll_fn(|cx| {
let item = msg.take().unwrap();
match self.poll_ready(cx)? {
Poll::Ready(()) => Poll::Ready(self.start_quick_send(item)),
Poll::Pending => {
msg = Some(item);
Poll::Pending
}
}
})
.await
}
pub fn start_send(&self, msg: T) -> Result<(), SendError> {
self.do_send_nb(msg, Priority::Normal).map_err(|e| e.err)
}
pub fn start_quick_send(&self, msg: T) -> Result<(), SendError> {
self.do_send_nb(msg, Priority::High).map_err(|e| e.err)
}
pub fn unbounded_send(&self, msg: T) -> Result<(), TrySendError<T>> {
self.do_send_nb(msg, Priority::Normal)
}
pub fn unbounded_quick_send(&self, msg: T) -> Result<(), TrySendError<T>> {
self.do_send_nb(msg, Priority::High)
}
pub fn same_receiver(&self, other: &Self) -> bool {
match (&self.0, &other.0) {
(Some(inner), Some(other)) => inner.same_receiver(other),
_ => false,
}
}
pub fn hash_receiver<H>(&self, hasher: &mut H)
where
H: std::hash::Hasher,
{
use std::hash::Hash;
let ptr = self.0.as_ref().map(|inner| inner.ptr());
ptr.hash(hasher);
}
}
impl<T> Clone for UnboundedSender<T> {
fn clone(&self) -> UnboundedSender<T> {
UnboundedSender(self.0.clone())
}
}
#[derive(Debug)]
pub struct UnboundedReceiver<T> {
inner: Option<Arc<UnboundedInner<T>>>,
}
impl<T> Unpin for UnboundedReceiver<T> {}
impl<T> UnboundedReceiver<T> {
pub fn close(&mut self) {
if let Some(inner) = &mut self.inner {
inner.set_closed();
}
}
pub fn try_next(&mut self) -> Result<Option<(Priority, T)>, TryRecvError> {
match self.next_message() {
Poll::Ready(msg) => Ok(msg),
Poll::Pending => Err(TryRecvError { _priv: () }),
}
}
fn next_message(&mut self) -> Poll<Option<(Priority, T)>> {
let inner = self
.inner
.as_mut()
.expect("Receiver::next_message called after `None`");
match unsafe { inner.quick_message_queue.pop_spin() } {
Some(msg) => {
self.dec_num_messages();
Poll::Ready(Some((Priority::High, msg)))
}
None => {
match unsafe { inner.message_queue.pop_spin() } {
Some(msg) => {
self.dec_num_messages();
Poll::Ready(Some((Priority::Normal, msg)))
}
None => {
let state = decode_state(inner.state.load(SeqCst));
if state.is_closed() {
self.inner = None;
Poll::Ready(None)
} else {
Poll::Pending
}
}
}
}
}
}
fn dec_num_messages(&self) {
if let Some(inner) = &self.inner {
inner.state.fetch_sub(1, SeqCst);
}
}
}
impl<T> FusedStream for UnboundedReceiver<T> {
fn is_terminated(&self) -> bool {
self.inner.is_none()
}
}
impl<T> Stream for UnboundedReceiver<T> {
type Item = (Priority, T);
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.next_message() {
Poll::Ready(msg) => {
if msg.is_none() {
self.inner = None;
}
Poll::Ready(msg)
}
Poll::Pending => {
self.inner.as_ref().unwrap().recv_task.register(cx.waker());
self.next_message()
}
}
}
}
impl<T> Drop for UnboundedReceiver<T> {
#[allow(clippy::unnecessary_unwrap)]
fn drop(&mut self) {
self.close();
if self.inner.is_some() {
loop {
match self.next_message() {
Poll::Ready(Some(_)) => {}
Poll::Ready(None) => break,
Poll::Pending => {
let state = decode_state(self.inner.as_ref().unwrap().state.load(SeqCst));
if state.is_closed() {
break;
}
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "aarch64",
target_arch = "arm"
))]
std::hint::spin_loop();
#[cfg(not(any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "aarch64",
target_arch = "arm"
)))]
std::thread::yield_now();
}
}
}
}
}
}