use std::collections::HashMap;
use std::collections::hash_map::Entry;
use std::hash::Hash;
use std::pin::Pin;
use std::sync::{Arc, Mutex, MutexGuard};
use std::task::{Context, Poll};
use futures_channel::mpsc;
use futures_core::Stream;
pub trait RoutableEvent: Clone {
type Target: Eq + Hash + Clone;
fn targets(&self) -> impl Iterator<Item = Self::Target>;
}
pub trait Publisher<E> {
fn publish(&self, event: E);
}
pub struct NoopPublisher;
impl<E> Publisher<E> for NoopPublisher {
fn publish(&self, _event: E) {}
}
pub trait EventHandler<E> {
type Publish: Publisher<E>;
fn init() -> (Self::Publish, Self);
}
pub struct NoopHandler;
impl<E> EventHandler<E> for NoopHandler {
type Publish = NoopPublisher;
fn init() -> (Self::Publish, Self) {
(NoopPublisher, NoopHandler)
}
}
const SUBSCRIBER_BUFFER: usize = 256;
type Subscribers<T> = HashMap<u64, mpsc::Sender<T>>;
struct Inner<T: RoutableEvent> {
targets: HashMap<T::Target, Subscribers<T>>,
next_id: u64,
}
pub struct EventBusPublisher<T: RoutableEvent> {
inner: Arc<Mutex<Inner<T>>>,
}
pub struct EventBus<T: RoutableEvent> {
inner: Arc<Mutex<Inner<T>>>,
}
impl<T: RoutableEvent> Clone for EventBus<T> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
pub fn event_bus<T: RoutableEvent>() -> (EventBusPublisher<T>, EventBus<T>) {
let inner = Arc::new(Mutex::new(Inner {
targets: HashMap::new(),
next_id: 0,
}));
(
EventBusPublisher {
inner: Arc::clone(&inner),
},
EventBus { inner },
)
}
impl<T: RoutableEvent> EventHandler<T> for EventBus<T> {
type Publish = EventBusPublisher<T>;
fn init() -> (Self::Publish, Self) {
event_bus()
}
}
impl<T: RoutableEvent> Publisher<T> for EventBusPublisher<T> {
fn publish(&self, event: T) {
let mut inner = lock(&self.inner);
for target in event.targets() {
if let Entry::Occupied(mut e) = inner.targets.entry(target) {
e.get_mut()
.retain(|_, tx| tx.try_send(event.clone()).is_ok());
if e.get().is_empty() {
e.remove();
}
}
}
}
}
impl<T: RoutableEvent> EventBus<T> {
pub fn subscribe(&self, target: T::Target) -> Subscription<T> {
let (tx, rx) = mpsc::channel(SUBSCRIBER_BUFFER);
let mut inner = lock(&self.inner);
let id = inner.next_id;
inner.next_id += 1;
inner
.targets
.entry(target.clone())
.or_default()
.insert(id, tx);
Subscription {
inner: Arc::clone(&self.inner),
target,
id,
rx,
}
}
}
pub struct Subscription<T: RoutableEvent> {
inner: Arc<Mutex<Inner<T>>>,
target: T::Target,
id: u64,
rx: mpsc::Receiver<T>,
}
impl<T: RoutableEvent> Stream for Subscription<T>
where
T::Target: Unpin,
{
type Item = T;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
Pin::new(&mut self.get_mut().rx).poll_next(cx)
}
}
impl<T: RoutableEvent> Drop for Subscription<T> {
fn drop(&mut self) {
let mut inner = lock(&self.inner);
if let Entry::Occupied(mut e) = inner.targets.entry(self.target.clone()) {
e.get_mut().remove(&self.id);
if e.get().is_empty() {
e.remove();
}
}
}
}
fn lock<G>(m: &Mutex<G>) -> MutexGuard<'_, G> {
m.lock().unwrap_or_else(|e| e.into_inner())
}
#[cfg(test)]
mod tests {
use futures::{FutureExt, StreamExt};
use super::*;
#[derive(Clone)]
struct Ev {
n: i32,
to: Vec<u32>,
}
impl RoutableEvent for Ev {
type Target = u32;
fn targets(&self) -> impl Iterator<Item = u32> {
self.to.clone().into_iter()
}
}
fn ev(n: i32, to: &[u32]) -> Ev {
Ev { n, to: to.to_vec() }
}
fn recv(sub: &mut Subscription<Ev>) -> Option<i32> {
sub.next().now_or_never().flatten().map(|e| e.n)
}
fn target_count(bus: &EventBus<Ev>) -> usize {
lock(&bus.inner).targets.len()
}
#[test]
fn delivers_to_subscriber() {
let (p, bus) = event_bus::<Ev>();
let mut s = bus.subscribe(1);
p.publish(ev(10, &[1]));
assert_eq!(recv(&mut s), Some(10));
}
#[test]
fn fans_out_to_all_targets_of_an_event() {
let (p, bus) = event_bus::<Ev>();
let mut a = bus.subscribe(1);
let mut b = bus.subscribe(2);
p.publish(ev(7, &[1, 2]));
assert_eq!(recv(&mut a), Some(7));
assert_eq!(recv(&mut b), Some(7));
}
#[test]
fn other_targets_are_not_delivered() {
let (p, bus) = event_bus::<Ev>();
let mut a = bus.subscribe(1);
p.publish(ev(1, &[2]));
assert_eq!(recv(&mut a), None);
}
#[test]
fn last_unsubscribe_removes_the_target() {
let (_p, bus) = event_bus::<Ev>();
let a = bus.subscribe(1);
let b = bus.subscribe(1);
assert_eq!(target_count(&bus), 1);
drop(a);
assert_eq!(target_count(&bus), 1); drop(b);
assert_eq!(target_count(&bus), 0); }
#[test]
fn overflowing_subscriber_is_pruned_on_publish() {
let (p, bus) = event_bus::<Ev>();
let _s = bus.subscribe(1); for i in 0..(SUBSCRIBER_BUFFER as i32 + 8) {
p.publish(ev(i, &[1]));
}
assert_eq!(target_count(&bus), 0); }
}