use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::task::Context;
use std::task::Poll;
use crate::internal::Mutex;
use crate::internal::RwLock;
use crate::internal::WaitSet;
#[cfg(test)]
mod tests;
pub fn channel<T: Clone>(capacity: usize) -> (Sender<T>, Receiver<T>) {
assert!(capacity > 0, "capacity must be greater than 0");
let capacity = capacity.next_power_of_two();
let mask = capacity - 1;
let mut buffer = Vec::with_capacity(capacity);
for _ in 0..capacity {
buffer.push(RwLock::new(Slot {
msg: None,
version: 0,
}));
}
let shared = Arc::new(Shared {
buffer: buffer.into_boxed_slice(),
capacity,
mask,
tail_cnt: AtomicU64::new(0),
senders: AtomicUsize::new(1),
waiters: Mutex::new(WaitSet::new()),
});
let sender = Sender {
shared: shared.clone(),
};
let receiver = Receiver { shared, head: 0 };
(sender, receiver)
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RecvError {
Lagged(u64),
Disconnected,
}
impl fmt::Display for RecvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
RecvError::Lagged(n) => write!(f, "receiver has been lagged by {n}"),
RecvError::Disconnected => write!(f, "receiving on a closed channel"),
}
}
}
impl std::error::Error for RecvError {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TryRecvError {
Empty,
Lagged(u64),
Disconnected,
}
impl fmt::Display for TryRecvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TryRecvError::Empty => write!(f, "receiving on an empty channel"),
TryRecvError::Lagged(n) => write!(f, "receiver has been lagged by {n}"),
TryRecvError::Disconnected => write!(f, "receiving on a closed channel"),
}
}
}
impl std::error::Error for TryRecvError {}
#[derive(Debug)]
struct Slot<T> {
msg: Option<T>,
version: u64,
}
struct Shared<T> {
buffer: Box<[RwLock<Slot<T>>]>,
capacity: usize,
mask: usize,
tail_cnt: AtomicU64,
senders: AtomicUsize,
waiters: Mutex<WaitSet>,
}
pub struct Sender<T> {
shared: Arc<Shared<T>>,
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
self.shared.senders.fetch_add(1, Ordering::Release);
Self {
shared: self.shared.clone(),
}
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
match self.shared.senders.fetch_sub(1, Ordering::AcqRel) {
1 => {
self.shared.waiters.lock().wake_all();
}
_ => {
}
}
}
}
impl<T> Sender<T> {
pub fn send(&self, msg: T) {
let tail = self.shared.tail_cnt.fetch_add(1, Ordering::SeqCst);
let idx = (tail as usize) & self.shared.mask;
{
let mut slot = self.shared.buffer[idx].write();
slot.msg = Some(msg);
slot.version = tail;
}
self.shared.waiters.lock().wake_all();
}
pub fn subscribe(&self) -> Receiver<T> {
let head = self.shared.tail_cnt.load(Ordering::SeqCst);
let shared = self.shared.clone();
Receiver { shared, head }
}
}
pub struct Receiver<T> {
shared: Arc<Shared<T>>,
head: u64,
}
impl<T> Clone for Receiver<T> {
fn clone(&self) -> Self {
Self {
shared: self.shared.clone(),
head: self.head,
}
}
}
impl<T: Clone> Receiver<T> {
pub async fn recv(&mut self) -> Result<T, RecvError> {
Recv {
receiver: self,
index: None,
}
.await
}
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
let shared = &self.shared;
let cap = shared.capacity as u64;
let tail = shared.tail_cnt.load(Ordering::SeqCst);
let head = self.head;
let diff = tail.wrapping_sub(head);
if diff > cap {
let missed = diff - cap;
self.head = tail.wrapping_sub(cap);
return Err(TryRecvError::Lagged(missed));
}
if diff > 0 {
let idx = (head as usize) & shared.mask;
let slot = shared.buffer[idx].read();
if slot.version == head {
return if let Some(msg) = &slot.msg {
self.head = head.wrapping_add(1);
Ok(msg.clone())
} else {
Err(TryRecvError::Empty)
};
}
drop(slot);
let missed = tail.wrapping_sub(self.head).wrapping_sub(cap);
self.head = tail.wrapping_sub(cap);
return Err(TryRecvError::Lagged(missed));
}
if shared.senders.load(Ordering::Acquire) == 0 {
return Err(TryRecvError::Disconnected);
}
Err(TryRecvError::Empty)
}
}
impl<T> Receiver<T> {
pub fn resubscribe(&self) -> Self {
let head = self.shared.tail_cnt.load(Ordering::SeqCst);
let shared = self.shared.clone();
Self { shared, head }
}
}
struct Recv<'a, T> {
receiver: &'a mut Receiver<T>,
index: Option<usize>,
}
impl<T: Clone> Future for Recv<'_, T> {
type Output = Result<T, RecvError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let Self { receiver, index } = self.get_mut();
loop {
match receiver.try_recv() {
Ok(val) => return Poll::Ready(Ok(val)),
Err(TryRecvError::Lagged(n)) => return Poll::Ready(Err(RecvError::Lagged(n))),
Err(TryRecvError::Disconnected) => {
return Poll::Ready(Err(RecvError::Disconnected));
}
Err(TryRecvError::Empty) => {}
}
let shared = &receiver.shared;
let mut waiters = shared.waiters.lock();
let tail_now = shared.tail_cnt.load(Ordering::SeqCst);
if tail_now != receiver.head {
drop(waiters);
continue;
}
if shared.senders.load(Ordering::Acquire) == 0 {
return Poll::Ready(Err(RecvError::Disconnected));
}
waiters.register_waker(index, cx);
return Poll::Pending;
}
}
}