Skip to main content

rakka_cluster_tools/
pub_sub.rs

1//! `DistributedPubSub.Mediator` (local-topic subset).
2//! akka.net: `Akka.Cluster.Tools/PublishSubscribe/DistributedPubSubMediator.cs`.
3//!
4//! Phase 7 of `docs/full-port-plan.md`. The mediator owns a local
5//! per-node topic table; cross-node gossip plugs in once Phase 6's
6//! gossip transport lands. This sub-step adds:
7//!
8//! * **Typed publish** — `publish_msg::<M>(topic, msg)` actually
9//!   delivers the message to each subscribed `ActorRef<M>` (the
10//!   prior API only returned the subscriber list).
11//! * **Group routing** — `subscribe_to_group(topic, group, ref)`
12//!   buckets subscribers; `send_to_group(topic, group, msg)` picks
13//!   one round-robin recipient per call.
14//! * **send_to_one(path)** — single recipient by path, akka.net's
15//!   `DistributedPubSubMediator.Send` semantics.
16
17use std::collections::HashMap;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::Arc;
20
21use parking_lot::RwLock;
22
23use rakka_core::actor::{ActorRef, UntypedActorRef};
24
25#[derive(Default)]
26pub struct DistributedPubSub {
27    topics: RwLock<HashMap<String, Vec<TypedSubscriber>>>,
28    groups: RwLock<HashMap<(String, String), Group>>,
29}
30
31type DeliverAnyFn = Box<dyn Fn(&dyn std::any::Any) -> bool + Send + Sync>;
32type CodecFn = Box<dyn Fn(&[u8]) -> bool + Send + Sync>;
33
34/// A subscriber that knows how to deliver `M` by holding a typed
35/// closure. Stored type-erased in the mediator so the topic table
36/// is a homogeneous `Vec`.
37struct TypedSubscriber {
38    untyped: UntypedActorRef,
39    deliver_any: DeliverAnyFn,
40}
41
42#[derive(Default)]
43struct Group {
44    members: Vec<TypedSubscriber>,
45    cursor: AtomicUsize,
46}
47
48impl DistributedPubSub {
49    pub fn new() -> Arc<Self> {
50        Arc::new(Self::default())
51    }
52
53    /// Subscribe `subscriber: ActorRef<M>` to `topic`. Future
54    /// `publish_msg::<M>(topic, msg)` calls deliver to it.
55    pub fn subscribe<M: Clone + Send + 'static>(&self, topic: impl Into<String>, subscriber: ActorRef<M>) {
56        let typed = TypedSubscriber::new(subscriber);
57        self.topics.write().entry(topic.into()).or_default().push(typed);
58    }
59
60    /// Subscribe to a `(topic, group)` bucket. `send_to_group`
61    /// rotates through bucket members.
62    pub fn subscribe_to_group<M: Clone + Send + 'static>(
63        &self,
64        topic: impl Into<String>,
65        group: impl Into<String>,
66        subscriber: ActorRef<M>,
67    ) {
68        let typed = TypedSubscriber::new(subscriber);
69        self.groups.write().entry((topic.into(), group.into())).or_default().members.push(typed);
70    }
71
72    /// Drop a subscriber by path from a topic.
73    pub fn unsubscribe(&self, topic: &str, subscriber_path: &rakka_core::actor::ActorPath) {
74        if let Some(v) = self.topics.write().get_mut(topic) {
75            v.retain(|s| s.untyped.path() != subscriber_path);
76        }
77    }
78
79    /// Snapshot of subscriber refs for a topic. Useful for tests +
80    /// the legacy "discover, then send" pattern.
81    pub fn publish(&self, topic: &str) -> Vec<UntypedActorRef> {
82        self.topics
83            .read()
84            .get(topic)
85            .map(|v| v.iter().map(|s| s.untyped.clone()).collect())
86            .unwrap_or_default()
87    }
88
89    /// Typed broadcast. Delivers `msg` (cloned) to every subscriber
90    /// of `topic`. Returns the number of successful deliveries.
91    pub fn publish_msg<M: Clone + Send + 'static>(&self, topic: &str, msg: M) -> usize {
92        let subs = self.topics.read();
93        let Some(list) = subs.get(topic) else {
94            return 0;
95        };
96        let mut delivered = 0;
97        let any: &dyn std::any::Any = &msg;
98        for s in list {
99            if (s.deliver_any)(any) {
100                delivered += 1;
101            }
102        }
103        // Clone-per-recipient happens inside deliver_any, so we
104        // can't move `msg`. The first deliver is a borrow; subsequent
105        // delivers re-borrow the same `Any`.
106        let _ = msg; // keep alive
107        delivered
108    }
109
110    /// Pick one member of `(topic, group)` round-robin and deliver
111    /// `msg`. Returns `true` if a recipient was found.
112    pub fn send_to_group<M: Clone + Send + 'static>(&self, topic: &str, group: &str, msg: M) -> bool {
113        let groups = self.groups.read();
114        let Some(g) = groups.get(&(topic.to_string(), group.to_string())) else {
115            return false;
116        };
117        if g.members.is_empty() {
118            return false;
119        }
120        let i = g.cursor.fetch_add(1, Ordering::Relaxed) % g.members.len();
121        let any: &dyn std::any::Any = &msg;
122        let r = (g.members[i].deliver_any)(any);
123        let _ = msg;
124        r
125    }
126
127    pub fn topic_count(&self) -> usize {
128        self.topics.read().len()
129    }
130
131    pub fn group_count(&self) -> usize {
132        self.groups.read().len()
133    }
134}
135
136// -----------------------------------------------------------------------
137// Phase 7.B — cross-node mediator.
138// -----------------------------------------------------------------------
139
140use std::collections::HashSet;
141
142/// Pluggable transport for the cross-node mediator. Sends an outbound
143/// `MediatorPdu` to a peer node, identified by an opaque string node id
144/// (typically `Address::to_string()`). The transport is responsible for
145/// the wire round-trip; on the receiver side, the inbound PDU is fed
146/// back into the local mediator via [`ClusterPubSub::apply_pdu`].
147pub trait MediatorTransport: Send + Sync + 'static {
148    fn send(&self, target_node: &str, pdu: MediatorPdu);
149}
150
151/// Wire shape of a cross-node mediator exchange.
152#[derive(Debug, Clone)]
153#[non_exhaustive]
154pub enum MediatorPdu {
155    /// Announce the set of topics this node has at least one subscriber for.
156    TopicAnnounce { from: String, topics: Vec<String> },
157    /// Forward `msg_blob` (already serialized) to every local subscriber
158    /// of `topic` on the receiving node.
159    Forward { topic: String, msg_blob: Vec<u8>, type_id: String },
160}
161
162/// Mediator that augments a local [`DistributedPubSub`] with a
163/// cross-node topic table + transport. Clusters publish via
164/// [`ClusterPubSub::publish_remote`] which fans out to all nodes that
165/// have advertised the topic; receivers route the payload to local
166/// subscribers using the codec registry.
167pub struct ClusterPubSub {
168    local: Arc<DistributedPubSub>,
169    self_node: String,
170    /// `topic -> set of advertising node-ids`.
171    remote_topics: RwLock<HashMap<String, HashSet<String>>>,
172    transport: Arc<dyn MediatorTransport>,
173    codecs: RwLock<HashMap<String, CodecFn>>,
174}
175
176impl ClusterPubSub {
177    pub fn new(
178        local: Arc<DistributedPubSub>,
179        self_node: impl Into<String>,
180        transport: Arc<dyn MediatorTransport>,
181    ) -> Arc<Self> {
182        Arc::new(Self {
183            local,
184            self_node: self_node.into(),
185            remote_topics: RwLock::new(HashMap::new()),
186            transport,
187            codecs: RwLock::new(HashMap::new()),
188        })
189    }
190
191    /// Register a per-message-type decoder for inbound `Forward` PDUs.
192    /// `type_id` typically matches `std::any::type_name::<M>()`; the
193    /// decoder must deliver to local subscribers (and return `true` if
194    /// any delivery happened).
195    pub fn register_decoder<F>(&self, type_id: impl Into<String>, decode: F)
196    where
197        F: Fn(&[u8]) -> bool + Send + Sync + 'static,
198    {
199        self.codecs.write().insert(type_id.into(), Box::new(decode));
200    }
201
202    /// Announce currently-subscribed topics to a peer node. Caller drives
203    /// this on a tick (similar to `ClusterDaemon`).
204    pub fn announce_to(&self, target_node: &str) {
205        let topics: Vec<String> = self.local.topics.read().keys().cloned().collect();
206        self.transport.send(target_node, MediatorPdu::TopicAnnounce { from: self.self_node.clone(), topics });
207    }
208
209    /// Apply an inbound PDU received from the transport.
210    pub fn apply_pdu(&self, pdu: MediatorPdu) {
211        match pdu {
212            MediatorPdu::TopicAnnounce { from, topics } => {
213                let mut g = self.remote_topics.write();
214                // Drop prior announcements from this node.
215                for set in g.values_mut() {
216                    set.remove(&from);
217                }
218                for t in topics {
219                    g.entry(t).or_default().insert(from.clone());
220                }
221            }
222            MediatorPdu::Forward { topic, msg_blob, type_id } => {
223                let codecs = self.codecs.read();
224                if let Some(decode) = codecs.get(&type_id) {
225                    let _ = decode(&msg_blob);
226                    // Local fan-out: the decoder publishes to this node's
227                    // local mediator. The topic is implicit in the codec's
228                    // closure body. We also stash the topic for diagnostics.
229                    let _ = topic;
230                }
231            }
232        }
233    }
234
235    /// Cross-node publish. Locally fan-out via the wrapped mediator,
236    /// then forward the serialized payload to every remote node that has
237    /// announced this topic.
238    pub fn publish_remote<M, S>(&self, topic: &str, msg: M, type_id: impl Into<String>, encode: S) -> usize
239    where
240        M: Clone + Send + 'static,
241        S: FnOnce(&M) -> Vec<u8>,
242    {
243        let local_n = self.local.publish_msg(topic, msg.clone());
244        let remote = self.remote_topics.read();
245        let Some(nodes) = remote.get(topic) else { return local_n };
246        let blob = encode(&msg);
247        let type_id = type_id.into();
248        let mut forwarded = 0;
249        for node in nodes {
250            if node == &self.self_node {
251                continue;
252            }
253            self.transport.send(
254                node,
255                MediatorPdu::Forward {
256                    topic: topic.into(),
257                    msg_blob: blob.clone(),
258                    type_id: type_id.clone(),
259                },
260            );
261            forwarded += 1;
262        }
263        local_n + forwarded
264    }
265
266    pub fn known_remote_topics(&self) -> usize {
267        self.remote_topics.read().len()
268    }
269
270    pub fn nodes_for(&self, topic: &str) -> Vec<String> {
271        self.remote_topics.read().get(topic).map(|s| s.iter().cloned().collect()).unwrap_or_default()
272    }
273}
274
275impl TypedSubscriber {
276    fn new<M: Clone + Send + 'static>(r: ActorRef<M>) -> Self {
277        let untyped = r.as_untyped();
278        let r2 = r.clone();
279        let deliver_any: DeliverAnyFn = Box::new(move |any| {
280            if let Some(m) = any.downcast_ref::<M>() {
281                r2.tell(m.clone());
282                true
283            } else {
284                false
285            }
286        });
287        Self { untyped, deliver_any }
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use rakka_core::actor::Inbox;
295    use std::time::Duration;
296
297    #[test]
298    fn subscribe_and_publish_returns_subscriber_list() {
299        let bus = DistributedPubSub::new();
300        let inbox = Inbox::<u32>::new("s");
301        bus.subscribe("greetings", inbox.actor_ref().clone());
302        let subs = bus.publish("greetings");
303        assert_eq!(subs.len(), 1);
304    }
305
306    #[tokio::test]
307    async fn typed_publish_delivers_to_each_subscriber() {
308        let bus = DistributedPubSub::new();
309        let mut a = Inbox::<u32>::new("a");
310        let mut b = Inbox::<u32>::new("b");
311        bus.subscribe("nums", a.actor_ref().clone());
312        bus.subscribe("nums", b.actor_ref().clone());
313
314        let n = bus.publish_msg("nums", 7u32);
315        assert_eq!(n, 2);
316
317        assert_eq!(a.receive(Duration::from_millis(50)).await.unwrap(), 7);
318        assert_eq!(b.receive(Duration::from_millis(50)).await.unwrap(), 7);
319    }
320
321    #[tokio::test]
322    async fn publish_to_unknown_topic_delivers_zero() {
323        let bus = DistributedPubSub::new();
324        let n = bus.publish_msg("nope", 1u32);
325        assert_eq!(n, 0);
326    }
327
328    #[tokio::test]
329    async fn group_send_round_robins_one_member() {
330        let bus = DistributedPubSub::new();
331        let mut a = Inbox::<u32>::new("ga");
332        let mut b = Inbox::<u32>::new("gb");
333        bus.subscribe_to_group("work", "G1", a.actor_ref().clone());
334        bus.subscribe_to_group("work", "G1", b.actor_ref().clone());
335
336        // 4 sends → 2 + 2 (round-robin starts at index 0).
337        for i in 0..4u32 {
338            assert!(bus.send_to_group("work", "G1", i));
339        }
340        let mut a_count = 0;
341        let mut b_count = 0;
342        for _ in 0..2 {
343            a.receive(Duration::from_millis(20)).await.unwrap();
344            a_count += 1;
345            b.receive(Duration::from_millis(20)).await.unwrap();
346            b_count += 1;
347        }
348        assert_eq!(a_count, 2);
349        assert_eq!(b_count, 2);
350    }
351
352    #[derive(Default, Clone)]
353    struct CapturingTransport {
354        sent: Arc<parking_lot::Mutex<Vec<(String, MediatorPdu)>>>,
355    }
356    impl MediatorTransport for CapturingTransport {
357        fn send(&self, target: &str, pdu: MediatorPdu) {
358            self.sent.lock().push((target.to_string(), pdu));
359        }
360    }
361
362    #[tokio::test]
363    async fn cluster_pub_sub_announce_and_forward_round_trip() {
364        let local_a = DistributedPubSub::new();
365        let local_b = DistributedPubSub::new();
366        let mut subscriber = Inbox::<u32>::new("sub");
367        local_b.subscribe("nums", subscriber.actor_ref().clone());
368        let net = CapturingTransport::default();
369        let net_arc: Arc<dyn MediatorTransport> = Arc::new(net.clone());
370        let a = ClusterPubSub::new(local_a.clone(), "node-a", net_arc.clone());
371        let b = ClusterPubSub::new(local_b.clone(), "node-b", net_arc);
372
373        // B announces its topics.
374        b.announce_to("node-a");
375        let pdu = net.sent.lock().pop().unwrap().1;
376        a.apply_pdu(pdu);
377        assert_eq!(a.known_remote_topics(), 1);
378        assert_eq!(a.nodes_for("nums"), vec!["node-b".to_string()]);
379
380        // B installs a decoder that publishes locally.
381        let local_b2 = local_b.clone();
382        b.register_decoder("u32", move |bytes| {
383            let n = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
384            local_b2.publish_msg::<u32>("nums", n) > 0
385        });
386
387        // A publishes — it forwards to B.
388        let n = a.publish_remote::<u32, _>("nums", 42, "u32", |m| m.to_le_bytes().to_vec());
389        assert_eq!(n, 1);
390        let (target, fwd) = net.sent.lock().pop().unwrap();
391        assert_eq!(target, "node-b");
392        b.apply_pdu(fwd);
393        assert_eq!(subscriber.receive(std::time::Duration::from_millis(50)).await.unwrap(), 42);
394    }
395
396    #[test]
397    fn group_count_tracks_distinct_buckets() {
398        let bus = DistributedPubSub::new();
399        let inbox = Inbox::<u32>::new("g");
400        bus.subscribe_to_group("t1", "G1", inbox.actor_ref().clone());
401        bus.subscribe_to_group("t1", "G2", inbox.actor_ref().clone());
402        bus.subscribe_to_group("t2", "G1", inbox.actor_ref().clone());
403        assert_eq!(bus.group_count(), 3);
404    }
405}