use std::collections::HashSet;
use tokio::sync::broadcast;
use crate::notify::Notification;
const DEFAULT_CAPACITY: usize = 256;
pub struct NotificationDispatcher {
sender: broadcast::Sender<Notification>,
channels: HashSet<String>,
}
impl NotificationDispatcher {
pub fn new() -> Self {
Self::with_capacity(DEFAULT_CAPACITY)
}
pub fn with_capacity(capacity: usize) -> Self {
let (sender, _) = broadcast::channel(capacity);
Self {
sender,
channels: HashSet::new(),
}
}
pub fn subscribe(&self) -> NotificationReceiver {
NotificationReceiver {
receiver: self.sender.subscribe(),
}
}
pub fn dispatch(&self, notification: Notification) -> usize {
self.sender.send(notification).unwrap_or(0)
}
pub fn add_channel(&mut self, channel: String) {
self.channels.insert(channel);
}
pub fn remove_channel(&mut self, channel: &str) {
self.channels.remove(channel);
}
pub fn channels(&self) -> &HashSet<String> {
&self.channels
}
pub fn subscriber_count(&self) -> usize {
self.sender.receiver_count()
}
}
impl Default for NotificationDispatcher {
fn default() -> Self {
Self::new()
}
}
pub struct NotificationReceiver {
receiver: broadcast::Receiver<Notification>,
}
impl NotificationReceiver {
pub async fn recv(&mut self) -> Option<Notification> {
loop {
match self.receiver.recv().await {
Ok(notification) => return Some(notification),
Err(broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!(count = n, "notification receiver lagged, skipped messages");
}
Err(broadcast::error::RecvError::Closed) => return None,
}
}
}
}