use std::collections::BTreeMap;
use std::string::String;
use std::sync::mpsc;
use std::vec::Vec;
#[derive(Debug, Clone)]
pub enum RouterMsg {
Sample {
topic: String,
payload: Vec<u8>,
},
Shutdown,
}
#[derive(Debug, Default)]
pub struct Router {
subs: BTreeMap<String, Vec<u64>>,
conns: BTreeMap<u64, mpsc::Sender<RouterMsg>>,
}
impl Router {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register_connection(&mut self, id: u64, sender: mpsc::Sender<RouterMsg>) {
self.conns.insert(id, sender);
}
pub fn deregister_connection(&mut self, id: u64) {
self.conns.remove(&id);
for subs in self.subs.values_mut() {
subs.retain(|c| *c != id);
}
}
pub fn subscribe(&mut self, conn_id: u64, topic: String) {
let entry = self.subs.entry(topic).or_default();
if !entry.contains(&conn_id) {
entry.push(conn_id);
}
}
pub fn unsubscribe(&mut self, conn_id: u64, topic: &str) {
if let Some(list) = self.subs.get_mut(topic) {
list.retain(|c| *c != conn_id);
}
}
pub fn dispatch(&mut self, topic: &str, payload: Vec<u8>) -> usize {
let Some(subs) = self.subs.get(topic).cloned() else {
return 0;
};
let mut delivered = 0usize;
for conn_id in subs {
if let Some(sender) = self.conns.get(&conn_id) {
let msg = RouterMsg::Sample {
topic: topic.to_string(),
payload: payload.clone(),
};
if sender.send(msg).is_ok() {
delivered += 1;
} else {
self.conns.remove(&conn_id);
}
}
}
delivered
}
pub fn broadcast_shutdown(&self) {
for sender in self.conns.values() {
let _ = sender.send(RouterMsg::Shutdown);
}
}
#[must_use]
pub fn connection_count(&self) -> usize {
self.conns.len()
}
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
use std::sync::mpsc::channel;
#[test]
fn dispatch_to_subscribed_connection() {
let mut router = Router::new();
let (tx, rx) = channel();
router.register_connection(1, tx);
router.subscribe(1, "Trade".to_string());
let n = router.dispatch("Trade", b"PAYLOAD".to_vec());
assert_eq!(n, 1);
match rx.recv().unwrap() {
RouterMsg::Sample { topic, payload } => {
assert_eq!(topic, "Trade");
assert_eq!(payload, b"PAYLOAD");
}
other => panic!("unexpected msg {other:?}"),
}
}
#[test]
fn dispatch_to_no_subscribers_is_zero() {
let mut router = Router::new();
let n = router.dispatch("Empty", b"x".to_vec());
assert_eq!(n, 0);
}
#[test]
fn unsubscribe_stops_delivery() {
let mut router = Router::new();
let (tx, rx) = channel();
router.register_connection(2, tx);
router.subscribe(2, "T".to_string());
router.unsubscribe(2, "T");
let n = router.dispatch("T", b"x".to_vec());
assert_eq!(n, 0);
assert!(rx.try_recv().is_err());
}
#[test]
fn deregister_removes_subscription() {
let mut router = Router::new();
let (tx, _rx) = channel();
router.register_connection(3, tx);
router.subscribe(3, "T".to_string());
router.deregister_connection(3);
assert_eq!(router.connection_count(), 0);
}
#[test]
fn shutdown_broadcasts_to_all() {
let mut router = Router::new();
let (tx, rx) = channel();
router.register_connection(7, tx);
router.broadcast_shutdown();
assert!(matches!(rx.recv().unwrap(), RouterMsg::Shutdown));
}
}