use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::time::Duration;
use crossbeam_utils::Backoff;
use crate::queue::mpsc as queue;
const DEFAULT_PARK_TIMEOUT: Duration = Duration::from_millis(100);
pub fn channel(capacity: usize) -> (Sender, Receiver) {
let (producer, consumer) = queue::new(capacity);
let parker = crossbeam_utils::sync::Parker::new();
let unparker = parker.unparker().clone();
let shared = Arc::new(ChannelShared {
receiver_waiting: AtomicBool::new(false),
receiver_unparker: unparker,
sender_count: AtomicUsize::new(1),
receiver_disconnected: AtomicBool::new(false),
});
(
Sender {
inner: producer,
shared: Arc::clone(&shared),
},
Receiver {
inner: consumer,
parker,
shared,
},
)
}
struct ChannelShared {
receiver_waiting: AtomicBool,
receiver_unparker: crossbeam_utils::sync::Unparker,
sender_count: AtomicUsize,
receiver_disconnected: AtomicBool,
}
pub struct Sender {
inner: queue::Producer,
shared: Arc<ChannelShared>,
}
impl Clone for Sender {
fn clone(&self) -> Self {
self.shared.sender_count.fetch_add(1, Ordering::Relaxed);
Self {
inner: self.inner.clone(),
shared: Arc::clone(&self.shared),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ChannelClosed;
impl std::fmt::Display for ChannelClosed {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("channel disconnected")
}
}
impl std::error::Error for ChannelClosed {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TrySendError {
Full,
Disconnected,
}
impl std::fmt::Display for TrySendError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Full => write!(f, "channel full"),
Self::Disconnected => write!(f, "channel disconnected"),
}
}
}
impl std::error::Error for TrySendError {}
impl Sender {
#[inline]
pub fn send(&mut self, len: usize) -> Result<queue::WriteClaim<'_>, ChannelClosed> {
assert!(len > 0, "payload length must be non-zero");
if self.shared.receiver_disconnected.load(Ordering::Relaxed) {
return Err(ChannelClosed);
}
let backoff = Backoff::new();
loop {
unsafe {
let inner_ptr: *mut queue::Producer = &raw mut self.inner;
if let Ok(claim) = (*inner_ptr).try_claim(len) {
return Ok(std::mem::transmute::<
queue::WriteClaim<'_>,
queue::WriteClaim<'_>,
>(claim));
}
backoff.snooze();
if self.shared.receiver_disconnected.load(Ordering::Relaxed) {
return Err(ChannelClosed);
}
if backoff.is_completed() {
backoff.reset();
}
}
}
}
#[inline]
pub fn try_send(&mut self, len: usize) -> Result<queue::WriteClaim<'_>, TrySendError> {
assert!(len > 0, "payload length must be non-zero");
if self.shared.receiver_disconnected.load(Ordering::Relaxed) {
return Err(TrySendError::Disconnected);
}
match self.inner.try_claim(len) {
Ok(claim) => Ok(claim),
Err(crate::BufferFull) => Err(TrySendError::Full),
}
}
#[inline]
pub fn notify(&self) {
if self.shared.receiver_waiting.load(Ordering::Relaxed) {
self.shared.receiver_unparker.unpark();
}
}
#[inline]
pub fn capacity(&self) -> usize {
self.inner.capacity()
}
#[inline]
pub fn is_disconnected(&self) -> bool {
self.shared.receiver_disconnected.load(Ordering::Relaxed)
}
}
impl Drop for Sender {
fn drop(&mut self) {
let prev = self.shared.sender_count.fetch_sub(1, Ordering::Relaxed);
if prev == 1 {
self.shared.receiver_unparker.unpark();
}
}
}
impl std::fmt::Debug for Sender {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Sender")
.field("capacity", &self.capacity())
.finish_non_exhaustive()
}
}
pub struct Receiver {
inner: queue::Consumer,
parker: crossbeam_utils::sync::Parker,
shared: Arc<ChannelShared>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RecvError {
Timeout,
Disconnected,
}
impl std::fmt::Display for RecvError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Timeout => write!(f, "receive timed out"),
Self::Disconnected => write!(f, "channel disconnected"),
}
}
}
impl std::error::Error for RecvError {}
impl Receiver {
#[inline]
pub fn recv(&mut self, timeout: Option<Duration>) -> Result<queue::ReadClaim<'_>, RecvError> {
if timeout == Some(Duration::ZERO) {
unsafe {
let inner_ptr: *mut queue::Consumer = &raw mut self.inner;
if let Some(claim) = (*inner_ptr).try_claim() {
return Ok(std::mem::transmute::<
queue::ReadClaim<'_>,
queue::ReadClaim<'_>,
>(claim));
}
}
if self.shared.sender_count.load(Ordering::Relaxed) == 0 {
return Err(RecvError::Disconnected);
}
return Err(RecvError::Timeout);
}
let park_timeout = timeout.unwrap_or(DEFAULT_PARK_TIMEOUT);
let backoff = Backoff::new();
loop {
unsafe {
let inner_ptr: *mut queue::Consumer = &raw mut self.inner;
if let Some(claim) = (*inner_ptr).try_claim() {
return Ok(std::mem::transmute::<
queue::ReadClaim<'_>,
queue::ReadClaim<'_>,
>(claim));
}
}
if self.shared.sender_count.load(Ordering::Relaxed) == 0 {
return Err(RecvError::Disconnected);
}
if !backoff.is_completed() {
backoff.snooze();
continue;
}
self.shared.receiver_waiting.store(true, Ordering::Relaxed);
self.parker.park_timeout(park_timeout);
self.shared.receiver_waiting.store(false, Ordering::Relaxed);
if timeout.is_some() {
unsafe {
let inner_ptr: *mut queue::Consumer = &raw mut self.inner;
if let Some(claim) = (*inner_ptr).try_claim() {
return Ok(std::mem::transmute::<
queue::ReadClaim<'_>,
queue::ReadClaim<'_>,
>(claim));
}
}
if self.shared.sender_count.load(Ordering::Relaxed) == 0 {
return Err(RecvError::Disconnected);
}
return Err(RecvError::Timeout);
}
backoff.reset();
}
}
#[inline]
pub fn try_recv(&mut self) -> Option<queue::ReadClaim<'_>> {
self.inner.try_claim()
}
#[inline]
pub fn capacity(&self) -> usize {
self.inner.capacity()
}
#[inline]
pub fn is_disconnected(&self) -> bool {
self.shared.sender_count.load(Ordering::Relaxed) == 0
}
}
impl Drop for Receiver {
fn drop(&mut self) {
self.shared
.receiver_disconnected
.store(true, Ordering::Relaxed);
}
}
impl std::fmt::Debug for Receiver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Receiver")
.field("capacity", &self.capacity())
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn basic_send_recv() {
let (mut tx, mut rx) = channel(1024);
let payload = b"hello world";
let mut claim = tx.send(payload.len()).unwrap();
claim.copy_from_slice(payload);
claim.commit();
tx.notify();
let record = rx.recv(None).unwrap();
assert_eq!(&*record, payload);
}
#[test]
#[allow(clippy::redundant_clone)]
fn sender_is_clone() {
let (tx, _rx) = channel(1024);
let _tx2 = tx.clone();
}
#[test]
fn multiple_senders() {
const SENDERS: usize = 4;
const MESSAGES: usize = 100;
let (tx, mut rx) = channel(4096);
let handles: Vec<_> = (0..SENDERS)
.map(|id| {
let mut tx = tx.clone();
thread::spawn(move || {
for i in 0..MESSAGES {
let payload = format!("{}:{}", id, i);
let mut claim = tx.send(payload.len()).unwrap();
claim.copy_from_slice(payload.as_bytes());
claim.commit();
tx.notify();
}
})
})
.collect();
drop(tx);
let mut count = 0;
while let Ok(_record) = rx.recv(None) {
count += 1;
if count == SENDERS * MESSAGES {
break;
}
}
for h in handles {
h.join().unwrap();
}
assert_eq!(count, SENDERS * MESSAGES);
}
#[test]
fn disconnection_all_senders_dropped() {
let (tx, mut rx) = channel(1024);
drop(tx);
match rx.recv(None) {
Err(RecvError::Disconnected) => {}
_ => panic!("expected Disconnected"),
}
}
#[test]
fn disconnection_receiver_dropped() {
let (mut tx, rx) = channel(1024);
drop(rx);
match tx.send(8) {
Err(ChannelClosed) => {}
_ => panic!("expected ChannelClosed"),
}
}
#[test]
fn recv_timeout_works() {
let (_tx, mut rx) = channel(1024);
let start = std::time::Instant::now();
let result = rx.recv(Some(Duration::from_millis(50)));
let elapsed = start.elapsed();
assert!(matches!(result, Err(RecvError::Timeout)));
assert!(elapsed >= Duration::from_millis(40));
assert!(elapsed < Duration::from_millis(200));
}
#[test]
#[should_panic(expected = "payload length must be non-zero")]
fn send_zero_panics() {
let (mut tx, _rx) = channel(1024);
let _ = tx.send(0);
}
#[test]
#[should_panic(expected = "payload length must be non-zero")]
fn try_send_zero_panics() {
let (mut tx, _rx) = channel(1024);
let _ = tx.try_send(0);
}
#[test]
fn stress_multiple_senders() {
const SENDERS: usize = 4;
const MESSAGES_PER_SENDER: u64 = 10_000;
const TOTAL: u64 = SENDERS as u64 * MESSAGES_PER_SENDER;
const BUFFER_SIZE: usize = 64 * 1024;
let (tx, mut rx) = channel(BUFFER_SIZE);
let handles: Vec<_> = (0..SENDERS)
.map(|sender_id| {
let mut tx = tx.clone();
thread::spawn(move || {
for i in 0..MESSAGES_PER_SENDER {
let mut payload = [0u8; 16];
payload[..8].copy_from_slice(&(sender_id as u64).to_le_bytes());
payload[8..].copy_from_slice(&i.to_le_bytes());
{
let mut claim = tx.send(16).unwrap();
claim.copy_from_slice(&payload);
claim.commit();
}
tx.notify();
}
})
})
.collect();
drop(tx);
let consumer = thread::spawn(move || {
let mut received = 0u64;
let mut per_sender = vec![0u64; SENDERS];
while received < TOTAL {
match rx.recv(None) {
Ok(record) => {
let sender_id =
u64::from_le_bytes(record[..8].try_into().unwrap()) as usize;
let seq = u64::from_le_bytes(record[8..].try_into().unwrap());
assert_eq!(
seq, per_sender[sender_id],
"sender {} out of order at {}",
sender_id, received
);
per_sender[sender_id] += 1;
received += 1;
}
Err(RecvError::Timeout) => unreachable!(),
Err(RecvError::Disconnected) => break,
}
}
per_sender
});
for h in handles {
h.join().unwrap();
}
let per_sender = consumer.join().unwrap();
for (i, &count) in per_sender.iter().enumerate() {
assert_eq!(count, MESSAGES_PER_SENDER, "sender {} count", i);
}
}
}