1use std::collections::HashMap;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::Arc;
20
21use parking_lot::RwLock;
22
23use rakka_core::actor::{ActorRef, UntypedActorRef};
24
25#[derive(Default)]
26pub struct DistributedPubSub {
27 topics: RwLock<HashMap<String, Vec<TypedSubscriber>>>,
28 groups: RwLock<HashMap<(String, String), Group>>,
29}
30
31type DeliverAnyFn = Box<dyn Fn(&dyn std::any::Any) -> bool + Send + Sync>;
32type CodecFn = Box<dyn Fn(&[u8]) -> bool + Send + Sync>;
33
34struct TypedSubscriber {
38 untyped: UntypedActorRef,
39 deliver_any: DeliverAnyFn,
40}
41
42#[derive(Default)]
43struct Group {
44 members: Vec<TypedSubscriber>,
45 cursor: AtomicUsize,
46}
47
48impl DistributedPubSub {
49 pub fn new() -> Arc<Self> {
50 Arc::new(Self::default())
51 }
52
53 pub fn subscribe<M: Clone + Send + 'static>(&self, topic: impl Into<String>, subscriber: ActorRef<M>) {
56 let typed = TypedSubscriber::new(subscriber);
57 self.topics.write().entry(topic.into()).or_default().push(typed);
58 }
59
60 pub fn subscribe_to_group<M: Clone + Send + 'static>(
63 &self,
64 topic: impl Into<String>,
65 group: impl Into<String>,
66 subscriber: ActorRef<M>,
67 ) {
68 let typed = TypedSubscriber::new(subscriber);
69 self.groups.write().entry((topic.into(), group.into())).or_default().members.push(typed);
70 }
71
72 pub fn unsubscribe(&self, topic: &str, subscriber_path: &rakka_core::actor::ActorPath) {
74 if let Some(v) = self.topics.write().get_mut(topic) {
75 v.retain(|s| s.untyped.path() != subscriber_path);
76 }
77 }
78
79 pub fn publish(&self, topic: &str) -> Vec<UntypedActorRef> {
82 self.topics
83 .read()
84 .get(topic)
85 .map(|v| v.iter().map(|s| s.untyped.clone()).collect())
86 .unwrap_or_default()
87 }
88
89 pub fn publish_msg<M: Clone + Send + 'static>(&self, topic: &str, msg: M) -> usize {
92 let subs = self.topics.read();
93 let Some(list) = subs.get(topic) else {
94 return 0;
95 };
96 let mut delivered = 0;
97 let any: &dyn std::any::Any = &msg;
98 for s in list {
99 if (s.deliver_any)(any) {
100 delivered += 1;
101 }
102 }
103 let _ = msg; delivered
108 }
109
110 pub fn send_to_group<M: Clone + Send + 'static>(&self, topic: &str, group: &str, msg: M) -> bool {
113 let groups = self.groups.read();
114 let Some(g) = groups.get(&(topic.to_string(), group.to_string())) else {
115 return false;
116 };
117 if g.members.is_empty() {
118 return false;
119 }
120 let i = g.cursor.fetch_add(1, Ordering::Relaxed) % g.members.len();
121 let any: &dyn std::any::Any = &msg;
122 let r = (g.members[i].deliver_any)(any);
123 let _ = msg;
124 r
125 }
126
127 pub fn topic_count(&self) -> usize {
128 self.topics.read().len()
129 }
130
131 pub fn group_count(&self) -> usize {
132 self.groups.read().len()
133 }
134}
135
136use std::collections::HashSet;
141
142pub trait MediatorTransport: Send + Sync + 'static {
148 fn send(&self, target_node: &str, pdu: MediatorPdu);
149}
150
151#[derive(Debug, Clone)]
153#[non_exhaustive]
154pub enum MediatorPdu {
155 TopicAnnounce { from: String, topics: Vec<String> },
157 Forward { topic: String, msg_blob: Vec<u8>, type_id: String },
160}
161
162pub struct ClusterPubSub {
168 local: Arc<DistributedPubSub>,
169 self_node: String,
170 remote_topics: RwLock<HashMap<String, HashSet<String>>>,
172 transport: Arc<dyn MediatorTransport>,
173 codecs: RwLock<HashMap<String, CodecFn>>,
174}
175
176impl ClusterPubSub {
177 pub fn new(
178 local: Arc<DistributedPubSub>,
179 self_node: impl Into<String>,
180 transport: Arc<dyn MediatorTransport>,
181 ) -> Arc<Self> {
182 Arc::new(Self {
183 local,
184 self_node: self_node.into(),
185 remote_topics: RwLock::new(HashMap::new()),
186 transport,
187 codecs: RwLock::new(HashMap::new()),
188 })
189 }
190
191 pub fn register_decoder<F>(&self, type_id: impl Into<String>, decode: F)
196 where
197 F: Fn(&[u8]) -> bool + Send + Sync + 'static,
198 {
199 self.codecs.write().insert(type_id.into(), Box::new(decode));
200 }
201
202 pub fn announce_to(&self, target_node: &str) {
205 let topics: Vec<String> = self.local.topics.read().keys().cloned().collect();
206 self.transport.send(target_node, MediatorPdu::TopicAnnounce { from: self.self_node.clone(), topics });
207 }
208
209 pub fn apply_pdu(&self, pdu: MediatorPdu) {
211 match pdu {
212 MediatorPdu::TopicAnnounce { from, topics } => {
213 let mut g = self.remote_topics.write();
214 for set in g.values_mut() {
216 set.remove(&from);
217 }
218 for t in topics {
219 g.entry(t).or_default().insert(from.clone());
220 }
221 }
222 MediatorPdu::Forward { topic, msg_blob, type_id } => {
223 let codecs = self.codecs.read();
224 if let Some(decode) = codecs.get(&type_id) {
225 let _ = decode(&msg_blob);
226 let _ = topic;
230 }
231 }
232 }
233 }
234
235 pub fn publish_remote<M, S>(&self, topic: &str, msg: M, type_id: impl Into<String>, encode: S) -> usize
239 where
240 M: Clone + Send + 'static,
241 S: FnOnce(&M) -> Vec<u8>,
242 {
243 let local_n = self.local.publish_msg(topic, msg.clone());
244 let remote = self.remote_topics.read();
245 let Some(nodes) = remote.get(topic) else { return local_n };
246 let blob = encode(&msg);
247 let type_id = type_id.into();
248 let mut forwarded = 0;
249 for node in nodes {
250 if node == &self.self_node {
251 continue;
252 }
253 self.transport.send(
254 node,
255 MediatorPdu::Forward {
256 topic: topic.into(),
257 msg_blob: blob.clone(),
258 type_id: type_id.clone(),
259 },
260 );
261 forwarded += 1;
262 }
263 local_n + forwarded
264 }
265
266 pub fn known_remote_topics(&self) -> usize {
267 self.remote_topics.read().len()
268 }
269
270 pub fn nodes_for(&self, topic: &str) -> Vec<String> {
271 self.remote_topics.read().get(topic).map(|s| s.iter().cloned().collect()).unwrap_or_default()
272 }
273}
274
275impl TypedSubscriber {
276 fn new<M: Clone + Send + 'static>(r: ActorRef<M>) -> Self {
277 let untyped = r.as_untyped();
278 let r2 = r.clone();
279 let deliver_any: DeliverAnyFn = Box::new(move |any| {
280 if let Some(m) = any.downcast_ref::<M>() {
281 r2.tell(m.clone());
282 true
283 } else {
284 false
285 }
286 });
287 Self { untyped, deliver_any }
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use rakka_core::actor::Inbox;
295 use std::time::Duration;
296
297 #[test]
298 fn subscribe_and_publish_returns_subscriber_list() {
299 let bus = DistributedPubSub::new();
300 let inbox = Inbox::<u32>::new("s");
301 bus.subscribe("greetings", inbox.actor_ref().clone());
302 let subs = bus.publish("greetings");
303 assert_eq!(subs.len(), 1);
304 }
305
306 #[tokio::test]
307 async fn typed_publish_delivers_to_each_subscriber() {
308 let bus = DistributedPubSub::new();
309 let mut a = Inbox::<u32>::new("a");
310 let mut b = Inbox::<u32>::new("b");
311 bus.subscribe("nums", a.actor_ref().clone());
312 bus.subscribe("nums", b.actor_ref().clone());
313
314 let n = bus.publish_msg("nums", 7u32);
315 assert_eq!(n, 2);
316
317 assert_eq!(a.receive(Duration::from_millis(50)).await.unwrap(), 7);
318 assert_eq!(b.receive(Duration::from_millis(50)).await.unwrap(), 7);
319 }
320
321 #[tokio::test]
322 async fn publish_to_unknown_topic_delivers_zero() {
323 let bus = DistributedPubSub::new();
324 let n = bus.publish_msg("nope", 1u32);
325 assert_eq!(n, 0);
326 }
327
328 #[tokio::test]
329 async fn group_send_round_robins_one_member() {
330 let bus = DistributedPubSub::new();
331 let mut a = Inbox::<u32>::new("ga");
332 let mut b = Inbox::<u32>::new("gb");
333 bus.subscribe_to_group("work", "G1", a.actor_ref().clone());
334 bus.subscribe_to_group("work", "G1", b.actor_ref().clone());
335
336 for i in 0..4u32 {
338 assert!(bus.send_to_group("work", "G1", i));
339 }
340 let mut a_count = 0;
341 let mut b_count = 0;
342 for _ in 0..2 {
343 a.receive(Duration::from_millis(20)).await.unwrap();
344 a_count += 1;
345 b.receive(Duration::from_millis(20)).await.unwrap();
346 b_count += 1;
347 }
348 assert_eq!(a_count, 2);
349 assert_eq!(b_count, 2);
350 }
351
352 #[derive(Default, Clone)]
353 struct CapturingTransport {
354 sent: Arc<parking_lot::Mutex<Vec<(String, MediatorPdu)>>>,
355 }
356 impl MediatorTransport for CapturingTransport {
357 fn send(&self, target: &str, pdu: MediatorPdu) {
358 self.sent.lock().push((target.to_string(), pdu));
359 }
360 }
361
362 #[tokio::test]
363 async fn cluster_pub_sub_announce_and_forward_round_trip() {
364 let local_a = DistributedPubSub::new();
365 let local_b = DistributedPubSub::new();
366 let mut subscriber = Inbox::<u32>::new("sub");
367 local_b.subscribe("nums", subscriber.actor_ref().clone());
368 let net = CapturingTransport::default();
369 let net_arc: Arc<dyn MediatorTransport> = Arc::new(net.clone());
370 let a = ClusterPubSub::new(local_a.clone(), "node-a", net_arc.clone());
371 let b = ClusterPubSub::new(local_b.clone(), "node-b", net_arc);
372
373 b.announce_to("node-a");
375 let pdu = net.sent.lock().pop().unwrap().1;
376 a.apply_pdu(pdu);
377 assert_eq!(a.known_remote_topics(), 1);
378 assert_eq!(a.nodes_for("nums"), vec!["node-b".to_string()]);
379
380 let local_b2 = local_b.clone();
382 b.register_decoder("u32", move |bytes| {
383 let n = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
384 local_b2.publish_msg::<u32>("nums", n) > 0
385 });
386
387 let n = a.publish_remote::<u32, _>("nums", 42, "u32", |m| m.to_le_bytes().to_vec());
389 assert_eq!(n, 1);
390 let (target, fwd) = net.sent.lock().pop().unwrap();
391 assert_eq!(target, "node-b");
392 b.apply_pdu(fwd);
393 assert_eq!(subscriber.receive(std::time::Duration::from_millis(50)).await.unwrap(), 42);
394 }
395
396 #[test]
397 fn group_count_tracks_distinct_buckets() {
398 let bus = DistributedPubSub::new();
399 let inbox = Inbox::<u32>::new("g");
400 bus.subscribe_to_group("t1", "G1", inbox.actor_ref().clone());
401 bus.subscribe_to_group("t1", "G2", inbox.actor_ref().clone());
402 bus.subscribe_to_group("t2", "G1", inbox.actor_ref().clone());
403 assert_eq!(bus.group_count(), 3);
404 }
405}