Skip to main content

shelly/
pubsub.rs

1use crate::ServerMessage;
2use std::{
3    collections::{BTreeMap, HashMap, HashSet},
4    future::Future,
5    pin::Pin,
6    sync::{Arc, Mutex},
7};
8use tokio::sync::broadcast;
9
10const DEFAULT_TOPIC_CAPACITY: usize = 1024;
11type TopicSenders = HashMap<String, broadcast::Sender<PubSubMessage>>;
12type NodePresenceMap = HashMap<String, HashSet<String>>;
13type TopicPresenceMap = HashMap<String, NodePresenceMap>;
14
15/// Delivery topology for one PubSub backend.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum PubSubDeliveryScope {
18    /// Fanout stays inside one process/runtime instance.
19    LocalProcess,
20    /// Fanout can be shared across multiple runtime instances.
21    Cluster,
22}
23
24/// Ordering contract for one PubSub backend.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum PubSubOrdering {
27    /// Messages are delivered in topic order.
28    PerTopicOrdered,
29    /// Ordering is best effort and may be reordered by backend behavior.
30    BestEffort,
31}
32
33/// Session-affinity contract for one PubSub backend.
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum SessionAffinityRequirement {
36    /// Session affinity is not required for backend fanout.
37    None,
38    /// Session affinity is required for stateful reconnect/session continuity.
39    StatefulSessionRequired,
40}
41
42/// Cluster capabilities advertised by a PubSub backend.
43#[derive(Debug, Clone, PartialEq, Eq)]
44pub struct PubSubCapabilities {
45    pub backend: String,
46    pub delivery_scope: PubSubDeliveryScope,
47    pub ordering: PubSubOrdering,
48    pub session_affinity: SessionAffinityRequirement,
49    pub presence_tracking: bool,
50}
51
52impl PubSubCapabilities {
53    fn in_process() -> Self {
54        Self {
55            backend: "in_process".to_string(),
56            delivery_scope: PubSubDeliveryScope::LocalProcess,
57            ordering: PubSubOrdering::PerTopicOrdered,
58            session_affinity: SessionAffinityRequirement::StatefulSessionRequired,
59            presence_tracking: true,
60        }
61    }
62}
63
64/// Presence snapshot for one topic across nodes.
65#[derive(Debug, Clone, PartialEq, Eq)]
66pub struct PubSubPresenceSnapshot {
67    pub topic: String,
68    pub total_sessions: usize,
69    pub by_node: BTreeMap<String, usize>,
70}
71
72impl PubSubPresenceSnapshot {
73    fn empty(topic: &str) -> Self {
74        Self {
75            topic: topic.to_string(),
76            total_sessions: 0,
77            by_node: BTreeMap::new(),
78        }
79    }
80}
81
82/// Errors produced while receiving subscription fanout.
83#[derive(Debug, Clone, PartialEq, Eq)]
84pub enum PubSubReceiveError {
85    Closed,
86    Lagged(u64),
87}
88
89type PubSubRecvFuture<'a> =
90    Pin<Box<dyn Future<Output = Result<PubSubMessage, PubSubReceiveError>> + Send + 'a>>;
91
92/// Backend-owned receiver trait for one subscription.
93pub trait PubSubSubscriptionHandle: Send {
94    fn recv(&mut self) -> PubSubRecvFuture<'_>;
95}
96
97struct BroadcastSubscriptionHandle {
98    receiver: broadcast::Receiver<PubSubMessage>,
99}
100
101impl PubSubSubscriptionHandle for BroadcastSubscriptionHandle {
102    fn recv(&mut self) -> PubSubRecvFuture<'_> {
103        Box::pin(async move {
104            self.receiver.recv().await.map_err(|err| match err {
105                broadcast::error::RecvError::Closed => PubSubReceiveError::Closed,
106                broadcast::error::RecvError::Lagged(skipped) => PubSubReceiveError::Lagged(skipped),
107            })
108        })
109    }
110}
111
112/// Backend interface for adapter-executed PubSub commands.
113pub trait PubSubBackend: Send + Sync {
114    fn subscribe(&self, topic: &str) -> PubSubSubscription;
115    fn broadcast(&self, topic: &str, messages: Vec<ServerMessage>) -> usize;
116    fn capabilities(&self) -> PubSubCapabilities;
117
118    fn register_presence(&self, _topic: &str, _session_id: &str, _node_id: &str) {}
119
120    fn unregister_presence(&self, _topic: &str, _session_id: &str, _node_id: &str) {}
121
122    fn presence_snapshot(&self, topic: &str) -> PubSubPresenceSnapshot {
123        PubSubPresenceSnapshot::empty(topic)
124    }
125}
126
127#[derive(Debug)]
128struct InProcessPubSubBackend {
129    topics: Arc<Mutex<TopicSenders>>,
130    presence: Arc<Mutex<TopicPresenceMap>>,
131    topic_capacity: usize,
132}
133
134impl InProcessPubSubBackend {
135    fn new(topic_capacity: usize) -> Self {
136        Self {
137            topics: Arc::new(Mutex::new(HashMap::new())),
138            presence: Arc::new(Mutex::new(HashMap::new())),
139            topic_capacity,
140        }
141    }
142
143    fn sender_for(&self, topic: &str) -> broadcast::Sender<PubSubMessage> {
144        let mut topics = self.topics.lock().expect("pubsub topic mutex poisoned");
145        topics
146            .entry(topic.to_string())
147            .or_insert_with(|| {
148                let (sender, _) = broadcast::channel(self.topic_capacity);
149                sender
150            })
151            .clone()
152    }
153}
154
155impl PubSubBackend for InProcessPubSubBackend {
156    fn subscribe(&self, topic: &str) -> PubSubSubscription {
157        let sender = self.sender_for(topic);
158        PubSubSubscription::new(BroadcastSubscriptionHandle {
159            receiver: sender.subscribe(),
160        })
161    }
162
163    fn broadcast(&self, topic: &str, messages: Vec<ServerMessage>) -> usize {
164        let sender = self.sender_for(topic);
165        sender
166            .send(PubSubMessage {
167                topic: topic.to_string(),
168                messages,
169            })
170            .unwrap_or_default()
171    }
172
173    fn capabilities(&self) -> PubSubCapabilities {
174        PubSubCapabilities::in_process()
175    }
176
177    fn register_presence(&self, topic: &str, session_id: &str, node_id: &str) {
178        let mut presence = self
179            .presence
180            .lock()
181            .expect("pubsub presence mutex poisoned");
182        presence
183            .entry(topic.to_string())
184            .or_default()
185            .entry(node_id.to_string())
186            .or_default()
187            .insert(session_id.to_string());
188    }
189
190    fn unregister_presence(&self, topic: &str, session_id: &str, node_id: &str) {
191        let mut presence = self
192            .presence
193            .lock()
194            .expect("pubsub presence mutex poisoned");
195        let mut remove_topic = false;
196        if let Some(by_node) = presence.get_mut(topic) {
197            if let Some(sessions) = by_node.get_mut(node_id) {
198                sessions.remove(session_id);
199                if sessions.is_empty() {
200                    by_node.remove(node_id);
201                }
202            }
203            remove_topic = by_node.is_empty();
204        }
205        if remove_topic {
206            presence.remove(topic);
207        }
208    }
209
210    fn presence_snapshot(&self, topic: &str) -> PubSubPresenceSnapshot {
211        let presence = self
212            .presence
213            .lock()
214            .expect("pubsub presence mutex poisoned");
215        let Some(by_node) = presence.get(topic) else {
216            return PubSubPresenceSnapshot::empty(topic);
217        };
218        let mut snapshot = PubSubPresenceSnapshot {
219            topic: topic.to_string(),
220            total_sessions: 0,
221            by_node: BTreeMap::new(),
222        };
223        for (node_id, sessions) in by_node {
224            snapshot.total_sessions += sessions.len();
225            snapshot.by_node.insert(node_id.clone(), sessions.len());
226        }
227        snapshot
228    }
229}
230
231/// Adapter-owned PubSub runtime abstraction.
232#[derive(Clone)]
233pub struct PubSub {
234    backend: Arc<dyn PubSubBackend>,
235}
236
237impl std::fmt::Debug for PubSub {
238    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239        f.debug_struct("PubSub")
240            .field("capabilities", &self.capabilities())
241            .finish()
242    }
243}
244
245impl Default for PubSub {
246    fn default() -> Self {
247        Self::new(DEFAULT_TOPIC_CAPACITY)
248    }
249}
250
251impl PubSub {
252    /// Create the built-in in-process backend.
253    pub fn new(topic_capacity: usize) -> Self {
254        Self::with_backend(InProcessPubSubBackend::new(topic_capacity))
255    }
256
257    /// Wrap a custom backend implementation (for clustered fanout backends).
258    pub fn with_backend<B>(backend: B) -> Self
259    where
260        B: PubSubBackend + 'static,
261    {
262        Self {
263            backend: Arc::new(backend),
264        }
265    }
266
267    /// Subscribe to a topic.
268    pub fn subscribe(&self, topic: impl Into<String>) -> PubSubSubscription {
269        let topic = topic.into();
270        self.backend.subscribe(&topic)
271    }
272
273    /// Broadcast one payload to all subscribers on a topic.
274    pub fn broadcast(&self, topic: impl Into<String>, messages: Vec<ServerMessage>) -> usize {
275        let topic = topic.into();
276        self.backend.broadcast(&topic, messages)
277    }
278
279    /// Backend cluster/delivery capability contract.
280    pub fn capabilities(&self) -> PubSubCapabilities {
281        self.backend.capabilities()
282    }
283
284    /// Register one session as present for a topic on one cluster node.
285    pub fn register_presence(
286        &self,
287        topic: impl Into<String>,
288        session_id: impl Into<String>,
289        node_id: impl Into<String>,
290    ) {
291        let topic = topic.into();
292        let session_id = session_id.into();
293        let node_id = node_id.into();
294        self.backend
295            .register_presence(&topic, &session_id, &node_id);
296    }
297
298    /// Unregister one session presence from a topic on one cluster node.
299    pub fn unregister_presence(
300        &self,
301        topic: impl Into<String>,
302        session_id: impl Into<String>,
303        node_id: impl Into<String>,
304    ) {
305        let topic = topic.into();
306        let session_id = session_id.into();
307        let node_id = node_id.into();
308        self.backend
309            .unregister_presence(&topic, &session_id, &node_id);
310    }
311
312    /// Return presence counts for one topic.
313    pub fn presence_snapshot(&self, topic: impl Into<String>) -> PubSubPresenceSnapshot {
314        let topic = topic.into();
315        self.backend.presence_snapshot(&topic)
316    }
317}
318
319/// Message delivered by one PubSub backend.
320#[derive(Debug, Clone, PartialEq)]
321pub struct PubSubMessage {
322    pub topic: String,
323    pub messages: Vec<ServerMessage>,
324}
325
326/// Live subscription receiver for one topic.
327pub struct PubSubSubscription {
328    inner: Box<dyn PubSubSubscriptionHandle>,
329}
330
331impl PubSubSubscription {
332    /// Create one subscription from a custom backend receiver handle.
333    pub fn new<H>(handle: H) -> Self
334    where
335        H: PubSubSubscriptionHandle + 'static,
336    {
337        Self {
338            inner: Box::new(handle),
339        }
340    }
341
342    pub async fn recv(&mut self) -> Result<PubSubMessage, PubSubReceiveError> {
343        self.inner.recv().await
344    }
345}
346
347/// Internal commands collected from `Context` and executed by the adapter.
348#[derive(Debug, Clone, PartialEq)]
349pub enum PubSubCommand {
350    Subscribe {
351        topic: String,
352    },
353    Broadcast {
354        topic: String,
355        messages: Vec<ServerMessage>,
356    },
357}
358
359#[cfg(test)]
360mod tests {
361    use super::{
362        BroadcastSubscriptionHandle, PubSub, PubSubBackend, PubSubCapabilities,
363        PubSubDeliveryScope, PubSubMessage, PubSubOrdering, PubSubReceiveError, PubSubSubscription,
364        SessionAffinityRequirement,
365    };
366    use crate::ServerMessage;
367    use std::{
368        collections::HashMap,
369        sync::{Arc, Mutex},
370    };
371    use tokio::sync::broadcast;
372
373    #[tokio::test]
374    async fn in_process_pubsub_broadcasts_to_subscribers() {
375        let pubsub = PubSub::default();
376        let mut first = pubsub.subscribe("chat:lobby");
377        let mut second = pubsub.subscribe("chat:lobby");
378
379        assert_eq!(
380            pubsub.broadcast(
381                "chat:lobby",
382                vec![ServerMessage::Redirect {
383                    to: "/ok".to_string()
384                }]
385            ),
386            2
387        );
388
389        assert_eq!(first.recv().await.unwrap().topic, "chat:lobby");
390        assert_eq!(
391            second.recv().await.unwrap().messages,
392            vec![ServerMessage::Redirect {
393                to: "/ok".to_string()
394            }]
395        );
396    }
397
398    #[test]
399    fn in_process_pubsub_reports_cluster_capabilities_and_presence() {
400        let pubsub = PubSub::default();
401        let capabilities = pubsub.capabilities();
402        assert_eq!(capabilities.backend, "in_process");
403        assert_eq!(
404            capabilities.delivery_scope,
405            PubSubDeliveryScope::LocalProcess
406        );
407        assert_eq!(capabilities.ordering, PubSubOrdering::PerTopicOrdered);
408        assert_eq!(
409            capabilities.session_affinity,
410            SessionAffinityRequirement::StatefulSessionRequired
411        );
412        assert!(capabilities.presence_tracking);
413
414        pubsub.register_presence("chat:lobby", "s1", "node-a");
415        pubsub.register_presence("chat:lobby", "s2", "node-a");
416        pubsub.register_presence("chat:lobby", "s3", "node-b");
417        let snapshot = pubsub.presence_snapshot("chat:lobby");
418        assert_eq!(snapshot.topic, "chat:lobby");
419        assert_eq!(snapshot.total_sessions, 3);
420        assert_eq!(snapshot.by_node.get("node-a"), Some(&2));
421        assert_eq!(snapshot.by_node.get("node-b"), Some(&1));
422
423        pubsub.unregister_presence("chat:lobby", "s2", "node-a");
424        let after = pubsub.presence_snapshot("chat:lobby");
425        assert_eq!(after.total_sessions, 2);
426        assert_eq!(after.by_node.get("node-a"), Some(&1));
427    }
428
429    #[derive(Debug, Clone)]
430    struct SharedHub {
431        topics: Arc<Mutex<HashMap<String, broadcast::Sender<PubSubMessage>>>>,
432    }
433
434    impl SharedHub {
435        fn new() -> Self {
436            Self {
437                topics: Arc::new(Mutex::new(HashMap::new())),
438            }
439        }
440
441        fn sender_for(&self, topic: &str) -> broadcast::Sender<PubSubMessage> {
442            let mut topics = self.topics.lock().expect("hub mutex poisoned");
443            topics
444                .entry(topic.to_string())
445                .or_insert_with(|| {
446                    let (tx, _) = broadcast::channel(256);
447                    tx
448                })
449                .clone()
450        }
451    }
452
453    #[derive(Debug, Clone)]
454    struct MockClusterBackend {
455        hub: SharedHub,
456    }
457
458    impl PubSubBackend for MockClusterBackend {
459        fn subscribe(&self, topic: &str) -> PubSubSubscription {
460            let receiver = self.hub.sender_for(topic).subscribe();
461            PubSubSubscription::new(BroadcastSubscriptionHandle { receiver })
462        }
463
464        fn broadcast(&self, topic: &str, messages: Vec<ServerMessage>) -> usize {
465            self.hub
466                .sender_for(topic)
467                .send(PubSubMessage {
468                    topic: topic.to_string(),
469                    messages,
470                })
471                .unwrap_or_default()
472        }
473
474        fn capabilities(&self) -> PubSubCapabilities {
475            PubSubCapabilities {
476                backend: "mock_cluster".to_string(),
477                delivery_scope: PubSubDeliveryScope::Cluster,
478                ordering: PubSubOrdering::BestEffort,
479                session_affinity: SessionAffinityRequirement::StatefulSessionRequired,
480                presence_tracking: false,
481            }
482        }
483    }
484
485    #[tokio::test]
486    async fn custom_backend_can_fanout_across_multiple_pubsub_instances() {
487        let hub = SharedHub::new();
488        let node_a = PubSub::with_backend(MockClusterBackend { hub: hub.clone() });
489        let node_b = PubSub::with_backend(MockClusterBackend { hub });
490
491        let mut subscription = node_a.subscribe("cluster:lobby");
492        assert_eq!(
493            node_b.broadcast(
494                "cluster:lobby",
495                vec![ServerMessage::Error {
496                    message: "hello".to_string(),
497                    code: Some("cluster".to_string()),
498                }]
499            ),
500            1
501        );
502
503        let delivered = subscription.recv().await.unwrap();
504        assert_eq!(delivered.topic, "cluster:lobby");
505        assert_eq!(delivered.messages.len(), 1);
506        match &delivered.messages[0] {
507            ServerMessage::Error { message, code } => {
508                assert_eq!(message, "hello");
509                assert_eq!(code.as_deref(), Some("cluster"));
510            }
511            other => panic!("unexpected payload: {other:?}"),
512        }
513    }
514
515    #[test]
516    fn pubsub_debug_and_presence_cleanup_cover_additional_branches() {
517        let pubsub = PubSub::default();
518        assert!(format!("{pubsub:?}").contains("capabilities"));
519
520        pubsub.register_presence("chat:lobby", "session-1", "node-a");
521        pubsub.unregister_presence("chat:lobby", "session-1", "node-a");
522        let snapshot = pubsub.presence_snapshot("chat:lobby");
523        assert_eq!(snapshot.total_sessions, 0);
524        assert!(snapshot.by_node.is_empty());
525    }
526
527    #[derive(Debug, Clone, Default)]
528    struct NoPresenceBackend;
529
530    impl PubSubBackend for NoPresenceBackend {
531        fn subscribe(&self, _topic: &str) -> PubSubSubscription {
532            let (_sender, receiver) = broadcast::channel(8);
533            PubSubSubscription::new(BroadcastSubscriptionHandle { receiver })
534        }
535
536        fn broadcast(&self, _topic: &str, _messages: Vec<ServerMessage>) -> usize {
537            0
538        }
539
540        fn capabilities(&self) -> PubSubCapabilities {
541            PubSubCapabilities {
542                backend: "no_presence".to_string(),
543                delivery_scope: PubSubDeliveryScope::LocalProcess,
544                ordering: PubSubOrdering::BestEffort,
545                session_affinity: SessionAffinityRequirement::None,
546                presence_tracking: false,
547            }
548        }
549    }
550
551    #[test]
552    fn default_presence_methods_return_empty_snapshot() {
553        let pubsub = PubSub::with_backend(NoPresenceBackend);
554        pubsub.register_presence("topic", "session-1", "node-a");
555        pubsub.unregister_presence("topic", "session-1", "node-a");
556        let snapshot = pubsub.presence_snapshot("topic");
557        assert_eq!(snapshot.topic, "topic");
558        assert_eq!(snapshot.total_sessions, 0);
559        assert!(snapshot.by_node.is_empty());
560    }
561
562    #[tokio::test]
563    async fn broadcast_subscription_maps_lagged_and_closed_recv_errors() {
564        let (sender, receiver) = broadcast::channel(1);
565        let mut lagged = PubSubSubscription::new(BroadcastSubscriptionHandle { receiver });
566        sender
567            .send(PubSubMessage {
568                topic: "topic".to_string(),
569                messages: vec![ServerMessage::Redirect {
570                    to: "/a".to_string(),
571                }],
572            })
573            .unwrap();
574        sender
575            .send(PubSubMessage {
576                topic: "topic".to_string(),
577                messages: vec![ServerMessage::Redirect {
578                    to: "/b".to_string(),
579                }],
580            })
581            .unwrap();
582
583        assert!(matches!(
584            lagged.recv().await,
585            Err(PubSubReceiveError::Lagged(_))
586        ));
587
588        let (sender2, receiver2) = broadcast::channel(1);
589        drop(sender2);
590        let mut closed = PubSubSubscription::new(BroadcastSubscriptionHandle {
591            receiver: receiver2,
592        });
593        assert!(matches!(
594            closed.recv().await,
595            Err(PubSubReceiveError::Closed)
596        ));
597    }
598}