use std::collections::VecDeque;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
pub enum SendError<T> {
Closed(T),
}
impl<T> std::fmt::Debug for SendError<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("SendError::Closed")
.field(&format_args!("_"))
.finish()
}
}
impl<T> PartialEq for SendError<T> {
fn eq(&self, _other: &Self) -> bool {
true
}
}
impl<T> Eq for SendError<T> {}
impl<T> std::fmt::Display for SendError<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Channel closed")
}
}
impl<T> std::error::Error for SendError<T> {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RecvError {
Closed,
}
impl std::fmt::Display for RecvError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RecvError::Closed => write!(f, "Channel closed"),
}
}
}
impl std::error::Error for RecvError {}
#[must_use]
pub fn unbounded<T>() -> (Sender<T>, Receiver<T>) {
let shared = Arc::new(ChannelShared {
buffer: Mutex::new(VecDeque::new()),
sender_count: AtomicUsize::new(1),
is_receiver_alive: AtomicBool::new(true),
recv_waker: Mutex::new(None),
});
let sender = Sender {
shared: shared.clone(),
};
let receiver = Receiver { shared };
(sender, receiver)
}
#[must_use]
pub fn bounded<T>(cap: usize) -> (Sender<T>, Receiver<T>) {
let shared = Arc::new(ChannelShared {
buffer: Mutex::new(VecDeque::with_capacity(cap)),
sender_count: AtomicUsize::new(1),
is_receiver_alive: AtomicBool::new(true),
recv_waker: Mutex::new(None),
});
let sender = Sender {
shared: shared.clone(),
};
let receiver = Receiver { shared };
(sender, receiver)
}
struct ChannelShared<T> {
buffer: Mutex<VecDeque<T>>,
sender_count: AtomicUsize,
is_receiver_alive: AtomicBool,
recv_waker: Mutex<Option<Waker>>,
}
pub struct Sender<T> {
shared: Arc<ChannelShared<T>>,
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
self.shared.sender_count.fetch_add(1, Ordering::Relaxed);
Self {
shared: Arc::clone(&self.shared),
}
}
}
impl<T> Sender<T> {
pub fn send(&self, value: T) -> Result<(), SendError<T>> {
if !self.shared.is_receiver_alive.load(Ordering::Acquire) {
return Err(SendError::Closed(value));
}
let mut buffer = self.shared.buffer.lock().unwrap();
buffer.push_back(value);
if let Some(waker) = self.shared.recv_waker.lock().unwrap().take() {
drop(buffer);
waker.wake();
}
Ok(())
}
#[must_use]
pub fn is_closed(&self) -> bool {
!self.shared.is_receiver_alive.load(Ordering::Acquire)
}
#[must_use]
pub fn sender_count(&self) -> usize {
self.shared.sender_count.load(Ordering::Acquire)
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
let prev = self.shared.sender_count.fetch_sub(1, Ordering::AcqRel);
if prev == 1 {
if let Some(waker) = self.shared.recv_waker.lock().unwrap().take() {
waker.wake();
}
}
}
}
pub struct Receiver<T> {
shared: Arc<ChannelShared<T>>,
}
impl<T> Receiver<T> {
pub fn recv(&mut self) -> RecvFuture<'_, T> {
RecvFuture::new(self)
}
pub fn try_recv(&mut self) -> Result<T, RecvError> {
let mut buffer = self.shared.buffer.lock().unwrap();
if let Some(value) = buffer.pop_front() {
Ok(value)
} else if self.shared.sender_count.load(Ordering::Acquire) == 0 {
Err(RecvError::Closed)
} else {
Err(RecvError::Closed)
}
}
#[must_use]
pub fn len(&self) -> usize {
self.shared.buffer.lock().unwrap().len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
self.shared
.is_receiver_alive
.store(false, Ordering::Release);
}
}
pub struct RecvFuture<'a, T> {
shared: Arc<ChannelShared<T>>,
_marker: std::marker::PhantomData<&'a mut Receiver<T>>,
}
impl<'a, T> RecvFuture<'a, T> {
fn new(receiver: &'a mut Receiver<T>) -> Self {
Self {
shared: Arc::clone(&receiver.shared),
_marker: std::marker::PhantomData,
}
}
}
unsafe impl<T: Send> Send for RecvFuture<'_, T> {}
unsafe impl<T: Sync> Sync for RecvFuture<'_, T> {}
impl<T> Future for RecvFuture<'_, T> {
type Output = Option<T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut buffer = self.shared.buffer.lock().unwrap();
if let Some(value) = buffer.pop_front() {
Poll::Ready(Some(value))
} else if self.shared.sender_count.load(Ordering::Acquire) == 0 {
Poll::Ready(None)
} else {
*self.shared.recv_waker.lock().unwrap() = Some(cx.waker().clone());
Poll::Pending
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_unbounded_channel_creation() {
let (tx, _rx) = unbounded::<i32>();
assert!(!tx.is_closed());
assert_eq!(tx.sender_count(), 1);
}
#[test]
fn test_bounded_channel_creation() {
let (tx, _rx) = bounded::<i32>(16);
assert!(!tx.is_closed());
assert_eq!(tx.sender_count(), 1);
}
#[test]
fn test_sender_clone() {
let (tx, _rx) = unbounded::<i32>();
let tx2 = tx.clone();
assert_eq!(tx.sender_count(), 2);
assert_eq!(tx2.sender_count(), 2);
drop(tx);
assert_eq!(tx2.sender_count(), 1);
}
#[test]
fn test_receiver_empty() {
let (_tx, rx) = unbounded::<i32>();
assert!(rx.is_empty());
assert_eq!(rx.len(), 0);
}
#[test]
fn test_sync_send() {
let (tx, mut rx) = unbounded::<i32>();
assert!(tx.send(42).is_ok());
assert!(tx.send(100).is_ok());
assert_eq!(rx.len(), 2);
assert!(!rx.is_empty());
assert_eq!(rx.try_recv().unwrap(), 42);
assert_eq!(rx.try_recv().unwrap(), 100);
assert_eq!(rx.try_recv(), Err(RecvError::Closed));
}
#[test]
fn test_send_after_receiver_drop() {
let (tx, rx) = unbounded::<i32>();
drop(rx);
assert!(tx.is_closed());
let err = tx.send(42).unwrap_err();
assert!(matches!(err, SendError::Closed(42)));
assert_eq!(err.to_string(), "Channel closed");
}
#[test]
fn test_recv_error() {
let err = RecvError::Closed;
assert_eq!(err.to_string(), "Channel closed");
}
#[test]
fn test_unbounded_send_recv_order() {
let (tx, mut rx) = unbounded::<String>();
for i in 0..10 {
tx.send(format!("msg-{i}")).unwrap();
}
for i in 0..10 {
assert_eq!(rx.try_recv().unwrap(), format!("msg-{i}"));
}
}
#[test]
fn test_bounded_channel_full() {
let (tx, rx) = bounded::<i32>(2);
assert!(tx.send(1).is_ok());
assert!(tx.send(2).is_ok());
assert!(tx.send(3).is_ok());
assert_eq!(rx.len(), 3);
}
#[test]
fn test_close_after_all_senders_drop() {
let (tx, mut rx) = unbounded::<i32>();
let tx2 = tx.clone();
tx.send(1).unwrap();
drop(tx);
assert!(!tx2.is_closed());
tx2.send(2).unwrap();
drop(tx2);
assert_eq!(rx.try_recv().unwrap(), 1);
assert_eq!(rx.try_recv().unwrap(), 2);
}
}