1use std::collections::HashMap;
17use std::sync::atomic::{AtomicUsize, Ordering};
18use std::sync::Arc;
19
20use parking_lot::RwLock;
21
22use atomr_core::actor::{ActorRef, UntypedActorRef};
23
24#[derive(Default)]
25pub struct DistributedPubSub {
26 topics: RwLock<HashMap<String, Vec<TypedSubscriber>>>,
27 groups: RwLock<HashMap<(String, String), Group>>,
28}
29
30type DeliverAnyFn = Box<dyn Fn(&dyn std::any::Any) -> bool + Send + Sync>;
31type CodecFn = Box<dyn Fn(&[u8]) -> bool + Send + Sync>;
32
33struct TypedSubscriber {
37 untyped: UntypedActorRef,
38 deliver_any: DeliverAnyFn,
39}
40
41#[derive(Default)]
42struct Group {
43 members: Vec<TypedSubscriber>,
44 cursor: AtomicUsize,
45}
46
47impl DistributedPubSub {
48 pub fn new() -> Arc<Self> {
49 Arc::new(Self::default())
50 }
51
52 pub fn subscribe<M: Clone + Send + 'static>(&self, topic: impl Into<String>, subscriber: ActorRef<M>) {
55 let typed = TypedSubscriber::new(subscriber);
56 self.topics.write().entry(topic.into()).or_default().push(typed);
57 }
58
59 pub fn subscribe_to_group<M: Clone + Send + 'static>(
62 &self,
63 topic: impl Into<String>,
64 group: impl Into<String>,
65 subscriber: ActorRef<M>,
66 ) {
67 let typed = TypedSubscriber::new(subscriber);
68 self.groups.write().entry((topic.into(), group.into())).or_default().members.push(typed);
69 }
70
71 pub fn unsubscribe(&self, topic: &str, subscriber_path: &atomr_core::actor::ActorPath) {
73 if let Some(v) = self.topics.write().get_mut(topic) {
74 v.retain(|s| s.untyped.path() != subscriber_path);
75 }
76 }
77
78 pub fn publish(&self, topic: &str) -> Vec<UntypedActorRef> {
81 self.topics
82 .read()
83 .get(topic)
84 .map(|v| v.iter().map(|s| s.untyped.clone()).collect())
85 .unwrap_or_default()
86 }
87
88 pub fn publish_msg<M: Clone + Send + 'static>(&self, topic: &str, msg: M) -> usize {
91 let subs = self.topics.read();
92 let Some(list) = subs.get(topic) else {
93 return 0;
94 };
95 let mut delivered = 0;
96 let any: &dyn std::any::Any = &msg;
97 for s in list {
98 if (s.deliver_any)(any) {
99 delivered += 1;
100 }
101 }
102 let _ = msg; delivered
107 }
108
109 pub fn send_to_group<M: Clone + Send + 'static>(&self, topic: &str, group: &str, msg: M) -> bool {
112 let groups = self.groups.read();
113 let Some(g) = groups.get(&(topic.to_string(), group.to_string())) else {
114 return false;
115 };
116 if g.members.is_empty() {
117 return false;
118 }
119 let i = g.cursor.fetch_add(1, Ordering::Relaxed) % g.members.len();
120 let any: &dyn std::any::Any = &msg;
121 let r = (g.members[i].deliver_any)(any);
122 let _ = msg;
123 r
124 }
125
126 pub fn topic_count(&self) -> usize {
127 self.topics.read().len()
128 }
129
130 pub fn group_count(&self) -> usize {
131 self.groups.read().len()
132 }
133}
134
135use std::collections::HashSet;
140
141pub trait MediatorTransport: Send + Sync + 'static {
147 fn send(&self, target_node: &str, pdu: MediatorPdu);
148}
149
150#[derive(Debug, Clone)]
152#[non_exhaustive]
153pub enum MediatorPdu {
154 TopicAnnounce { from: String, topics: Vec<String> },
156 Forward { topic: String, msg_blob: Vec<u8>, type_id: String },
159}
160
161pub struct ClusterPubSub {
167 local: Arc<DistributedPubSub>,
168 self_node: String,
169 remote_topics: RwLock<HashMap<String, HashSet<String>>>,
171 transport: Arc<dyn MediatorTransport>,
172 codecs: RwLock<HashMap<String, CodecFn>>,
173}
174
175impl ClusterPubSub {
176 pub fn new(
177 local: Arc<DistributedPubSub>,
178 self_node: impl Into<String>,
179 transport: Arc<dyn MediatorTransport>,
180 ) -> Arc<Self> {
181 Arc::new(Self {
182 local,
183 self_node: self_node.into(),
184 remote_topics: RwLock::new(HashMap::new()),
185 transport,
186 codecs: RwLock::new(HashMap::new()),
187 })
188 }
189
190 pub fn register_decoder<F>(&self, type_id: impl Into<String>, decode: F)
195 where
196 F: Fn(&[u8]) -> bool + Send + Sync + 'static,
197 {
198 self.codecs.write().insert(type_id.into(), Box::new(decode));
199 }
200
201 pub fn announce_to(&self, target_node: &str) {
204 let topics: Vec<String> = self.local.topics.read().keys().cloned().collect();
205 self.transport.send(target_node, MediatorPdu::TopicAnnounce { from: self.self_node.clone(), topics });
206 }
207
208 pub fn apply_pdu(&self, pdu: MediatorPdu) {
210 match pdu {
211 MediatorPdu::TopicAnnounce { from, topics } => {
212 let mut g = self.remote_topics.write();
213 for set in g.values_mut() {
215 set.remove(&from);
216 }
217 for t in topics {
218 g.entry(t).or_default().insert(from.clone());
219 }
220 }
221 MediatorPdu::Forward { topic, msg_blob, type_id } => {
222 let codecs = self.codecs.read();
223 if let Some(decode) = codecs.get(&type_id) {
224 let _ = decode(&msg_blob);
225 let _ = topic;
229 }
230 }
231 }
232 }
233
234 pub fn publish_remote<M, S>(&self, topic: &str, msg: M, type_id: impl Into<String>, encode: S) -> usize
238 where
239 M: Clone + Send + 'static,
240 S: FnOnce(&M) -> Vec<u8>,
241 {
242 let local_n = self.local.publish_msg(topic, msg.clone());
243 let remote = self.remote_topics.read();
244 let Some(nodes) = remote.get(topic) else { return local_n };
245 let blob = encode(&msg);
246 let type_id = type_id.into();
247 let mut forwarded = 0;
248 for node in nodes {
249 if node == &self.self_node {
250 continue;
251 }
252 self.transport.send(
253 node,
254 MediatorPdu::Forward {
255 topic: topic.into(),
256 msg_blob: blob.clone(),
257 type_id: type_id.clone(),
258 },
259 );
260 forwarded += 1;
261 }
262 local_n + forwarded
263 }
264
265 pub fn known_remote_topics(&self) -> usize {
266 self.remote_topics.read().len()
267 }
268
269 pub fn nodes_for(&self, topic: &str) -> Vec<String> {
270 self.remote_topics.read().get(topic).map(|s| s.iter().cloned().collect()).unwrap_or_default()
271 }
272}
273
274impl TypedSubscriber {
275 fn new<M: Clone + Send + 'static>(r: ActorRef<M>) -> Self {
276 let untyped = r.as_untyped();
277 let r2 = r.clone();
278 let deliver_any: DeliverAnyFn = Box::new(move |any| {
279 if let Some(m) = any.downcast_ref::<M>() {
280 r2.tell(m.clone());
281 true
282 } else {
283 false
284 }
285 });
286 Self { untyped, deliver_any }
287 }
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293 use atomr_core::actor::Inbox;
294 use std::time::Duration;
295
296 #[test]
297 fn subscribe_and_publish_returns_subscriber_list() {
298 let bus = DistributedPubSub::new();
299 let inbox = Inbox::<u32>::new("s");
300 bus.subscribe("greetings", inbox.actor_ref().clone());
301 let subs = bus.publish("greetings");
302 assert_eq!(subs.len(), 1);
303 }
304
305 #[tokio::test]
306 async fn typed_publish_delivers_to_each_subscriber() {
307 let bus = DistributedPubSub::new();
308 let mut a = Inbox::<u32>::new("a");
309 let mut b = Inbox::<u32>::new("b");
310 bus.subscribe("nums", a.actor_ref().clone());
311 bus.subscribe("nums", b.actor_ref().clone());
312
313 let n = bus.publish_msg("nums", 7u32);
314 assert_eq!(n, 2);
315
316 assert_eq!(a.receive(Duration::from_millis(50)).await.unwrap(), 7);
317 assert_eq!(b.receive(Duration::from_millis(50)).await.unwrap(), 7);
318 }
319
320 #[tokio::test]
321 async fn publish_to_unknown_topic_delivers_zero() {
322 let bus = DistributedPubSub::new();
323 let n = bus.publish_msg("nope", 1u32);
324 assert_eq!(n, 0);
325 }
326
327 #[tokio::test]
328 async fn group_send_round_robins_one_member() {
329 let bus = DistributedPubSub::new();
330 let mut a = Inbox::<u32>::new("ga");
331 let mut b = Inbox::<u32>::new("gb");
332 bus.subscribe_to_group("work", "G1", a.actor_ref().clone());
333 bus.subscribe_to_group("work", "G1", b.actor_ref().clone());
334
335 for i in 0..4u32 {
337 assert!(bus.send_to_group("work", "G1", i));
338 }
339 let mut a_count = 0;
340 let mut b_count = 0;
341 for _ in 0..2 {
342 a.receive(Duration::from_millis(20)).await.unwrap();
343 a_count += 1;
344 b.receive(Duration::from_millis(20)).await.unwrap();
345 b_count += 1;
346 }
347 assert_eq!(a_count, 2);
348 assert_eq!(b_count, 2);
349 }
350
351 #[derive(Default, Clone)]
352 struct CapturingTransport {
353 sent: Arc<parking_lot::Mutex<Vec<(String, MediatorPdu)>>>,
354 }
355 impl MediatorTransport for CapturingTransport {
356 fn send(&self, target: &str, pdu: MediatorPdu) {
357 self.sent.lock().push((target.to_string(), pdu));
358 }
359 }
360
361 #[tokio::test]
362 async fn cluster_pub_sub_announce_and_forward_round_trip() {
363 let local_a = DistributedPubSub::new();
364 let local_b = DistributedPubSub::new();
365 let mut subscriber = Inbox::<u32>::new("sub");
366 local_b.subscribe("nums", subscriber.actor_ref().clone());
367 let net = CapturingTransport::default();
368 let net_arc: Arc<dyn MediatorTransport> = Arc::new(net.clone());
369 let a = ClusterPubSub::new(local_a.clone(), "node-a", net_arc.clone());
370 let b = ClusterPubSub::new(local_b.clone(), "node-b", net_arc);
371
372 b.announce_to("node-a");
374 let pdu = net.sent.lock().pop().unwrap().1;
375 a.apply_pdu(pdu);
376 assert_eq!(a.known_remote_topics(), 1);
377 assert_eq!(a.nodes_for("nums"), vec!["node-b".to_string()]);
378
379 let local_b2 = local_b.clone();
381 b.register_decoder("u32", move |bytes| {
382 let n = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
383 local_b2.publish_msg::<u32>("nums", n) > 0
384 });
385
386 let n = a.publish_remote::<u32, _>("nums", 42, "u32", |m| m.to_le_bytes().to_vec());
388 assert_eq!(n, 1);
389 let (target, fwd) = net.sent.lock().pop().unwrap();
390 assert_eq!(target, "node-b");
391 b.apply_pdu(fwd);
392 assert_eq!(subscriber.receive(std::time::Duration::from_millis(50)).await.unwrap(), 42);
393 }
394
395 #[test]
396 fn group_count_tracks_distinct_buckets() {
397 let bus = DistributedPubSub::new();
398 let inbox = Inbox::<u32>::new("g");
399 bus.subscribe_to_group("t1", "G1", inbox.actor_ref().clone());
400 bus.subscribe_to_group("t1", "G2", inbox.actor_ref().clone());
401 bus.subscribe_to_group("t2", "G1", inbox.actor_ref().clone());
402 assert_eq!(bus.group_count(), 3);
403 }
404}