use crate::ServerMessage;
use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
use tokio::sync::broadcast;
const DEFAULT_TOPIC_CAPACITY: usize = 1024;
#[derive(Debug, Clone)]
pub struct PubSub {
topics: Arc<Mutex<HashMap<String, broadcast::Sender<PubSubMessage>>>>,
topic_capacity: usize,
}
impl Default for PubSub {
fn default() -> Self {
Self::new(DEFAULT_TOPIC_CAPACITY)
}
}
impl PubSub {
pub fn new(topic_capacity: usize) -> Self {
Self {
topics: Arc::new(Mutex::new(HashMap::new())),
topic_capacity,
}
}
pub fn subscribe(&self, topic: impl Into<String>) -> PubSubSubscription {
let topic = topic.into();
let sender = self.sender_for(&topic);
PubSubSubscription {
receiver: sender.subscribe(),
}
}
pub fn broadcast(&self, topic: impl Into<String>, messages: Vec<ServerMessage>) -> usize {
let topic = topic.into();
let sender = self.sender_for(&topic);
sender
.send(PubSubMessage { topic, messages })
.unwrap_or_default()
}
fn sender_for(&self, topic: &str) -> broadcast::Sender<PubSubMessage> {
let mut topics = self.topics.lock().expect("pubsub topic mutex poisoned");
topics
.entry(topic.to_string())
.or_insert_with(|| {
let (sender, _) = broadcast::channel(self.topic_capacity);
sender
})
.clone()
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct PubSubMessage {
pub topic: String,
pub messages: Vec<ServerMessage>,
}
pub struct PubSubSubscription {
receiver: broadcast::Receiver<PubSubMessage>,
}
impl PubSubSubscription {
pub async fn recv(&mut self) -> Result<PubSubMessage, broadcast::error::RecvError> {
self.receiver.recv().await
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum PubSubCommand {
Subscribe {
topic: String,
},
Broadcast {
topic: String,
messages: Vec<ServerMessage>,
},
}
#[cfg(test)]
mod tests {
use super::PubSub;
use crate::ServerMessage;
#[tokio::test]
async fn in_process_pubsub_broadcasts_to_subscribers() {
let pubsub = PubSub::default();
let mut first = pubsub.subscribe("chat:lobby");
let mut second = pubsub.subscribe("chat:lobby");
assert_eq!(
pubsub.broadcast(
"chat:lobby",
vec![ServerMessage::Redirect {
to: "/ok".to_string()
}]
),
2
);
assert_eq!(first.recv().await.unwrap().topic, "chat:lobby");
assert_eq!(
second.recv().await.unwrap().messages,
vec![ServerMessage::Redirect {
to: "/ok".to_string()
}]
);
}
}