use crate::ServerMessage;
use std::{
collections::{BTreeMap, HashMap, HashSet},
future::Future,
pin::Pin,
sync::{Arc, Mutex},
};
use tokio::sync::broadcast;
const DEFAULT_TOPIC_CAPACITY: usize = 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PubSubDeliveryScope {
LocalProcess,
Cluster,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PubSubOrdering {
PerTopicOrdered,
BestEffort,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SessionAffinityRequirement {
None,
StatefulSessionRequired,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PubSubCapabilities {
pub backend: String,
pub delivery_scope: PubSubDeliveryScope,
pub ordering: PubSubOrdering,
pub session_affinity: SessionAffinityRequirement,
pub presence_tracking: bool,
}
impl PubSubCapabilities {
fn in_process() -> Self {
Self {
backend: "in_process".to_string(),
delivery_scope: PubSubDeliveryScope::LocalProcess,
ordering: PubSubOrdering::PerTopicOrdered,
session_affinity: SessionAffinityRequirement::StatefulSessionRequired,
presence_tracking: true,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PubSubPresenceSnapshot {
pub topic: String,
pub total_sessions: usize,
pub by_node: BTreeMap<String, usize>,
}
impl PubSubPresenceSnapshot {
fn empty(topic: &str) -> Self {
Self {
topic: topic.to_string(),
total_sessions: 0,
by_node: BTreeMap::new(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PubSubReceiveError {
Closed,
Lagged(u64),
}
type PubSubRecvFuture<'a> =
Pin<Box<dyn Future<Output = Result<PubSubMessage, PubSubReceiveError>> + Send + 'a>>;
pub trait PubSubSubscriptionHandle: Send {
fn recv(&mut self) -> PubSubRecvFuture<'_>;
}
struct BroadcastSubscriptionHandle {
receiver: broadcast::Receiver<PubSubMessage>,
}
impl PubSubSubscriptionHandle for BroadcastSubscriptionHandle {
fn recv(&mut self) -> PubSubRecvFuture<'_> {
Box::pin(async move {
self.receiver.recv().await.map_err(|err| match err {
broadcast::error::RecvError::Closed => PubSubReceiveError::Closed,
broadcast::error::RecvError::Lagged(skipped) => PubSubReceiveError::Lagged(skipped),
})
})
}
}
pub trait PubSubBackend: Send + Sync {
fn subscribe(&self, topic: &str) -> PubSubSubscription;
fn broadcast(&self, topic: &str, messages: Vec<ServerMessage>) -> usize;
fn capabilities(&self) -> PubSubCapabilities;
fn register_presence(&self, _topic: &str, _session_id: &str, _node_id: &str) {}
fn unregister_presence(&self, _topic: &str, _session_id: &str, _node_id: &str) {}
fn presence_snapshot(&self, topic: &str) -> PubSubPresenceSnapshot {
PubSubPresenceSnapshot::empty(topic)
}
}
#[derive(Debug)]
struct InProcessPubSubBackend {
topics: Arc<Mutex<HashMap<String, broadcast::Sender<PubSubMessage>>>>,
presence: Arc<Mutex<HashMap<String, HashMap<String, HashSet<String>>>>>,
topic_capacity: usize,
}
impl InProcessPubSubBackend {
fn new(topic_capacity: usize) -> Self {
Self {
topics: Arc::new(Mutex::new(HashMap::new())),
presence: Arc::new(Mutex::new(HashMap::new())),
topic_capacity,
}
}
fn sender_for(&self, topic: &str) -> broadcast::Sender<PubSubMessage> {
let mut topics = self.topics.lock().expect("pubsub topic mutex poisoned");
topics
.entry(topic.to_string())
.or_insert_with(|| {
let (sender, _) = broadcast::channel(self.topic_capacity);
sender
})
.clone()
}
}
impl PubSubBackend for InProcessPubSubBackend {
fn subscribe(&self, topic: &str) -> PubSubSubscription {
let sender = self.sender_for(topic);
PubSubSubscription::new(BroadcastSubscriptionHandle {
receiver: sender.subscribe(),
})
}
fn broadcast(&self, topic: &str, messages: Vec<ServerMessage>) -> usize {
let sender = self.sender_for(topic);
sender
.send(PubSubMessage {
topic: topic.to_string(),
messages,
})
.unwrap_or_default()
}
fn capabilities(&self) -> PubSubCapabilities {
PubSubCapabilities::in_process()
}
fn register_presence(&self, topic: &str, session_id: &str, node_id: &str) {
let mut presence = self
.presence
.lock()
.expect("pubsub presence mutex poisoned");
presence
.entry(topic.to_string())
.or_default()
.entry(node_id.to_string())
.or_default()
.insert(session_id.to_string());
}
fn unregister_presence(&self, topic: &str, session_id: &str, node_id: &str) {
let mut presence = self
.presence
.lock()
.expect("pubsub presence mutex poisoned");
let mut remove_topic = false;
if let Some(by_node) = presence.get_mut(topic) {
if let Some(sessions) = by_node.get_mut(node_id) {
sessions.remove(session_id);
if sessions.is_empty() {
by_node.remove(node_id);
}
}
remove_topic = by_node.is_empty();
}
if remove_topic {
presence.remove(topic);
}
}
fn presence_snapshot(&self, topic: &str) -> PubSubPresenceSnapshot {
let presence = self
.presence
.lock()
.expect("pubsub presence mutex poisoned");
let Some(by_node) = presence.get(topic) else {
return PubSubPresenceSnapshot::empty(topic);
};
let mut snapshot = PubSubPresenceSnapshot {
topic: topic.to_string(),
total_sessions: 0,
by_node: BTreeMap::new(),
};
for (node_id, sessions) in by_node {
snapshot.total_sessions += sessions.len();
snapshot.by_node.insert(node_id.clone(), sessions.len());
}
snapshot
}
}
#[derive(Clone)]
pub struct PubSub {
backend: Arc<dyn PubSubBackend>,
}
impl std::fmt::Debug for PubSub {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PubSub")
.field("capabilities", &self.capabilities())
.finish()
}
}
impl Default for PubSub {
fn default() -> Self {
Self::new(DEFAULT_TOPIC_CAPACITY)
}
}
impl PubSub {
pub fn new(topic_capacity: usize) -> Self {
Self::with_backend(InProcessPubSubBackend::new(topic_capacity))
}
pub fn with_backend<B>(backend: B) -> Self
where
B: PubSubBackend + 'static,
{
Self {
backend: Arc::new(backend),
}
}
pub fn subscribe(&self, topic: impl Into<String>) -> PubSubSubscription {
let topic = topic.into();
self.backend.subscribe(&topic)
}
pub fn broadcast(&self, topic: impl Into<String>, messages: Vec<ServerMessage>) -> usize {
let topic = topic.into();
self.backend.broadcast(&topic, messages)
}
pub fn capabilities(&self) -> PubSubCapabilities {
self.backend.capabilities()
}
pub fn register_presence(
&self,
topic: impl Into<String>,
session_id: impl Into<String>,
node_id: impl Into<String>,
) {
let topic = topic.into();
let session_id = session_id.into();
let node_id = node_id.into();
self.backend
.register_presence(&topic, &session_id, &node_id);
}
pub fn unregister_presence(
&self,
topic: impl Into<String>,
session_id: impl Into<String>,
node_id: impl Into<String>,
) {
let topic = topic.into();
let session_id = session_id.into();
let node_id = node_id.into();
self.backend
.unregister_presence(&topic, &session_id, &node_id);
}
pub fn presence_snapshot(&self, topic: impl Into<String>) -> PubSubPresenceSnapshot {
let topic = topic.into();
self.backend.presence_snapshot(&topic)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct PubSubMessage {
pub topic: String,
pub messages: Vec<ServerMessage>,
}
pub struct PubSubSubscription {
inner: Box<dyn PubSubSubscriptionHandle>,
}
impl PubSubSubscription {
pub fn new<H>(handle: H) -> Self
where
H: PubSubSubscriptionHandle + 'static,
{
Self {
inner: Box::new(handle),
}
}
pub async fn recv(&mut self) -> Result<PubSubMessage, PubSubReceiveError> {
self.inner.recv().await
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum PubSubCommand {
Subscribe {
topic: String,
},
Broadcast {
topic: String,
messages: Vec<ServerMessage>,
},
}
#[cfg(test)]
mod tests {
use super::{
BroadcastSubscriptionHandle, PubSub, PubSubBackend, PubSubCapabilities,
PubSubDeliveryScope, PubSubMessage, PubSubOrdering, PubSubSubscription,
SessionAffinityRequirement,
};
use crate::ServerMessage;
use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
use tokio::sync::broadcast;
#[tokio::test]
async fn in_process_pubsub_broadcasts_to_subscribers() {
let pubsub = PubSub::default();
let mut first = pubsub.subscribe("chat:lobby");
let mut second = pubsub.subscribe("chat:lobby");
assert_eq!(
pubsub.broadcast(
"chat:lobby",
vec![ServerMessage::Redirect {
to: "/ok".to_string()
}]
),
2
);
assert_eq!(first.recv().await.unwrap().topic, "chat:lobby");
assert_eq!(
second.recv().await.unwrap().messages,
vec![ServerMessage::Redirect {
to: "/ok".to_string()
}]
);
}
#[test]
fn in_process_pubsub_reports_cluster_capabilities_and_presence() {
let pubsub = PubSub::default();
let capabilities = pubsub.capabilities();
assert_eq!(capabilities.backend, "in_process");
assert_eq!(
capabilities.delivery_scope,
PubSubDeliveryScope::LocalProcess
);
assert_eq!(capabilities.ordering, PubSubOrdering::PerTopicOrdered);
assert_eq!(
capabilities.session_affinity,
SessionAffinityRequirement::StatefulSessionRequired
);
assert!(capabilities.presence_tracking);
pubsub.register_presence("chat:lobby", "s1", "node-a");
pubsub.register_presence("chat:lobby", "s2", "node-a");
pubsub.register_presence("chat:lobby", "s3", "node-b");
let snapshot = pubsub.presence_snapshot("chat:lobby");
assert_eq!(snapshot.topic, "chat:lobby");
assert_eq!(snapshot.total_sessions, 3);
assert_eq!(snapshot.by_node.get("node-a"), Some(&2));
assert_eq!(snapshot.by_node.get("node-b"), Some(&1));
pubsub.unregister_presence("chat:lobby", "s2", "node-a");
let after = pubsub.presence_snapshot("chat:lobby");
assert_eq!(after.total_sessions, 2);
assert_eq!(after.by_node.get("node-a"), Some(&1));
}
#[derive(Debug, Clone)]
struct SharedHub {
topics: Arc<Mutex<HashMap<String, broadcast::Sender<PubSubMessage>>>>,
}
impl SharedHub {
fn new() -> Self {
Self {
topics: Arc::new(Mutex::new(HashMap::new())),
}
}
fn sender_for(&self, topic: &str) -> broadcast::Sender<PubSubMessage> {
let mut topics = self.topics.lock().expect("hub mutex poisoned");
topics
.entry(topic.to_string())
.or_insert_with(|| {
let (tx, _) = broadcast::channel(256);
tx
})
.clone()
}
}
#[derive(Debug, Clone)]
struct MockClusterBackend {
hub: SharedHub,
}
impl PubSubBackend for MockClusterBackend {
fn subscribe(&self, topic: &str) -> PubSubSubscription {
let receiver = self.hub.sender_for(topic).subscribe();
PubSubSubscription::new(BroadcastSubscriptionHandle { receiver })
}
fn broadcast(&self, topic: &str, messages: Vec<ServerMessage>) -> usize {
self.hub
.sender_for(topic)
.send(PubSubMessage {
topic: topic.to_string(),
messages,
})
.unwrap_or_default()
}
fn capabilities(&self) -> PubSubCapabilities {
PubSubCapabilities {
backend: "mock_cluster".to_string(),
delivery_scope: PubSubDeliveryScope::Cluster,
ordering: PubSubOrdering::BestEffort,
session_affinity: SessionAffinityRequirement::StatefulSessionRequired,
presence_tracking: false,
}
}
}
#[tokio::test]
async fn custom_backend_can_fanout_across_multiple_pubsub_instances() {
let hub = SharedHub::new();
let node_a = PubSub::with_backend(MockClusterBackend { hub: hub.clone() });
let node_b = PubSub::with_backend(MockClusterBackend { hub });
let mut subscription = node_a.subscribe("cluster:lobby");
assert_eq!(
node_b.broadcast(
"cluster:lobby",
vec![ServerMessage::Error {
message: "hello".to_string(),
code: Some("cluster".to_string()),
}]
),
1
);
let delivered = subscription.recv().await.unwrap();
assert_eq!(delivered.topic, "cluster:lobby");
assert_eq!(delivered.messages.len(), 1);
match &delivered.messages[0] {
ServerMessage::Error { message, code } => {
assert_eq!(message, "hello");
assert_eq!(code.as_deref(), Some("cluster"));
}
other => panic!("unexpected payload: {other:?}"),
}
}
}