use std::collections::HashMap;
use std::fmt::{Debug, Formatter};
use std::sync::{mpsc, Arc, RwLock};
use std::thread;
use std::time::Duration;
use crate::core::utils::{Closable, Closer, RingBuffer, UniqueId};
use crate::error::{
RecvError, RecvResult, RecvTimeoutError, RecvTimeoutResult, SendError, SendResult,
TryRecvError, TryRecvResult,
};
const POOLING_TIMEOUT: Duration = Duration::from_micros(50);
#[derive(Clone, Debug)]
pub struct Sender<T> {
inner: mpsc::Sender<T>,
state: Closable,
}
pub struct Receiver<T: Clone + Sync + Send + 'static> {
inner: mpsc::Receiver<T>,
guard: RecvGuard<T>,
}
pub struct RecvGuard<T: Clone + Sync + Send + 'static> {
id: UniqueId,
bus: Arc<BroadcastBus<T>>,
}
unsafe impl<T: Send> Send for Sender<T> {}
unsafe impl<T: Sync> Sync for Sender<T> {}
impl<T> Sender<T> {
pub fn send(&self, value: T) -> SendResult<T> {
if self.state.is_closed() {
return Err(SendError(value));
}
self.inner.send(value).map_err(SendError::from)
}
#[must_use]
#[allow(dead_code)]
pub fn into_inner(self) -> mpsc::Sender<T> {
self.inner
}
}
impl<T: Clone + Sync + Send + 'static> Debug for Receiver<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Receiver")
.field("inner", &self.inner)
.finish_non_exhaustive()
}
}
unsafe impl<T: Clone + Send + Sync + 'static> Send for Receiver<T> {}
unsafe impl<T: Clone + Send + Sync + 'static> Sync for Receiver<T> {}
unsafe impl<T: Clone + Send + Sync + 'static> Send for RecvGuard<T> {}
impl<T: Clone + Sync + Send + 'static> Receiver<T> {
pub fn recv(&self) -> RecvResult<T> {
self.inner.recv().map_err(RecvError::from)
}
pub fn recv_timeout(&self, timeout: Duration) -> RecvTimeoutResult<T> {
self.inner
.recv_timeout(timeout)
.map_err(RecvTimeoutError::from)
}
pub fn try_recv(&self) -> TryRecvResult<T> {
self.inner.try_recv().map_err(TryRecvError::from)
}
pub fn subscribe(&self) -> Receiver<T> {
let (id, rx) = self.guard.bus.add(true);
Receiver {
inner: rx,
guard: RecvGuard {
id,
bus: self.guard.bus.clone(),
},
}
}
#[must_use]
#[allow(dead_code)]
pub fn into_inner(self) -> (mpsc::Receiver<T>, RecvGuard<T>) {
(self.inner, self.guard)
}
}
impl<T: Clone + Sync + Send + 'static> Clone for Receiver<T> {
fn clone(&self) -> Self {
self.subscribe()
}
}
impl<T: Clone + Send + Sync> Clone for RecvGuard<T> {
fn clone(&self) -> Self {
RecvGuard {
id: self.id,
bus: self.bus.clone(),
}
}
}
impl<T: Clone + Send + Sync> Debug for RecvGuard<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RecvGuard").finish_non_exhaustive()
}
}
impl<T: Clone + Sync + Send + 'static> Drop for RecvGuard<T> {
fn drop(&mut self) {
self.bus.remove(&self.id);
}
}
#[must_use]
#[inline(always)]
pub fn channel<T: Clone + Sync + Send + 'static>() -> (Sender<T>, Receiver<T>) {
retentive_channel(0)
}
#[must_use]
pub fn retentive_channel<T: Clone + Sync + Send + 'static>(
depth: usize,
) -> (Sender<T>, Receiver<T>) {
let (send_tx, send_rx) = mpsc::channel();
let state = Closer::new();
let sender = Sender {
inner: send_tx,
state: state.to_closable(),
};
let receiver = {
let bus = BroadcastBus {
recv_txs: Default::default(),
recent: Arc::new(RwLock::new(RingBuffer::new(depth))),
depth,
_state: state,
};
let (id, rx) = bus.add(false);
let bus = Arc::new(bus);
let receiver = Receiver {
inner: rx,
guard: RecvGuard {
id,
bus: bus.clone(),
},
};
bus.start(send_rx);
receiver
};
(sender, receiver)
}
struct BroadcastBus<T> {
recv_txs: Arc<RwLock<HashMap<UniqueId, mpsc::Sender<T>>>>,
recent: Arc<RwLock<RingBuffer<T>>>,
depth: usize,
_state: Closer,
}
impl<T: Clone + Sync + Send + 'static> BroadcastBus<T> {
fn start(&self, send_rx: mpsc::Receiver<T>) {
let recv_txs = self.recv_txs.clone();
let latest = self.recent.clone();
let depth = self.depth;
thread::spawn(move || loop {
if recv_txs.read().unwrap().is_empty() {
return;
}
let data = match send_rx.recv_timeout(POOLING_TIMEOUT) {
Ok(data) => data,
Err(mpsc::RecvTimeoutError::Timeout) => continue,
Err(mpsc::RecvTimeoutError::Disconnected) => {
let mut recv_txs = recv_txs.write().unwrap();
recv_txs.clear();
return;
}
};
if depth > 0 {
let mut latest = latest.write().unwrap();
latest.push(data.clone());
}
let failed_recv_tx_ids = {
let recv_txs = recv_txs.read().unwrap();
if recv_txs.is_empty() {
return;
}
let mut failed_recv_tx_ids = Vec::new();
{
for (id, recv_tx) in recv_txs.iter() {
let recv_tx = recv_tx.clone();
let data = data.clone();
if recv_tx.send(data).is_err() {
failed_recv_tx_ids.push(*id);
}
}
}
failed_recv_tx_ids
};
if !failed_recv_tx_ids.is_empty() {
let mut recv_txs = recv_txs.write().unwrap();
for id in failed_recv_tx_ids {
recv_txs.remove(&id);
}
}
});
}
fn add(&self, push_recent: bool) -> (UniqueId, mpsc::Receiver<T>) {
let (recv_tx, recv_rx) = mpsc::channel();
let id = UniqueId::new();
if push_recent && self.depth > 0 {
let recent = self.recent.read().unwrap();
for msg in recent.iter() {
if recv_tx.send(msg.clone()).is_err() {
break;
}
}
}
{
let mut recv_txs = self.recv_txs.write().unwrap();
recv_txs.insert(id, recv_tx);
}
(id, recv_rx)
}
fn remove(&self, id: &UniqueId) {
let mut recv_txs = self.recv_txs.write().unwrap();
recv_txs.remove(id);
}
}
impl<T> Drop for BroadcastBus<T> {
fn drop(&mut self) {
let mut recv_txs = self.recv_txs.write().unwrap();
recv_txs.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn mpmc_basic() {
let (tx, rx) = channel();
let rx_1 = rx;
let rx_2 = rx_1.clone();
tx.send(1).unwrap();
rx_1.recv().unwrap();
rx_2.recv().unwrap();
}
#[test]
fn mpmc_basic_try_recv() {
let (tx, rx) = channel();
let rx_1 = rx;
let rx_2 = rx_1.clone();
tx.send(1).unwrap();
wait();
rx_1.try_recv().unwrap();
rx_2.try_recv().unwrap();
tx.send(1).unwrap();
wait_long();
rx_1.try_recv().unwrap();
rx_2.try_recv().unwrap();
}
#[test]
fn mpmc_close_on_sender_dropped() {
let (tx, rx) = channel();
drop(tx);
let res: Result<usize, _> = rx.recv();
assert!(res.is_err());
}
#[test]
fn mpmc_close_on_receivers_dropped() {
let (tx, rx) = channel();
drop(rx);
assert!(tx.send(1).is_err());
let (tx, rx) = channel();
let rx_1 = rx;
let rx_2 = rx_1.clone();
drop(rx_1);
drop(rx_2);
assert!(tx.send(1).is_err());
}
#[test]
fn mpmc_multi_threading() {
let (tx, rx) = channel();
let handler = thread::spawn(move || -> usize { rx.recv().unwrap() });
tx.send(1).unwrap();
let res = handler.join().unwrap();
assert_eq!(res, 1);
}
#[test]
fn mpmc_multi_threading_try_recv_recv() {
let (tx, rx) = channel();
let handler = thread::spawn(move || -> usize {
wait();
rx.try_recv().unwrap()
});
tx.send(1).unwrap();
let res = handler.join().unwrap();
assert_eq!(res, 1);
}
#[test]
fn mpmc_multi_threading_try_recv_send() {
let (tx, rx) = channel();
let handler = thread::spawn(move || -> () { tx.send(1).unwrap() });
wait();
assert_eq!(rx.recv_timeout(RECV_TIMEOUT).unwrap(), 1);
handler.join().unwrap();
}
#[test]
fn inner_receiver_is_active_while_guard_is_present() {
let (tx, rx) = channel();
let (rx_inner, _guard) = rx.into_inner();
let handler = thread::spawn(move || -> Result<(), mpsc::RecvError> { rx_inner.recv() });
assert!(tx.send(()).is_ok());
assert!(handler.join().unwrap().is_ok());
}
#[test]
fn inner_receiver_disconnects_when_guard_is_dropped() {
let (tx, rx) = channel();
let (rx_inner, guard) = rx.into_inner();
let handler = { thread::spawn(move || -> Result<(), mpsc::RecvError> { rx_inner.recv() }) };
drop(guard);
assert!(tx.send(()).is_err());
assert!(handler.join().unwrap().is_err());
}
const WAIT_DURATION: Duration = Duration::from_millis(10);
const WAIT_LONG_DURATION: Duration = Duration::from_millis(500);
const RECV_TIMEOUT: Duration = WAIT_DURATION;
fn wait() {
thread::sleep(WAIT_DURATION);
}
fn wait_long() {
thread::sleep(WAIT_LONG_DURATION);
}
}