1use std::{marker::PhantomData, sync::Arc, time::Duration};
2
3use ahash::RandomState;
4use ic_bn_lib_common::traits::pubsub::{Message, TopicId};
5use moka::sync::{Cache, CacheBuilder};
6use prometheus::{
7 IntCounter, IntGauge, Registry, register_int_counter_with_registry,
8 register_int_gauge_with_registry,
9};
10use tokio::sync::broadcast::{Receiver, Sender, error::RecvError};
11
12#[derive(Clone, Debug)]
14pub struct Opts {
15 pub max_topics: u64,
18 pub idle_timeout: Duration,
20 pub buffer_size: usize,
23 pub max_subscribers: usize,
26}
27
28impl Default for Opts {
29 fn default() -> Self {
30 Self {
31 max_topics: 1_000_000,
32 idle_timeout: Duration::from_secs(600),
33 buffer_size: 10_000,
34 max_subscribers: 10_000,
35 }
36 }
37}
38
39#[derive(Debug, Clone, Eq, PartialEq, thiserror::Error)]
41pub enum PublishError {
42 #[error("Topic does not exist")]
43 TopicDoesNotExist,
44 #[error("Topic has no subscribers")]
45 NoSubscribers,
46}
47
48#[derive(Debug, Clone, Eq, PartialEq, thiserror::Error)]
50pub enum SubscribeError {
51 #[error("Too many subscribers")]
52 TooManySubscribers,
53}
54
55#[derive(Debug, Clone)]
57pub struct Metrics {
58 topics: IntGauge,
59 subscribers: IntGauge,
60 msgs_sent: IntCounter,
61 msgs_dropped: IntCounter,
62}
63
64impl Metrics {
65 pub fn new(registry: &Registry) -> Self {
66 Self {
67 topics: register_int_gauge_with_registry!(
68 format!("pubsub_topics"),
69 format!("Number of topics currently active"),
70 registry
71 )
72 .unwrap(),
73
74 msgs_sent: register_int_counter_with_registry!(
75 format!("pubsub_msgs_published"),
76 format!("Number of messages published"),
77 registry
78 )
79 .unwrap(),
80
81 msgs_dropped: register_int_counter_with_registry!(
82 format!("pubsub_msgs_dropped"),
83 format!("Number of messages dropped"),
84 registry
85 )
86 .unwrap(),
87
88 subscribers: register_int_gauge_with_registry!(
89 format!("pubsub_subscribers"),
90 format!("Number of subscribers currently active"),
91 registry
92 )
93 .unwrap(),
94 }
95 }
96}
97
98#[derive(Debug)]
100pub struct Subscriber<M: Message> {
101 rx: Receiver<M>,
102 metrics: Arc<Metrics>,
103}
104
105impl<M: Message> Subscriber<M> {
106 pub async fn recv(&mut self) -> Result<M, RecvError> {
108 self.rx.recv().await
109 }
110}
111
112impl<M: Message> Drop for Subscriber<M> {
113 fn drop(&mut self) {
114 self.metrics.subscribers.dec();
116 }
117}
118
119#[derive(Debug, Clone)]
121pub struct Topic<M: Message> {
122 tx: Sender<M>,
123 max_subscribers: usize,
124 metrics: Arc<Metrics>,
125}
126
127impl<M: Message> Topic<M> {
128 fn new(capacity: usize, metrics: Arc<Metrics>, max_subscribers: usize) -> Self {
129 metrics.topics.inc();
130
131 Self {
132 tx: Sender::new(capacity),
133 max_subscribers,
134 metrics,
135 }
136 }
137
138 pub fn subscriber_count(&self) -> usize {
140 self.tx.receiver_count()
141 }
142
143 pub fn subscribe(&self) -> Result<Subscriber<M>, SubscribeError> {
146 if self.tx.receiver_count() >= self.max_subscribers {
148 return Err(SubscribeError::TooManySubscribers);
149 }
150
151 self.metrics.subscribers.inc();
152 Ok(Subscriber {
153 rx: self.tx.subscribe(),
154 metrics: self.metrics.clone(),
155 })
156 }
157
158 pub fn publish(&self, message: M) -> Result<usize, PublishError> {
160 self.tx.send(message).map_or_else(
161 |_| {
162 self.metrics.msgs_dropped.inc();
163 Err(PublishError::NoSubscribers)
164 },
165 |v| {
166 self.metrics.msgs_sent.inc();
167 Ok(v)
168 },
169 )
170 }
171}
172
173impl<M: Message> Drop for Topic<M> {
174 fn drop(&mut self) {
175 self.metrics.topics.dec();
177 }
178}
179
180#[derive(Debug, Clone)]
182pub struct Broker<M: Message, T: TopicId> {
183 opts: Opts,
184 topics: Cache<T, Arc<Topic<M>>, RandomState>,
185 metrics: Arc<Metrics>,
186}
187
188impl<M: Message, T: TopicId> Broker<M, T> {
189 pub fn new(opts: Opts, metrics: Metrics) -> Self {
191 let metrics = Arc::new(metrics);
192
193 let topics = CacheBuilder::new(opts.max_topics)
194 .time_to_idle(opts.idle_timeout)
195 .build_with_hasher(RandomState::new());
196
197 Self {
198 opts,
199 topics,
200 metrics,
201 }
202 }
203
204 pub fn topic_get(&self, topic: &T) -> Option<Arc<Topic<M>>> {
206 self.topics.get(topic)
207 }
208
209 pub fn topic_get_or_create(&self, topic: &T) -> Arc<Topic<M>> {
211 self.topics.get_with_by_ref(topic, || {
212 Arc::new(Topic::new(
213 self.opts.buffer_size,
214 self.metrics.clone(),
215 self.opts.max_subscribers,
216 ))
217 })
218 }
219
220 pub fn topic_exists(&self, topic: &T) -> bool {
222 self.topics.contains_key(topic)
223 }
224
225 pub fn topic_remove(&self, topic: &T) {
228 self.topics.invalidate(topic);
229 self.topics.run_pending_tasks();
230 }
231
232 pub fn subscribe(&self, topic: &T) -> Result<Subscriber<M>, SubscribeError> {
234 let topic = self.topic_get_or_create(topic);
235 topic.subscribe()
236 }
237
238 pub fn publish(&self, topic: &T, message: M) -> Result<usize, PublishError> {
240 let Some(topic) = self.topic_get(topic) else {
242 self.metrics.msgs_dropped.inc();
243 return Err(PublishError::TopicDoesNotExist);
244 };
245
246 topic.publish(message)
247 }
248}
249
250pub struct BrokerBuilder<M, T> {
252 opts: Opts,
253 metrics: Metrics,
254 _m: PhantomData<M>,
255 _t: PhantomData<T>,
256}
257
258impl<M: Message, T: TopicId> Default for BrokerBuilder<M, T> {
259 fn default() -> Self {
260 Self::new()
261 }
262}
263
264impl<M: Message, T: TopicId> BrokerBuilder<M, T> {
265 pub fn new() -> Self {
267 Self {
268 opts: Opts::default(),
269 metrics: Metrics::new(&Registry::new()),
270 _m: PhantomData,
271 _t: PhantomData,
272 }
273 }
274
275 pub const fn with_max_topics(mut self, max_topics: u64) -> Self {
277 self.opts.max_topics = max_topics;
278 self
279 }
280
281 pub const fn with_idle_timeout(mut self, idle_timeout: Duration) -> Self {
283 self.opts.idle_timeout = idle_timeout;
284 self
285 }
286
287 pub const fn with_buffer_size(mut self, buffer_size: usize) -> Self {
289 self.opts.buffer_size = buffer_size;
290 self
291 }
292
293 pub const fn with_max_subscribers(mut self, max_subscribers: usize) -> Self {
295 self.opts.max_subscribers = max_subscribers;
296 self
297 }
298
299 pub fn with_metrics(mut self, metrics: Metrics) -> Self {
301 self.metrics = metrics;
302 self
303 }
304
305 pub fn with_metric_registry(mut self, registry: &Registry) -> Self {
307 self.metrics = Metrics::new(registry);
308 self
309 }
310
311 pub fn build(self) -> Broker<M, T> {
313 Broker::new(self.opts, self.metrics)
314 }
315}
316
317#[cfg(test)]
318mod test {
319 use super::*;
320
321 #[tokio::test]
322 async fn test_pubsub() {
323 let b: Broker<String, String> = BrokerBuilder::new()
324 .with_buffer_size(3)
325 .with_max_subscribers(1)
326 .build();
327
328 let topic1 = "foo".to_string();
329 let topic2 = "dead".to_string();
330
331 assert_eq!(
333 b.publish(&topic1, "".into()),
334 Err(PublishError::TopicDoesNotExist)
335 );
336 assert_eq!(
337 b.publish(&topic2, "".into()),
338 Err(PublishError::TopicDoesNotExist)
339 );
340 assert_eq!(b.metrics.topics.get(), 0);
341 assert_eq!(b.metrics.msgs_dropped.get(), 2);
342
343 let mut t1_sub = b.subscribe(&topic1).unwrap();
345 let mut t2_sub = b.subscribe(&topic2).unwrap();
346 assert!(b.topic_exists(&topic1));
347 assert!(b.topic_exists(&topic2));
348 assert_eq!(b.metrics.topics.get(), 2);
349
350 assert_eq!(
352 b.subscribe(&topic1).unwrap_err(),
353 SubscribeError::TooManySubscribers
354 );
355 assert_eq!(b.metrics.subscribers.get(), 2);
356
357 assert_eq!(b.publish(&topic1, "bar1".into()), Ok(1));
359 assert_eq!(b.publish(&topic2, "beef1".into()), Ok(1));
360 assert_eq!(b.publish(&topic1, "bar2".into()), Ok(1));
361 assert_eq!(b.publish(&topic2, "beef2".into()), Ok(1));
362 assert_eq!(b.publish(&topic1, "bar3".into()), Ok(1));
363 assert_eq!(b.publish(&topic2, "beef3".into()), Ok(1));
364 assert_eq!(b.metrics.msgs_sent.get(), 6);
365
366 assert_eq!(t1_sub.recv().await.unwrap(), "bar1");
367 assert_eq!(t2_sub.recv().await.unwrap(), "beef1");
368 assert_eq!(t1_sub.recv().await.unwrap(), "bar2");
369 assert_eq!(t2_sub.recv().await.unwrap(), "beef2");
370 assert_eq!(t1_sub.recv().await.unwrap(), "bar3");
371 assert_eq!(t2_sub.recv().await.unwrap(), "beef3");
372
373 assert_eq!(b.publish(&topic1, "bar1".into()), Ok(1));
376 assert_eq!(b.publish(&topic2, "beef1".into()), Ok(1));
377 assert_eq!(b.publish(&topic1, "bar2".into()), Ok(1));
378 assert_eq!(b.publish(&topic2, "beef2".into()), Ok(1));
379 assert_eq!(b.publish(&topic1, "bar3".into()), Ok(1));
380 assert_eq!(b.publish(&topic2, "beef3".into()), Ok(1));
381 assert_eq!(b.publish(&topic1, "bar4".into()), Ok(1));
382 assert_eq!(b.publish(&topic2, "beef4".into()), Ok(1));
383 assert_eq!(b.publish(&topic1, "bar5".into()), Ok(1));
384 assert_eq!(b.publish(&topic2, "beef5".into()), Ok(1));
385
386 assert!(matches!(
387 t1_sub.recv().await.unwrap_err(),
388 RecvError::Lagged(_)
389 ));
390 assert!(matches!(
391 t2_sub.recv().await.unwrap_err(),
392 RecvError::Lagged(_)
393 ));
394 assert_eq!(t1_sub.recv().await.unwrap(), "bar2");
395 assert_eq!(t2_sub.recv().await.unwrap(), "beef2");
396 assert_eq!(t1_sub.recv().await.unwrap(), "bar3");
397 assert_eq!(t2_sub.recv().await.unwrap(), "beef3");
398 assert_eq!(t1_sub.recv().await.unwrap(), "bar4");
399 assert_eq!(t2_sub.recv().await.unwrap(), "beef4");
400 assert_eq!(t1_sub.recv().await.unwrap(), "bar5");
401 assert_eq!(t2_sub.recv().await.unwrap(), "beef5");
402
403 drop(t1_sub);
405 drop(t2_sub);
406 assert_eq!(b.metrics.subscribers.get(), 0);
407 assert_eq!(b.metrics.topics.get(), 2);
408
409 assert_eq!(
411 b.publish(&topic1, "".into()).unwrap_err(),
412 PublishError::NoSubscribers
413 );
414 assert_eq!(
415 b.publish(&topic2, "".into()).unwrap_err(),
416 PublishError::NoSubscribers
417 );
418
419 let t1 = b.topic_get_or_create(&topic1);
421 let t2 = b.topic_get_or_create(&topic2);
422 let mut t1_sub = t1.subscribe().unwrap();
423 let mut t2_sub = t2.subscribe().unwrap();
424
425 assert_eq!(t1.publish("foo".into()).unwrap(), 1);
427 assert_eq!(t2.publish("bar".into()).unwrap(), 1);
428 assert_eq!(t1_sub.recv().await.unwrap(), "foo");
429 assert_eq!(t2_sub.recv().await.unwrap(), "bar");
430
431 b.topic_remove(&topic1);
433 b.topic_remove(&topic2);
434 drop(t1);
435 drop(t2);
436
437 assert_eq!(t1_sub.recv().await.unwrap_err(), RecvError::Closed);
439 assert_eq!(t2_sub.recv().await.unwrap_err(), RecvError::Closed);
440 }
441}