Skip to main content

xtra_addons/
broker.rs

1use crate::*;
2use async_trait::async_trait;
3use std::collections::BTreeMap;
4use std::marker::PhantomData;
5
6// XXX: Use WeakMessageChannel and WeakAddress for broker
7
8#[derive(Eq, PartialEq, PartialOrd, Ord, Copy, Clone, Hash)]
9struct SubscriptionId(usize);
10
11struct Subscribe<T>
12where
13    T: Message<Result = ()> + Clone,
14{
15    channel: Box<dyn MessageChannel<T>>,
16}
17
18struct Unsubscribe<T>
19where
20    T: Message<Result = ()> + Clone,
21{
22    id: SubscriptionId,
23    _marker: PhantomData<T>,
24}
25
26struct Publish<T: Message<Result = ()> + Clone>(T);
27
28pub struct Subscription<T>
29where
30    T: Message<Result = ()> + Clone,
31{
32    id: SubscriptionId,
33    broker_addr: Address<Broker<T>>,
34}
35
36impl<T> Subscription<T>
37where
38    T: Message<Result = ()> + Clone,
39{
40    pub async fn unsubscribe(self) -> Result<(), ()> {
41        let Self { id, broker_addr } = self;
42        broker_addr
43            .send(Unsubscribe {
44                id,
45                _marker: PhantomData,
46            })
47            .await
48            .unwrap()
49    }
50}
51
52impl<T> Message for Subscribe<T>
53where
54    T: Message<Result = ()> + Clone,
55{
56    type Result = Result<Subscription<T>, ()>;
57}
58
59impl<T> Message for Unsubscribe<T>
60where
61    T: Message<Result = ()> + Clone,
62{
63    type Result = Result<(), ()>;
64}
65
66impl<T> Message for Publish<T>
67where
68    T: Message<Result = ()> + Clone,
69{
70    type Result = ();
71}
72
73pub struct Broker<T>
74where
75    T: Message<Result = ()> + Clone,
76{
77    next_id: usize,
78    subscriptions: BTreeMap<SubscriptionId, Box<dyn MessageChannel<T>>>,
79}
80
81impl<T> Broker<T>
82where
83    T: Message<Result = ()> + Clone,
84{
85    pub fn new() -> Self {
86        Self {
87            next_id: 0,
88            subscriptions: BTreeMap::new(),
89        }
90    }
91
92    pub async fn subscribe<A: Actor + Handler<T>>(
93        subscriber: Address<A>,
94    ) -> Result<Subscription<T>, ()> {
95        let broker = Self::from_registry().await;
96        subscriber.subscribe(broker).await
97    }
98
99    pub async fn publish(message: T) -> Result<(), xtra::Disconnected> {
100        let broker = Self::from_registry().await;
101        broker.publish(message).await
102    }
103}
104
105impl<T> Default for Broker<T>
106where
107    T: Message<Result = ()> + Clone,
108{
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114#[async_trait]
115impl<T> Actor for Broker<T> where T: Message<Result = ()> + Clone {}
116
117#[async_trait]
118impl<T> Handler<Subscribe<T>> for Broker<T>
119where
120    T: Message<Result = ()> + Clone,
121{
122    async fn handle(
123        &mut self,
124        message: Subscribe<T>,
125        ctx: &mut Context<Self>,
126    ) -> <Subscribe<T> as Message>::Result {
127        let broker_addr = ctx.address().map_err(|_| ())?;
128
129        let id = SubscriptionId(self.next_id);
130        self.next_id += 1;
131
132        self.subscriptions.insert(id, message.channel);
133
134        Ok(Subscription { id, broker_addr })
135    }
136}
137
138#[async_trait]
139impl<T> Handler<Unsubscribe<T>> for Broker<T>
140where
141    T: Message<Result = ()> + Clone,
142{
143    async fn handle(
144        &mut self,
145        message: Unsubscribe<T>,
146        _ctx: &mut Context<Self>,
147    ) -> <Unsubscribe<T> as Message>::Result {
148        match self.subscriptions.remove(&message.id) {
149            Some(_) => Ok(()),
150            None => Err(()),
151        }
152    }
153}
154
155#[async_trait]
156impl<T> Handler<Publish<T>> for Broker<T>
157where
158    T: Message<Result = ()> + Clone,
159{
160    async fn handle(&mut self, Publish(message): Publish<T>, _ctx: &mut Context<Self>) {
161        let mut disconnected: Vec<SubscriptionId> = Vec::new();
162
163        for (&id, subscriber) in &self.subscriptions {
164            match subscriber.do_send(message.clone()) {
165                Ok(()) => {}
166                Err(xtra::Disconnected) => {
167                    disconnected.push(id);
168                }
169            }
170        }
171
172        for id in disconnected {
173            self.subscriptions.remove(&id);
174        }
175    }
176}
177
178#[async_trait]
179pub trait SubscribeExt<M>
180where
181    M: Message<Result = ()> + Clone,
182{
183    async fn subscribe(&self, broker: Address<Broker<M>>) -> Result<Subscription<M>, ()>;
184}
185
186#[async_trait]
187impl<T, M> SubscribeExt<M> for Address<T>
188where
189    T: Handler<M>,
190    M: Message<Result = ()> + Clone,
191{
192    async fn subscribe(&self, broker: Address<Broker<M>>) -> Result<Subscription<M>, ()> {
193        broker
194            .send(Subscribe {
195                channel: Box::new(self.clone()),
196            })
197            .await
198            .map_err(|_| ())?
199    }
200}
201
202#[async_trait]
203pub trait PublishExt<M>
204where
205    M: Message<Result = ()> + Clone,
206{
207    async fn publish(&self, message: M) -> Result<(), xtra::Disconnected>;
208}
209
210#[async_trait]
211impl<M> PublishExt<M> for Address<Broker<M>>
212where
213    M: Message<Result = ()> + Clone,
214{
215    async fn publish(&self, message: M) -> Result<(), xtra::Disconnected> {
216        self.send(Publish(message)).await
217    }
218}
219
220#[cfg(test)]
221#[async_std::test]
222async fn test_broker() {
223    use xtra::spawn::AsyncStd;
224
225    #[derive(Clone)]
226    struct Msg {
227        msg: String,
228    }
229
230    impl Message for Msg {
231        type Result = ();
232    }
233
234    struct RetrieveMessages;
235
236    impl Message for RetrieveMessages {
237        type Result = Vec<String>;
238    }
239
240    struct Collector {
241        messages: Vec<String>,
242    }
243
244    impl Collector {
245        fn new() -> Self {
246            Self { messages: vec![] }
247        }
248    }
249
250    impl Actor for Collector {}
251
252    struct SubscriberA {
253        collector: Address<Collector>,
254    }
255
256    impl Actor for SubscriberA {}
257
258    struct SubscriberB {
259        collector: Address<Collector>,
260    }
261
262    impl Actor for SubscriberB {}
263
264    #[async_trait]
265    impl Handler<Msg> for Collector {
266        async fn handle(&mut self, Msg { msg }: Msg, _ctx: &mut Context<Self>) {
267            self.messages.push(msg);
268        }
269    }
270
271    #[async_trait]
272    impl Handler<RetrieveMessages> for Collector {
273        async fn handle(
274            &mut self,
275            _: RetrieveMessages,
276            _ctx: &mut Context<Self>,
277        ) -> <RetrieveMessages as Message>::Result {
278            let mut messages = vec![];
279
280            std::mem::swap(&mut self.messages, &mut messages);
281
282            messages
283        }
284    }
285
286    #[async_trait]
287    impl Handler<Msg> for SubscriberA {
288        async fn handle(&mut self, msg: Msg, _ctx: &mut Context<Self>) {
289            self.collector
290                .do_send(Msg {
291                    msg: format!("{} from SubscriberA", msg.msg),
292                })
293                .unwrap();
294        }
295    }
296
297    #[async_trait]
298    impl Handler<Msg> for SubscriberB {
299        async fn handle(&mut self, msg: Msg, _ctx: &mut Context<Self>) {
300            self.collector
301                .do_send(Msg {
302                    msg: format!("{} from SubscriberB", msg.msg),
303                })
304                .unwrap();
305        }
306    }
307
308    let broker = Broker::<Msg>::new().create(None).spawn(&mut AsyncStd);
309    let collector = Collector::new().create(None).spawn(&mut AsyncStd);
310
311    let subscriber_a = SubscriberA {
312        collector: collector.clone(),
313    }
314    .create(None)
315    .spawn(&mut AsyncStd);
316    let subscriber_b = SubscriberB {
317        collector: collector.clone(),
318    }
319    .create(None)
320    .spawn(&mut AsyncStd);
321
322    // Initially empty
323    assert!(collector.send(RetrieveMessages).await.unwrap().len() == 0);
324
325    let subscription_a = subscriber_a.subscribe(broker.clone()).await.unwrap();
326    let subscription_b = subscriber_b.subscribe(broker.clone()).await.unwrap();
327
328    // ---------------------------------------------
329
330    broker
331        .publish(Msg {
332            msg: "1".to_string(),
333        })
334        .await
335        .unwrap();
336
337    // Give enough time for the messages to be published to all receivers
338    async_std::task::sleep(std::time::Duration::from_millis(100)).await;
339
340    let msgs = collector.send(RetrieveMessages).await.unwrap();
341    if msgs[0].as_str() == "1 from SubscriberA" {
342        assert_eq!(
343            msgs,
344            vec![
345                "1 from SubscriberA".to_string(),
346                "1 from SubscriberB".to_string(),
347            ]
348        );
349    } else {
350        assert_eq!(
351            msgs,
352            vec![
353                "1 from SubscriberB".to_string(),
354                "1 from SubscriberA".to_string(),
355            ]
356        );
357    }
358
359    // ---------------------------------------------
360
361    subscription_b.unsubscribe().await.unwrap();
362
363    broker
364        .publish(Msg {
365            msg: "2".to_string(),
366        })
367        .await
368        .unwrap();
369
370    // Give enough time for the messages to be published to all receivers
371    async_std::task::sleep(std::time::Duration::from_millis(100)).await;
372
373    assert_eq!(
374        collector.send(RetrieveMessages).await.unwrap(),
375        vec!["2 from SubscriberA".to_string(),]
376    );
377
378    // ---------------------------------------------
379
380    subscription_a.unsubscribe().await.unwrap();
381
382    broker
383        .publish(Msg {
384            msg: "3".to_string(),
385        })
386        .await
387        .unwrap();
388
389    // Give enough time for the messages to be published to all receivers
390    async_std::task::sleep(std::time::Duration::from_millis(100)).await;
391
392    assert_eq!(
393        collector.send(RetrieveMessages).await.unwrap(),
394        Vec::<String>::new(),
395    );
396}
397
398#[cfg(test)]
399#[async_std::test]
400async fn test_broker_using_registry() {
401    #[derive(Clone)]
402    struct Msg {
403        msg: String,
404    }
405
406    impl Message for Msg {
407        type Result = ();
408    }
409
410    #[derive(Default)]
411    struct MyActor;
412
413    impl Actor for MyActor {}
414
415    #[async_trait]
416    impl Handler<Msg> for MyActor {
417        async fn handle(&mut self, _: Msg, _ctx: &mut Context<Self>) {}
418    }
419
420    let myactor = MyActor::from_registry().await;
421
422    let subscription = Broker::<Msg>::subscribe(myactor).await.unwrap();
423
424    Broker::<Msg>::publish(Msg { msg: "123".into() })
425        .await
426        .unwrap();
427
428    subscription.unsubscribe().await.unwrap();
429}