use std::collections::VecDeque;
use std::error::Error;
use std::fmt;
use std::fmt::Display;
use std::fmt::Formatter;
use std::future::Future;
use std::num::NonZeroUsize;
use std::pin::Pin;
use std::task::{Context, Poll, Waker};
use crate::broad_rc::{BroadRc, BroadWeak};
const DEFAULT_CAPACITY: usize = 16;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct UnderCapacity(pub usize);
impl Display for UnderCapacity {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "subscribers are under capacity")
}
}
impl Error for UnderCapacity {}
struct ReceiverState<T> {
id: u64,
waker: Option<Waker>,
buf: VecDeque<T>,
unbounded: bool,
}
impl<T> ReceiverState<T> {
fn at_capacity(&self) -> bool {
!self.unbounded && self.buf.capacity() == self.buf.len()
}
}
struct Shared<T> {
id: u64,
sender: Option<Waker>,
receivers: slab::Slab<ReceiverState<T>>,
capacity: Option<NonZeroUsize>,
}
pub struct Sender<T>
where
T: Clone,
{
inner: BroadRc<Shared<T>>,
}
impl<T> Sender<T>
where
T: Clone,
{
fn new_receiver(&mut self) -> usize {
unsafe {
let (inner, _) = self.inner.get_mut_unchecked();
let (capacity, unbounded) = match inner.capacity {
Some(capacity) => (capacity.get(), false),
None => (DEFAULT_CAPACITY, true),
};
inner.receivers.insert(ReceiverState {
id: inner.id,
waker: None,
buf: VecDeque::with_capacity(capacity),
unbounded,
})
}
}
fn bump_message_id(&mut self) {
unsafe {
let (inner, _) = self.inner.get_mut_unchecked();
inner.id = inner.id.wrapping_add(1);
if inner.id == 0 {
inner.id = 1;
}
}
}
pub fn subscribe(&mut self) -> Receiver<T> {
let index = self.new_receiver();
Receiver {
index,
inner: self.inner.weak(),
}
}
pub fn subscribers(&self) -> usize {
unsafe {
let (inner, _) = self.inner.get_mut_unchecked();
inner.receivers.len()
}
}
pub fn try_send(&mut self, value: T) -> Result<usize, UnderCapacity> {
self.bump_message_id();
unsafe {
let (inner, any_receivers_present) = self.inner.get_mut_unchecked();
if !any_receivers_present {
return Ok(0);
}
let mut delivered = 0;
for (_, receiver) in &mut inner.receivers {
if !receiver.at_capacity() {
delivered += 1;
receiver.buf.push_back(value.clone());
if let Some(waker) = &receiver.waker {
waker.wake_by_ref();
}
}
}
if delivered == inner.receivers.len() {
return Ok(delivered);
}
Err(UnderCapacity(delivered))
}
}
pub async fn send(&mut self, value: T) -> usize {
self.bump_message_id();
Send {
inner: &self.inner,
value,
}
.await
}
}
struct Send<'a, T> {
inner: &'a BroadRc<Shared<T>>,
value: T,
}
impl<'a, T> Future for Send<'a, T>
where
T: Clone,
{
type Output = usize;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = Pin::into_inner(self);
unsafe {
let (inner, any_receivers_present) = this.inner.get_mut_unchecked();
if !any_receivers_present {
return Poll::Ready(0);
}
if !matches!(&inner.sender, Some(w) if w.will_wake(cx.waker())) {
inner.sender = Some(cx.waker().clone());
}
loop {
let mut any_sent = false;
let mut delivered = 0;
for (_, receiver) in &mut inner.receivers {
if receiver.id == inner.id {
delivered += 1;
continue;
}
if receiver.at_capacity() {
continue;
}
receiver.buf.push_back(this.value.clone());
if let Some(waker) = &receiver.waker {
waker.wake_by_ref();
}
any_sent = true;
}
if delivered == inner.receivers.len() {
return Poll::Ready(delivered);
}
if any_sent {
continue;
}
return Poll::Pending;
}
}
}
}
impl<'a, T> Unpin for Send<'a, T> {}
pub struct Receiver<T> {
index: usize,
inner: BroadWeak<Shared<T>>,
}
impl<T> Receiver<T> {
pub async fn recv(&mut self) -> Option<T> {
Recv { receiver: self }.await
}
}
struct Recv<'a, T> {
receiver: &'a mut Receiver<T>,
}
impl<'a, T> Future for Recv<'a, T> {
type Output = Option<T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = Pin::into_inner(self);
unsafe {
let index = this.receiver.index;
let (inner, sender_present) = this.receiver.inner.get_mut_unchecked();
let receiver = match inner.receivers.get_mut(index) {
Some(receiver) => receiver,
None => return Poll::Ready(None),
};
if let Some(value) = receiver.buf.pop_front() {
receiver.id = inner.id;
if let Some(waker) = &inner.sender {
waker.wake_by_ref();
}
return Poll::Ready(Some(value));
}
if !sender_present {
receiver.waker = None;
return Poll::Ready(None);
}
if !matches!(&receiver.waker, Some(w) if !w.will_wake(cx.waker())) {
receiver.waker = Some(cx.waker().clone())
}
if let Some(waker) = &inner.sender {
waker.wake_by_ref();
}
Poll::Pending
}
}
}
impl<T> Drop for Recv<'_, T> {
fn drop(&mut self) {
unsafe {
let index = self.receiver.index;
let (inner, _) = self.receiver.inner.get_mut_unchecked();
if let Some(receiver) = inner.receivers.get_mut(index) {
receiver.buf.clear();
}
}
}
}
impl<T> Drop for Sender<T>
where
T: Clone,
{
fn drop(&mut self) {
unsafe {
let (inner, _) = self.inner.get_mut_unchecked();
for (_, r) in &mut inner.receivers {
if let Some(waker) = r.waker.take() {
waker.wake();
}
}
}
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
unsafe {
let index = self.index;
let (inner, _) = self.inner.get_mut_unchecked();
let _ = inner.receivers.try_remove(index);
if let Some(waker) = self.inner.get_mut_unchecked().0.sender.take() {
waker.wake();
}
}
}
}
pub fn channel<T>(capacity: usize) -> Sender<T>
where
T: Clone,
{
let capacity = NonZeroUsize::new(capacity).expect("capacity cannot be 0");
let inner = BroadRc::new(Shared {
id: 0,
sender: None,
receivers: slab::Slab::new(),
capacity: Some(capacity),
});
Sender { inner }
}
pub fn unbounded<T>() -> Sender<T>
where
T: Clone,
{
let inner = BroadRc::new(Shared {
id: 0,
sender: None,
receivers: slab::Slab::new(),
capacity: None,
});
Sender { inner }
}