use std::{
collections::HashMap,
sync::{Arc, RwLock},
thread,
};
use crossbeam::channel::{bounded, unbounded, Receiver, Sender, TryRecvError};
struct BusInner<T: Clone> {
senders: HashMap<usize, Sender<T>>,
next_id: usize,
}
impl<T: Clone> BusInner<T> {
pub fn add_rx(&mut self) -> Receiver<T> {
let (sender, receiver) = unbounded::<T>();
self.senders.insert(self.next_id, sender);
self.next_id += 1;
receiver
}
pub fn broadcast(&self, event: T) -> Vec<usize> {
let mut disconnected = Vec::with_capacity(0);
if let Some(((last_id, last_sender), the_rest)) = self.get_sorted_senders().split_last() {
for (id, sender) in the_rest.iter() {
if sender.send(event.clone()).is_err() {
disconnected.push(**id);
}
}
if last_sender.send(event).is_err() {
disconnected.push(**last_id);
};
}
disconnected
}
pub fn remove_senders(&mut self, ids: &[usize]) {
for id in ids {
self.senders.remove(&id);
}
}
fn get_sorted_senders(&self) -> Vec<(&usize, &Sender<T>)> {
let mut senders = self.senders.iter().collect::<Vec<(&usize, &Sender<T>)>>();
senders.sort_by_key(|(id, _)| **id);
senders
}
}
impl<T: Clone> Default for BusInner<T> {
fn default() -> Self {
BusInner {
senders: Default::default(),
next_id: 0,
}
}
}
#[derive(Clone)]
pub struct Bus<T: Clone> {
inner: Arc<RwLock<BusInner<T>>>,
}
impl<T: Clone> Bus<T> {
pub fn new() -> Self {
Bus {
inner: Default::default(),
}
}
pub fn add_rx(&self) -> Receiver<T> {
self.inner.write().expect("Lock was poisoned").add_rx()
}
pub fn broadcast(&self, event: T) {
let disconnected = {
self.inner
.read()
.expect("Lock was poisoned")
.broadcast(event)
};
if !disconnected.is_empty() {
self.inner
.write()
.expect("Lock was poisoned")
.remove_senders(&disconnected);
}
}
}
impl<T: Clone> Default for Bus<T> {
fn default() -> Self {
Bus::new()
}
}
type BoxedFn<T> = Box<dyn FnMut(T) + Send>;
struct DropSignal {
tx_signal: Sender<()>,
}
impl DropSignal {
pub fn new(tx_signal: Sender<()>) -> Arc<Self> {
Arc::new(DropSignal { tx_signal })
}
}
impl Drop for DropSignal {
fn drop(&mut self) {
let _ = self.tx_signal.send(());
}
}
#[derive(Clone)]
pub struct Subscription {
terminate: Arc<DropSignal>,
}
impl Subscription {
pub fn new(terminate: Sender<()>) -> Self {
Subscription {
terminate: DropSignal::new(terminate),
}
}
}
pub trait SubscribeToReader<T: Send + 'static> {
#[must_use]
fn subscribe_on_thread(&self, callback: BoxedFn<T>) -> Subscription;
fn subscribe(&self, callback: BoxedFn<T>);
}
impl<T: Send + 'static> SubscribeToReader<T> for Receiver<T> {
#[must_use]
fn subscribe_on_thread(&self, mut callback: BoxedFn<T>) -> Subscription {
let (terminate_tx, terminate_rx) = bounded::<()>(0);
let receiver = self.clone();
thread::Builder::new()
.name("Receiver subscription thread".to_string())
.spawn(move || loop {
for event in receiver.try_iter() {
callback(event);
}
match terminate_rx.try_recv() {
Err(TryRecvError::Empty) => {}
_ => return,
}
})
.expect("Could not start Receiver subscription thread");
Subscription::new(terminate_tx)
}
fn subscribe(&self, mut callback: BoxedFn<T>) {
for event in self.iter() {
callback(event);
}
}
}
impl<T: Clone + Send + 'static> SubscribeToReader<T> for Bus<T> {
#[must_use]
fn subscribe_on_thread(&self, callback: BoxedFn<T>) -> Subscription {
self.add_rx().subscribe_on_thread(callback)
}
fn subscribe(&self, callback: BoxedFn<T>) {
self.add_rx().subscribe(callback)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crossbeam::channel::RecvTimeoutError;
use std::time::Duration;
#[derive(Clone, PartialEq, Debug)]
struct Something;
#[derive(Clone, PartialEq, Debug)]
enum Event {
Start,
Stop(Vec<Something>),
}
#[test]
fn subscribe_on_thread() {
let dispatcher = Bus::<Event>::new();
let _sub_unused = dispatcher.subscribe_on_thread(Box::new(move |_event| {
}));
let __sub_unused = dispatcher.subscribe_on_thread(Box::new(move |_event| {
}));
let (tx_test, rx_test) = unbounded::<Event>();
{
let _sub = dispatcher.subscribe_on_thread(Box::new(move |event| {
tx_test.send(event).unwrap();
}));
dispatcher.broadcast(Event::Start);
dispatcher.broadcast(Event::Stop(vec![Something {}]));
match rx_test.recv_timeout(Duration::from_millis(100)) {
Err(_) => panic!("Event not received"),
Ok(e) => assert_eq!(e, Event::Start),
}
match rx_test.recv_timeout(Duration::from_millis(100)) {
Err(_) => panic!("Event not received"),
Ok(e) => assert_eq!(e, Event::Stop(vec![Something {}])),
}
}
dispatcher.broadcast(Event::Start);
match rx_test.recv_timeout(Duration::from_millis(100)) {
Err(RecvTimeoutError::Disconnected) => {}
_ => panic!("Subscription has been dropped so we should not get any events"),
}
}
#[test]
fn clone_subscription_without_dropping() {
let dispatcher = Bus::<Event>::new();
let (tx_test, rx_test) = unbounded::<Event>();
{
let sub = dispatcher.subscribe_on_thread(Box::new(move |event| {
tx_test.send(event).unwrap();
}));
{
#[allow(clippy::redundant_clone)]
let _sub_clone = sub.clone();
}
dispatcher.broadcast(Event::Start);
match rx_test.recv_timeout(Duration::from_millis(100)) {
Err(_) => panic!("Event not received"),
Ok(e) => assert_eq!(e, Event::Start),
}
}
dispatcher.broadcast(Event::Start);
match rx_test.recv_timeout(Duration::from_millis(100)) {
Err(RecvTimeoutError::Disconnected) => {}
_ => panic!("Subscription has been dropped so we should not get any events"),
}
}
}