use std::{fmt, sync::Arc};
use crate::{pdeque::Deque, MessageDeliveryPolicy};
use parking_lot::{Condvar, Mutex};
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum Error {
Full,
Closed,
Empty,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::Full => {
write!(f, "channel full")
}
Error::Closed => {
write!(f, "channel closed")
}
Error::Empty => {
write!(f, "channel empty")
}
}
}
}
impl std::error::Error for Error {}
struct Channel<T: MessageDeliveryPolicy>(Arc<ChannelInner<T>>);
impl<T> Clone for Channel<T>
where
T: MessageDeliveryPolicy,
{
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
struct ChannelInner<T: MessageDeliveryPolicy> {
pc: Mutex<PolicyChannel<T>>,
available: Condvar,
}
impl<T: MessageDeliveryPolicy> ChannelInner<T> {
fn try_send(&self, value: T) -> Result<(), Error> {
let mut pc = self.pc.lock();
if pc.receivers == 0 {
return Err(Error::Closed);
}
if pc.queue.try_push(value).is_none() {
self.available.notify_one();
Ok(())
} else {
Err(Error::Full)
}
}
fn send(&self, mut value: T) -> Result<(), Error> {
let mut pc = self.pc.lock();
loop {
if pc.receivers == 0 {
return Err(Error::Closed);
}
let Some(val) = pc.queue.try_push(value) else {
break;
};
value = val;
self.available.wait(&mut pc);
}
self.available.notify_one();
Ok(())
}
fn recv(&self) -> Result<T, Error> {
let mut pc = self.pc.lock();
loop {
if let Some(val) = pc.queue.get() {
self.available.notify_one();
return Ok(val);
} else if pc.senders == 0 {
return Err(Error::Closed);
}
self.available.wait(&mut pc);
}
}
fn try_recv(&self) -> Result<T, Error> {
let mut pc = self.pc.lock();
if let Some(val) = pc.queue.get() {
self.available.notify_one();
Ok(val)
} else if pc.senders == 0 {
Err(Error::Closed)
} else {
Err(Error::Empty)
}
}
}
impl<T: MessageDeliveryPolicy> Channel<T> {
fn new(capacity: usize, ordering: bool) -> Self {
Self(
ChannelInner {
pc: Mutex::new(PolicyChannel::new(capacity, ordering)),
available: Condvar::new(),
}
.into(),
)
}
}
struct PolicyChannel<T: MessageDeliveryPolicy> {
queue: Deque<T>,
senders: usize,
receivers: usize,
}
impl<T> PolicyChannel<T>
where
T: MessageDeliveryPolicy,
{
fn new(capacity: usize, ordering: bool) -> Self {
assert!(capacity > 0, "channel capacity MUST be > 0");
Self {
queue: Deque::bounded(capacity).with_ordering(ordering),
senders: 1,
receivers: 1,
}
}
}
pub struct Sender<T>
where
T: MessageDeliveryPolicy,
{
channel: Channel<T>,
}
impl<T> Sender<T>
where
T: MessageDeliveryPolicy,
{
#[inline]
pub fn send(&self, value: T) -> Result<(), Error> {
self.channel.0.send(value)
}
#[inline]
pub fn try_send(&self, value: T) -> Result<(), Error> {
self.channel.0.try_send(value)
}
#[inline]
pub fn usage(&self) -> usize {
self.channel.0.pc.lock().queue.usage()
}
#[inline]
pub fn is_full(&self) -> bool {
self.channel.0.pc.lock().queue.is_full()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.channel.0.pc.lock().queue.is_empty()
}
#[inline]
pub fn is_alive(&self) -> bool {
self.channel.0.pc.lock().receivers > 0
}
}
impl<T> Clone for Sender<T>
where
T: MessageDeliveryPolicy,
{
fn clone(&self) -> Self {
self.channel.0.pc.lock().senders += 1;
Self {
channel: self.channel.clone(),
}
}
}
impl<T> Drop for Sender<T>
where
T: MessageDeliveryPolicy,
{
fn drop(&mut self) {
self.channel.0.pc.lock().senders -= 1;
self.channel.0.available.notify_all();
}
}
pub struct Receiver<T>
where
T: MessageDeliveryPolicy,
{
channel: Channel<T>,
}
impl<T> Receiver<T>
where
T: MessageDeliveryPolicy,
{
#[inline]
pub fn recv(&self) -> Result<T, Error> {
self.channel.0.recv()
}
#[inline]
pub fn try_recv(&self) -> Result<T, Error> {
self.channel.0.try_recv()
}
#[inline]
pub fn usage(&self) -> usize {
self.channel.0.pc.lock().queue.usage()
}
#[inline]
pub fn is_full(&self) -> bool {
self.channel.0.pc.lock().queue.is_full()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.channel.0.pc.lock().queue.is_empty()
}
#[inline]
pub fn is_alive(&self) -> bool {
self.channel.0.pc.lock().senders > 0
}
}
impl<T> Clone for Receiver<T>
where
T: MessageDeliveryPolicy,
{
fn clone(&self) -> Self {
self.channel.0.pc.lock().receivers += 1;
Self {
channel: self.channel.clone(),
}
}
}
impl<T> Drop for Receiver<T>
where
T: MessageDeliveryPolicy,
{
fn drop(&mut self) {
self.channel.0.pc.lock().receivers -= 1;
self.channel.0.available.notify_all();
}
}
fn make_channel<T: MessageDeliveryPolicy>(ch: Channel<T>) -> (Sender<T>, Receiver<T>) {
let tx = Sender {
channel: ch.clone(),
};
let rx = Receiver { channel: ch };
(tx, rx)
}
pub fn bounded<T: MessageDeliveryPolicy>(capacity: usize) -> (Sender<T>, Receiver<T>) {
let ch = Channel::new(capacity, false);
make_channel(ch)
}
pub fn ordered<T: MessageDeliveryPolicy>(capacity: usize) -> (Sender<T>, Receiver<T>) {
let ch = Channel::new(capacity, true);
make_channel(ch)
}
#[cfg(test)]
mod test {
use std::{thread, time::Duration};
use crate::{DeliveryPolicy, MessageDeliveryPolicy};
use super::bounded;
#[derive(Debug)]
enum Message {
Test(usize),
Temperature(f64),
Spam,
}
impl MessageDeliveryPolicy for Message {
fn delivery_policy(&self) -> DeliveryPolicy {
match self {
Message::Test(_) => DeliveryPolicy::Always,
Message::Temperature(_) => DeliveryPolicy::Single,
Message::Spam => DeliveryPolicy::Optional,
}
}
}
#[test]
fn test_delivery_policy_optional() {
let (tx, rx) = bounded::<Message>(1);
thread::spawn(move || {
for _ in 0..10 {
tx.send(Message::Test(123)).unwrap();
tx.send(Message::Spam).unwrap();
tx.send(Message::Temperature(123.0)).unwrap();
}
});
thread::sleep(Duration::from_secs(1));
while let Ok(msg) = rx.recv() {
thread::sleep(Duration::from_millis(10));
if matches!(msg, Message::Spam) {
panic!("delivery policy not respected ({:?})", msg);
}
}
}
#[test]
fn test_delivery_policy_single() {
let (tx, rx) = bounded::<Message>(512);
thread::spawn(move || {
for _ in 0..10 {
tx.send(Message::Test(123)).unwrap();
tx.send(Message::Spam).unwrap();
tx.send(Message::Temperature(123.0)).unwrap();
}
});
thread::sleep(Duration::from_secs(1));
let mut c = 0;
let mut t = 0;
while let Ok(msg) = rx.recv() {
match msg {
Message::Test(_) => c += 1,
Message::Temperature(_) => t += 1,
Message::Spam => {}
}
}
assert_eq!(c, 10);
assert_eq!(t, 1);
}
#[test]
fn test_poisoning() {
let n = 10_000;
for i in 0..n {
let (tx, rx) = bounded::<Message>(512);
let rx_t = thread::spawn(move || while rx.recv().is_ok() {});
thread::spawn(move || {
let _t = tx;
});
for _ in 0..10 {
if rx_t.is_finished() {
break;
}
thread::sleep(Duration::from_millis(1));
}
if !rx_t.is_finished() {
panic!("RX poisined {}", i);
}
}
}
}