use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use crossbeam_utils::Backoff;
use crate::queue::spsc as queue;
const DEFAULT_PARK_TIMEOUT: Duration = Duration::from_millis(100);
pub fn channel(capacity: usize) -> (Sender, Receiver) {
let (producer, consumer) = queue::new(capacity);
let shared = Arc::new(ChannelShared {
receiver_waiting: AtomicBool::new(false),
sender_disconnected: AtomicBool::new(false),
receiver_disconnected: AtomicBool::new(false),
});
let parker = crossbeam_utils::sync::Parker::new();
let unparker = parker.unparker().clone();
(
Sender {
inner: producer,
receiver_unparker: unparker,
shared: Arc::clone(&shared),
},
Receiver {
inner: consumer,
parker,
shared,
},
)
}
struct ChannelShared {
receiver_waiting: AtomicBool,
sender_disconnected: AtomicBool,
receiver_disconnected: AtomicBool,
}
pub struct Sender {
inner: queue::Producer,
receiver_unparker: crossbeam_utils::sync::Unparker,
shared: Arc<ChannelShared>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ChannelClosed;
impl std::fmt::Display for ChannelClosed {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("channel disconnected")
}
}
impl std::error::Error for ChannelClosed {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TrySendError {
Full,
Disconnected,
}
impl std::fmt::Display for TrySendError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Full => write!(f, "channel full"),
Self::Disconnected => write!(f, "channel disconnected"),
}
}
}
impl std::error::Error for TrySendError {}
impl Sender {
#[inline]
pub fn send(&mut self, len: usize) -> Result<queue::WriteClaim<'_>, ChannelClosed> {
assert!(len > 0, "payload length must be non-zero");
if self.shared.receiver_disconnected.load(Ordering::Relaxed) {
return Err(ChannelClosed);
}
let backoff = Backoff::new();
loop {
unsafe {
let inner_ptr: *mut queue::Producer = &raw mut self.inner;
if let Ok(claim) = (*inner_ptr).try_claim(len) {
return Ok(std::mem::transmute::<
queue::WriteClaim<'_>,
queue::WriteClaim<'_>,
>(claim));
}
backoff.snooze();
if self.shared.receiver_disconnected.load(Ordering::Relaxed) {
return Err(ChannelClosed);
}
if backoff.is_completed() {
backoff.reset();
}
}
}
}
#[inline]
pub fn try_send(&mut self, len: usize) -> Result<queue::WriteClaim<'_>, TrySendError> {
assert!(len > 0, "payload length must be non-zero");
if self.shared.receiver_disconnected.load(Ordering::Relaxed) {
return Err(TrySendError::Disconnected);
}
match self.inner.try_claim(len) {
Ok(claim) => Ok(claim),
Err(crate::BufferFull) => Err(TrySendError::Full),
}
}
#[inline]
pub fn notify(&self) {
if self.shared.receiver_waiting.load(Ordering::Relaxed) {
self.receiver_unparker.unpark();
}
}
#[inline]
pub fn capacity(&self) -> usize {
self.inner.capacity()
}
#[inline]
pub fn is_disconnected(&self) -> bool {
self.shared.receiver_disconnected.load(Ordering::Relaxed)
}
}
impl Drop for Sender {
fn drop(&mut self) {
self.shared
.sender_disconnected
.store(true, Ordering::Relaxed);
self.receiver_unparker.unpark();
}
}
impl std::fmt::Debug for Sender {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Sender")
.field("capacity", &self.capacity())
.finish_non_exhaustive()
}
}
pub struct Receiver {
inner: queue::Consumer,
parker: crossbeam_utils::sync::Parker,
shared: Arc<ChannelShared>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RecvError {
Timeout,
Disconnected,
}
impl std::fmt::Display for RecvError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Timeout => write!(f, "receive timed out"),
Self::Disconnected => write!(f, "channel disconnected"),
}
}
}
impl std::error::Error for RecvError {}
impl Receiver {
#[inline]
pub fn recv(&mut self, timeout: Option<Duration>) -> Result<queue::ReadClaim<'_>, RecvError> {
if timeout == Some(Duration::ZERO) {
unsafe {
let inner_ptr: *mut queue::Consumer = &raw mut self.inner;
if let Some(claim) = (*inner_ptr).try_claim() {
return Ok(std::mem::transmute::<
queue::ReadClaim<'_>,
queue::ReadClaim<'_>,
>(claim));
}
}
if self.shared.sender_disconnected.load(Ordering::Relaxed) {
return Err(RecvError::Disconnected);
}
return Err(RecvError::Timeout);
}
let park_timeout = timeout.unwrap_or(DEFAULT_PARK_TIMEOUT);
let backoff = Backoff::new();
loop {
unsafe {
let inner_ptr: *mut queue::Consumer = &raw mut self.inner;
if let Some(claim) = (*inner_ptr).try_claim() {
return Ok(std::mem::transmute::<
queue::ReadClaim<'_>,
queue::ReadClaim<'_>,
>(claim));
}
}
if self.shared.sender_disconnected.load(Ordering::Relaxed) {
return Err(RecvError::Disconnected);
}
if !backoff.is_completed() {
backoff.snooze();
continue;
}
self.shared.receiver_waiting.store(true, Ordering::Relaxed);
self.parker.park_timeout(park_timeout);
self.shared.receiver_waiting.store(false, Ordering::Relaxed);
if timeout.is_some() {
unsafe {
let inner_ptr: *mut queue::Consumer = &raw mut self.inner;
if let Some(claim) = (*inner_ptr).try_claim() {
return Ok(std::mem::transmute::<
queue::ReadClaim<'_>,
queue::ReadClaim<'_>,
>(claim));
}
}
if self.shared.sender_disconnected.load(Ordering::Relaxed) {
return Err(RecvError::Disconnected);
}
return Err(RecvError::Timeout);
}
backoff.reset();
}
}
#[inline]
pub fn try_recv(&mut self) -> Option<queue::ReadClaim<'_>> {
self.inner.try_claim()
}
#[inline]
pub fn capacity(&self) -> usize {
self.inner.capacity()
}
#[inline]
pub fn is_disconnected(&self) -> bool {
self.shared.sender_disconnected.load(Ordering::Relaxed)
}
}
impl Drop for Receiver {
fn drop(&mut self) {
self.shared
.receiver_disconnected
.store(true, Ordering::Relaxed);
}
}
impl std::fmt::Debug for Receiver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Receiver")
.field("capacity", &self.capacity())
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn basic_send_recv() {
let (mut tx, mut rx) = channel(1024);
let payload = b"hello world";
let mut claim = tx.send(payload.len()).unwrap();
claim.copy_from_slice(payload);
claim.commit();
tx.notify();
let record = rx.recv(None).unwrap();
assert_eq!(&*record, payload);
}
#[test]
fn try_send_try_recv() {
let (mut tx, mut rx) = channel(1024);
assert!(rx.try_recv().is_none());
let payload = b"test";
let mut claim = tx.try_send(payload.len()).unwrap();
claim.copy_from_slice(payload);
claim.commit();
{
let record = rx.try_recv().unwrap();
assert_eq!(&*record, payload);
}
assert!(rx.try_recv().is_none());
}
#[test]
fn cross_thread() {
let (mut tx, mut rx) = channel(4096);
let producer = thread::spawn(move || {
for i in 0..1000u64 {
let payload = i.to_le_bytes();
{
let mut claim = tx.send(payload.len()).unwrap();
claim.copy_from_slice(&payload);
claim.commit();
} tx.notify();
}
});
let consumer = thread::spawn(move || {
for i in 0..1000u64 {
let record = rx.recv(None).unwrap();
let value = u64::from_le_bytes((*record).try_into().unwrap());
assert_eq!(value, i);
}
});
producer.join().unwrap();
consumer.join().unwrap();
}
#[test]
fn disconnection_sender_dropped() {
let (tx, mut rx) = channel(1024);
drop(tx);
match rx.recv(None) {
Err(RecvError::Disconnected) => {}
_ => panic!("expected Disconnected"),
}
}
#[test]
fn disconnection_receiver_dropped() {
let (mut tx, rx) = channel(1024);
drop(rx);
match tx.send(8) {
Err(ChannelClosed) => {}
_ => panic!("expected ChannelClosed"),
}
}
#[test]
fn recv_timeout_works() {
let (_tx, mut rx) = channel(1024);
let start = std::time::Instant::now();
let result = rx.recv(Some(Duration::from_millis(50)));
let elapsed = start.elapsed();
assert!(matches!(result, Err(RecvError::Timeout)));
assert!(elapsed >= Duration::from_millis(40)); assert!(elapsed < Duration::from_millis(200));
}
#[test]
fn recv_timeout_with_data() {
let (mut tx, mut rx) = channel(1024);
let payload = b"data";
let mut claim = tx.send(payload.len()).unwrap();
claim.copy_from_slice(payload);
claim.commit();
tx.notify();
let result = rx.recv(Some(Duration::from_secs(1)));
assert!(result.is_ok());
assert_eq!(&*result.unwrap(), payload);
}
#[test]
fn try_send_returns_full() {
let (mut tx, _rx) = channel(64);
let mut count = 0;
loop {
match tx.try_send(8) {
Ok(mut claim) => {
claim.copy_from_slice(b"12345678");
claim.commit();
count += 1;
}
Err(TrySendError::Full) => break,
Err(e) => panic!("unexpected error: {:?}", e),
}
}
assert!(count > 0);
}
#[test]
#[should_panic(expected = "payload length must be non-zero")]
fn send_zero_panics() {
let (mut tx, _rx) = channel(1024);
let _ = tx.send(0);
}
#[test]
#[should_panic(expected = "payload length must be non-zero")]
fn try_send_zero_panics() {
let (mut tx, _rx) = channel(1024);
let _ = tx.try_send(0);
}
}