use std::{marker::PhantomData, sync::Arc, time::Duration};
use ahash::RandomState;
use ic_bn_lib_common::traits::pubsub::{Message, TopicId};
use moka::sync::{Cache, CacheBuilder};
use prometheus::{
IntCounter, IntGauge, Registry, register_int_counter_with_registry,
register_int_gauge_with_registry,
};
use tokio::sync::broadcast::{Receiver, Sender, error::RecvError};
#[derive(Clone, Debug)]
pub struct Opts {
pub max_topics: u64,
pub idle_timeout: Duration,
pub buffer_size: usize,
pub max_subscribers: usize,
}
impl Default for Opts {
fn default() -> Self {
Self {
max_topics: 1_000_000,
idle_timeout: Duration::from_secs(600),
buffer_size: 10_000,
max_subscribers: 10_000,
}
}
}
#[derive(Debug, Clone, Eq, PartialEq, thiserror::Error)]
pub enum PublishError {
#[error("Topic does not exist")]
TopicDoesNotExist,
#[error("Topic has no subscribers")]
NoSubscribers,
}
#[derive(Debug, Clone, Eq, PartialEq, thiserror::Error)]
pub enum SubscribeError {
#[error("Too many subscribers")]
TooManySubscribers,
}
#[derive(Debug, Clone)]
pub struct Metrics {
topics: IntGauge,
subscribers: IntGauge,
msgs_sent: IntCounter,
msgs_dropped: IntCounter,
}
impl Metrics {
pub fn new(registry: &Registry) -> Self {
Self {
topics: register_int_gauge_with_registry!(
format!("pubsub_topics"),
format!("Number of topics currently active"),
registry
)
.unwrap(),
msgs_sent: register_int_counter_with_registry!(
format!("pubsub_msgs_published"),
format!("Number of messages published"),
registry
)
.unwrap(),
msgs_dropped: register_int_counter_with_registry!(
format!("pubsub_msgs_dropped"),
format!("Number of messages dropped"),
registry
)
.unwrap(),
subscribers: register_int_gauge_with_registry!(
format!("pubsub_subscribers"),
format!("Number of subscribers currently active"),
registry
)
.unwrap(),
}
}
}
#[derive(Debug)]
pub struct Subscriber<M: Message> {
rx: Receiver<M>,
metrics: Arc<Metrics>,
}
impl<M: Message> Subscriber<M> {
pub async fn recv(&mut self) -> Result<M, RecvError> {
self.rx.recv().await
}
}
impl<M: Message> Drop for Subscriber<M> {
fn drop(&mut self) {
self.metrics.subscribers.dec();
}
}
#[derive(Debug, Clone)]
pub struct Topic<M: Message> {
tx: Sender<M>,
max_subscribers: usize,
metrics: Arc<Metrics>,
}
impl<M: Message> Topic<M> {
fn new(capacity: usize, metrics: Arc<Metrics>, max_subscribers: usize) -> Self {
metrics.topics.inc();
Self {
tx: Sender::new(capacity),
max_subscribers,
metrics,
}
}
pub fn subscriber_count(&self) -> usize {
self.tx.receiver_count()
}
pub fn subscribe(&self) -> Result<Subscriber<M>, SubscribeError> {
if self.tx.receiver_count() >= self.max_subscribers {
return Err(SubscribeError::TooManySubscribers);
}
self.metrics.subscribers.inc();
Ok(Subscriber {
rx: self.tx.subscribe(),
metrics: self.metrics.clone(),
})
}
pub fn publish(&self, message: M) -> Result<usize, PublishError> {
self.tx.send(message).map_or_else(
|_| {
self.metrics.msgs_dropped.inc();
Err(PublishError::NoSubscribers)
},
|v| {
self.metrics.msgs_sent.inc();
Ok(v)
},
)
}
}
impl<M: Message> Drop for Topic<M> {
fn drop(&mut self) {
self.metrics.topics.dec();
}
}
#[derive(Debug, Clone)]
pub struct Broker<M: Message, T: TopicId> {
opts: Opts,
topics: Cache<T, Arc<Topic<M>>, RandomState>,
metrics: Arc<Metrics>,
}
impl<M: Message, T: TopicId> Broker<M, T> {
pub fn new(opts: Opts, metrics: Metrics) -> Self {
let metrics = Arc::new(metrics);
let topics = CacheBuilder::new(opts.max_topics)
.time_to_idle(opts.idle_timeout)
.build_with_hasher(RandomState::new());
Self {
opts,
topics,
metrics,
}
}
pub fn topic_get(&self, topic: &T) -> Option<Arc<Topic<M>>> {
self.topics.get(topic)
}
pub fn topic_get_or_create(&self, topic: &T) -> Arc<Topic<M>> {
self.topics.get_with_by_ref(topic, || {
Arc::new(Topic::new(
self.opts.buffer_size,
self.metrics.clone(),
self.opts.max_subscribers,
))
})
}
pub fn topic_exists(&self, topic: &T) -> bool {
self.topics.contains_key(topic)
}
pub fn topic_remove(&self, topic: &T) {
self.topics.invalidate(topic);
self.topics.run_pending_tasks();
}
pub fn subscribe(&self, topic: &T) -> Result<Subscriber<M>, SubscribeError> {
let topic = self.topic_get_or_create(topic);
topic.subscribe()
}
pub fn publish(&self, topic: &T, message: M) -> Result<usize, PublishError> {
let Some(topic) = self.topic_get(topic) else {
self.metrics.msgs_dropped.inc();
return Err(PublishError::TopicDoesNotExist);
};
topic.publish(message)
}
}
pub struct BrokerBuilder<M, T> {
opts: Opts,
metrics: Metrics,
_m: PhantomData<M>,
_t: PhantomData<T>,
}
impl<M: Message, T: TopicId> Default for BrokerBuilder<M, T> {
fn default() -> Self {
Self::new()
}
}
impl<M: Message, T: TopicId> BrokerBuilder<M, T> {
pub fn new() -> Self {
Self {
opts: Opts::default(),
metrics: Metrics::new(&Registry::new()),
_m: PhantomData,
_t: PhantomData,
}
}
pub const fn with_max_topics(mut self, max_topics: u64) -> Self {
self.opts.max_topics = max_topics;
self
}
pub const fn with_idle_timeout(mut self, idle_timeout: Duration) -> Self {
self.opts.idle_timeout = idle_timeout;
self
}
pub const fn with_buffer_size(mut self, buffer_size: usize) -> Self {
self.opts.buffer_size = buffer_size;
self
}
pub const fn with_max_subscribers(mut self, max_subscribers: usize) -> Self {
self.opts.max_subscribers = max_subscribers;
self
}
pub fn with_metrics(mut self, metrics: Metrics) -> Self {
self.metrics = metrics;
self
}
pub fn with_metric_registry(mut self, registry: &Registry) -> Self {
self.metrics = Metrics::new(registry);
self
}
pub fn build(self) -> Broker<M, T> {
Broker::new(self.opts, self.metrics)
}
}
#[cfg(test)]
mod test {
use super::*;
#[tokio::test]
async fn test_pubsub() {
let b: Broker<String, String> = BrokerBuilder::new()
.with_buffer_size(3)
.with_max_subscribers(1)
.build();
let topic1 = "foo".to_string();
let topic2 = "dead".to_string();
assert_eq!(
b.publish(&topic1, "".into()),
Err(PublishError::TopicDoesNotExist)
);
assert_eq!(
b.publish(&topic2, "".into()),
Err(PublishError::TopicDoesNotExist)
);
assert_eq!(b.metrics.topics.get(), 0);
assert_eq!(b.metrics.msgs_dropped.get(), 2);
let mut t1_sub = b.subscribe(&topic1).unwrap();
let mut t2_sub = b.subscribe(&topic2).unwrap();
assert!(b.topic_exists(&topic1));
assert!(b.topic_exists(&topic2));
assert_eq!(b.metrics.topics.get(), 2);
assert_eq!(
b.subscribe(&topic1).unwrap_err(),
SubscribeError::TooManySubscribers
);
assert_eq!(b.metrics.subscribers.get(), 2);
assert_eq!(b.publish(&topic1, "bar1".into()), Ok(1));
assert_eq!(b.publish(&topic2, "beef1".into()), Ok(1));
assert_eq!(b.publish(&topic1, "bar2".into()), Ok(1));
assert_eq!(b.publish(&topic2, "beef2".into()), Ok(1));
assert_eq!(b.publish(&topic1, "bar3".into()), Ok(1));
assert_eq!(b.publish(&topic2, "beef3".into()), Ok(1));
assert_eq!(b.metrics.msgs_sent.get(), 6);
assert_eq!(t1_sub.recv().await.unwrap(), "bar1");
assert_eq!(t2_sub.recv().await.unwrap(), "beef1");
assert_eq!(t1_sub.recv().await.unwrap(), "bar2");
assert_eq!(t2_sub.recv().await.unwrap(), "beef2");
assert_eq!(t1_sub.recv().await.unwrap(), "bar3");
assert_eq!(t2_sub.recv().await.unwrap(), "beef3");
assert_eq!(b.publish(&topic1, "bar1".into()), Ok(1));
assert_eq!(b.publish(&topic2, "beef1".into()), Ok(1));
assert_eq!(b.publish(&topic1, "bar2".into()), Ok(1));
assert_eq!(b.publish(&topic2, "beef2".into()), Ok(1));
assert_eq!(b.publish(&topic1, "bar3".into()), Ok(1));
assert_eq!(b.publish(&topic2, "beef3".into()), Ok(1));
assert_eq!(b.publish(&topic1, "bar4".into()), Ok(1));
assert_eq!(b.publish(&topic2, "beef4".into()), Ok(1));
assert_eq!(b.publish(&topic1, "bar5".into()), Ok(1));
assert_eq!(b.publish(&topic2, "beef5".into()), Ok(1));
assert!(matches!(
t1_sub.recv().await.unwrap_err(),
RecvError::Lagged(_)
));
assert!(matches!(
t2_sub.recv().await.unwrap_err(),
RecvError::Lagged(_)
));
assert_eq!(t1_sub.recv().await.unwrap(), "bar2");
assert_eq!(t2_sub.recv().await.unwrap(), "beef2");
assert_eq!(t1_sub.recv().await.unwrap(), "bar3");
assert_eq!(t2_sub.recv().await.unwrap(), "beef3");
assert_eq!(t1_sub.recv().await.unwrap(), "bar4");
assert_eq!(t2_sub.recv().await.unwrap(), "beef4");
assert_eq!(t1_sub.recv().await.unwrap(), "bar5");
assert_eq!(t2_sub.recv().await.unwrap(), "beef5");
drop(t1_sub);
drop(t2_sub);
assert_eq!(b.metrics.subscribers.get(), 0);
assert_eq!(b.metrics.topics.get(), 2);
assert_eq!(
b.publish(&topic1, "".into()).unwrap_err(),
PublishError::NoSubscribers
);
assert_eq!(
b.publish(&topic2, "".into()).unwrap_err(),
PublishError::NoSubscribers
);
let t1 = b.topic_get_or_create(&topic1);
let t2 = b.topic_get_or_create(&topic2);
let mut t1_sub = t1.subscribe().unwrap();
let mut t2_sub = t2.subscribe().unwrap();
assert_eq!(t1.publish("foo".into()).unwrap(), 1);
assert_eq!(t2.publish("bar".into()).unwrap(), 1);
assert_eq!(t1_sub.recv().await.unwrap(), "foo");
assert_eq!(t2_sub.recv().await.unwrap(), "bar");
b.topic_remove(&topic1);
b.topic_remove(&topic2);
drop(t1);
drop(t2);
assert_eq!(t1_sub.recv().await.unwrap_err(), RecvError::Closed);
assert_eq!(t2_sub.recv().await.unwrap_err(), RecvError::Closed);
}
}