ripb/
lib.rs

1//! This crates provides an implementation of a lock-free, type-safe, in-process bus.
2//!
3//! # Guarantees
4//!
5//! - In order delivery : messages published on the bus are received in the same order by a given
6//! subscriber. This is not guaranteed across multiple subscribers. IE if you send Messages m1 and
7//! m2 on a bus with subscribers s1 and s2, both s1 and s2 will receive the messages in the order
8//! m1, m2, but s1 may receive m2 before s2 has received m1.
9//!
10//!
11//! # Implementation
12//!
13//! Current implementation uses [`crossbeam-channel`]s, a fixed number of threads [`Any`] and
14//! [`TypeId`] are used to to be able to expose a type-safe api
15use crossbeam_channel::{Receiver, RecvError, RecvTimeoutError, Select, Sender, TryRecvError};
16use std::any::{Any, TypeId};
17use std::cell::Cell;
18use std::collections::HashMap;
19use std::fmt::Debug;
20use std::num::NonZeroUsize;
21use std::sync::atomic::{AtomicUsize, Ordering};
22use std::sync::Arc;
23use std::thread;
24
25/// An in process bus.
26///
27/// You can create a bus with the [`new`] method, then create new [`Subscriber`]s with the
28/// [`create_subscriber`] method, and push some messages on the bus with the [`publish`] method.
29/// If you need to send messages from multiple threads, [`clone`] the bus and use one instance per
30/// thread
31#[derive(Clone)]
32pub struct Bus {
33    control: Sender<BusTask>,
34    subscriber_id_source: Arc<AtomicUsize>,
35}
36
37/// A subscriber to a [`Bus`].
38///
39/// Subscribers are created with the method [`Bus.create_subscriber`]
40///
41/// Register new callbacks with the [`on_message`] method, callback will live until the subscriber
42/// is dropped. If you need more control on callback lifecycle use [`on_message_with_token`] that
43/// will give you a [`SubscriptionToken`] you can use to [`unsubscribe`] a callback.
44pub struct Subscriber {
45    subscriber_id: usize,
46    control: Sender<BusTask>,
47    callback_id_source: AtomicUsize,
48}
49
50#[derive(Debug)]
51enum BusTask {
52    Publish {
53        type_id: TypeId,
54        message: Arc<BoxedMessage>,
55        worker: Worker,
56    },
57    RegisterSubscriber {
58        subscriber: SubscriberState,
59        subscriber_id: usize,
60    },
61    UnregisterSubscriber {
62        subscriber_id: usize,
63    },
64    RegisterSubscriberCallback {
65        subscriber_id: usize,
66        callback_id: usize,
67        type_id: TypeId,
68        callback: BoxedCallback,
69    },
70    UnregisterSubscriberCallback {
71        subscriber_id: usize,
72        callback_id: usize,
73        type_id: TypeId,
74    },
75    Stop {
76        halted_tx: Sender<()>,
77    },
78}
79
80/// A message on must be [`Send`] and [`Sync`] to be sent on the bus.
81pub trait Message: Send + Sync {}
82
83impl<T: Send + Sync> Message for T {}
84
85type BoxedMessage = Box<dyn Any + Send + Sync>;
86type BoxedCallback = Box<dyn Any + Send>;
87
88#[derive(Debug)]
89enum SubscriberTask {
90    Receive {
91        type_id: TypeId,
92        message: Arc<BoxedMessage>,
93        worker: Arc<Worker>,
94    },
95    RegisterCallback {
96        callback_id: usize,
97        type_id: TypeId,
98        callback: BoxedCallback,
99    },
100    UnregisterCallback {
101        callback_id: usize,
102        type_id: TypeId,
103    },
104}
105
106struct Worker {
107    worker: Box<dyn Fn(&BoxedCallback, Arc<BoxedMessage>) + Send + Sync>,
108}
109
110impl Debug for Worker {
111    fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
112        write!(f, "<Worker>")
113    }
114}
115
116impl Worker {
117    pub fn of<T: Message + 'static>() -> Self {
118        Self {
119            worker: Box::new(|c: &BoxedCallback, payload: Arc<BoxedMessage>| {
120                let callback = c
121                    .downcast_ref::<Callback<T>>()
122                    .expect("Could not downcast_ref for callback, this is a bug in ripb");
123                let message = Any::downcast_ref::<T>(&**payload)
124                    .expect("Could not downcast_ref for message, this is a bug in ripb");
125                callback.call(message);
126            }),
127        }
128    }
129
130    pub fn call(&self, callback: &BoxedCallback, payload: Arc<BoxedMessage>) {
131        (self.worker)(callback, payload)
132    }
133}
134
135struct Callback<M: Message> {
136    callback: Box<dyn Fn(&M) -> () + Send>,
137}
138
139impl<M: Message> Callback<M> {
140    pub fn new<F: 'static>(handler: F) -> Callback<M>
141    where
142        F: Fn(&M) -> () + Send,
143    {
144        Callback {
145            callback: Box::new(handler),
146        }
147    }
148
149    pub fn call(&self, arg: &M) {
150        (self.callback)(arg)
151    }
152}
153
154struct BusState {
155    subs: HashMap<usize, Cell<Receiver<SubscriberState>>>,
156    tasks: Receiver<BusTask>,
157    thread_count: usize,
158}
159
160impl Debug for BusState {
161    fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
162        write!(
163            f,
164            "BusState {{ subs: <{}>, tasks: {:?}, thread_count: {} }}",
165            self.subs.len(),
166            self.tasks,
167            self.thread_count
168        )
169    }
170}
171
172struct SubscriberState {
173    callbacks: HashMap<TypeId, Vec<(usize, BoxedCallback)>>,
174}
175
176impl Debug for SubscriberState {
177    fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
178        write!(
179            f,
180            "SubscriberState {{ callbacks: {:?} }}",
181            self.callbacks
182                .iter()
183                .map(|it| (it.0.clone(), format!("<{} callbacks>", it.1.len())))
184                .collect::<HashMap<TypeId, String>>()
185        )
186    }
187}
188
189#[derive(Debug)]
190enum BusWorkerTask {
191    ManageBusState {
192        state: BusState,
193    },
194    ManageSubscriberState {
195        subscriber_id: usize,
196        state: Receiver<SubscriberState>,
197        task: SubscriberTask,
198        next_state: Sender<SubscriberState>,
199    },
200    ManageSlowSubscribersStates {
201        subscriber_ids: Vec<usize>,
202        states: Vec<Receiver<SubscriberState>>,
203        tasks: Vec<SubscriberTask>,
204        next_states: Vec<Sender<SubscriberState>>,
205    },
206    Stop {
207        state: BusState,
208        halted_tx: Option<Sender<()>>,
209    },
210}
211
212struct BusWorker {
213    id: usize,
214    tasks: Receiver<BusWorkerTask>,
215    backlog: Sender<BusWorkerTask>,
216}
217
218impl BusWorker {
219    fn run(&self) {
220        log::info!("worker {} started", self.id);
221        loop {
222            match self.tasks.recv() {
223                Ok(task) => {
224                    if !self.handle_task(task) {
225                        break;
226                    }
227                }
228                Err(RecvError {}) => panic!("Task/backlog channel closed, this should not happen"),
229            }
230        }
231        log::info!("worker {} fisnished", self.id);
232    }
233
234    fn handle_task(&self, task: BusWorkerTask) -> bool {
235        log::debug!("bus worker {} handling {:?}", self.id, task);
236        return match task {
237            BusWorkerTask::ManageBusState { state } => self.manage_bus_state(state),
238            BusWorkerTask::ManageSubscriberState {
239                subscriber_id,
240                state,
241                task,
242                next_state,
243            } => self.manage_subscriber_state(subscriber_id, state, task, next_state),
244            BusWorkerTask::ManageSlowSubscribersStates {
245                subscriber_ids,
246                states,
247                tasks,
248                next_states,
249            } => self.manage_slow_subscribers_states(subscriber_ids, states, tasks, next_states),
250            BusWorkerTask::Stop { state, halted_tx } => self.handle_stop(state, halted_tx),
251        };
252    }
253
254    fn handle_stop(&self, mut state: BusState, halted_tx: Option<Sender<()>>) -> bool {
255        log::debug!("bus worker {} handling a stop task", self.id);
256
257        let should_continue = self.tasks.len() > 0;
258
259        if !should_continue {
260            state.thread_count -= 1;
261            log::debug!("threadcount is now {}", state.thread_count);
262        }
263
264        log::debug!(
265            "bus worker {} remaining subs : {}",
266            self.id,
267            state.subs.len(),
268        );
269
270        if state.thread_count > 0 {
271            // we own a receiver so sending should not fail
272            self.backlog
273                .send(BusWorkerTask::Stop { state, halted_tx })
274                .unwrap();
275        } else {
276            log::debug!("bus is done");
277            drop(state);
278            if let Some(halted_tx) = halted_tx {
279                halted_tx.send(()).unwrap();
280            }
281        }
282
283        should_continue
284    }
285
286    fn manage_bus_state(&self, state: BusState) -> bool {
287        log::debug!("bus worker {} managing the bus state", self.id);
288
289        let BusState {
290            mut subs,
291            tasks,
292            thread_count,
293        } = state;
294        let mut should_continue = true;
295        let mut halted_tx_opt = None;
296
297        let task = tasks
298            .recv()
299            .expect("bus task channel was closed without sending a stop command");
300
301        log::debug!(
302            "bus worker {} managing the bus state with task {:?}",
303            self.id,
304            task,
305        );
306
307        match task {
308            BusTask::Publish {
309                type_id,
310                message,
311                worker,
312            } => {
313                let worker = Arc::new(worker);
314
315                for (subscriber_id, sub) in subs.iter_mut() {
316                    let (next_state, new_sub) = crossbeam_channel::bounded(1);
317                    let task = SubscriberTask::Receive {
318                        type_id,
319                        message: Arc::clone(&message),
320                        worker: Arc::clone(&worker),
321                    };
322                    let state = sub.replace(new_sub);
323                    // we own a receiver so sending should not fail
324                    self.backlog
325                        .send(BusWorkerTask::ManageSubscriberState {
326                            subscriber_id: *subscriber_id,
327                            state,
328                            task,
329                            next_state,
330                        })
331                        .unwrap();
332                }
333            }
334            BusTask::RegisterSubscriber {
335                subscriber,
336                subscriber_id,
337            } => {
338                let (next_state, new_sub) = crossbeam_channel::bounded(1);
339                // receiver is still alive as it is in the scope
340                next_state.send(subscriber).unwrap();
341                subs.insert(subscriber_id, Cell::new(new_sub));
342            }
343            BusTask::UnregisterSubscriber { subscriber_id } => {
344                // make sure all tasks for this subscriber have finished executing
345                // the recv here is blocking, if it poses some performance problems it can be
346                // sub in a new bus task
347                subs.remove(&subscriber_id)
348                    .expect("trying to remove a non existing subscriber")
349                    .get_mut()
350                    .recv()
351                    .expect("subscriber channel should not be close");
352            }
353            BusTask::RegisterSubscriberCallback {
354                subscriber_id,
355                callback_id,
356                type_id,
357                callback,
358            } => {
359                let (next_state, new_sub) = crossbeam_channel::bounded(1);
360                if let Some(state) = subs.insert(subscriber_id, Cell::new(new_sub)) {
361                    let task = SubscriberTask::RegisterCallback {
362                        callback_id,
363                        type_id,
364                        callback,
365                    };
366                    // we own a receiver so sending should not fail
367                    self.backlog
368                        .send(BusWorkerTask::ManageSubscriberState {
369                            subscriber_id,
370                            state: state.into_inner(),
371                            task,
372                            next_state,
373                        })
374                        .unwrap();
375                } else {
376                    panic!("trying to register a callback for an unknown subscriber")
377                }
378            }
379            BusTask::UnregisterSubscriberCallback {
380                subscriber_id,
381                callback_id,
382                type_id,
383            } => {
384                let (next_state, new_sub) = crossbeam_channel::bounded(1);
385                if let Some(state) = subs.insert(subscriber_id, Cell::new(new_sub)) {
386                    let task = SubscriberTask::UnregisterCallback {
387                        callback_id,
388                        type_id,
389                    };
390                    // we own a receiver so sending should not fail
391                    self.backlog
392                        .send(BusWorkerTask::ManageSubscriberState {
393                            subscriber_id,
394                            state: state.into_inner(),
395                            task,
396                            next_state,
397                        })
398                        .unwrap();
399                } else {
400                    panic!("trying to unregister a callback for an unknown subscriber")
401                }
402            }
403            BusTask::Stop { halted_tx } => {
404                should_continue = false;
405                halted_tx_opt = Some(halted_tx)
406            }
407        }
408
409        let state = BusState {
410            subs,
411            tasks,
412            thread_count,
413        };
414
415        log::debug!(
416            "bus worker {} remaining subs : {}",
417            self.id,
418            state.subs.len(),
419        );
420
421        if should_continue {
422            // we own a receiver so sending should not fail
423            self.backlog
424                .send(BusWorkerTask::ManageBusState { state })
425                .unwrap();
426        } else {
427            // we own a receiver so sending should not fail
428            self.backlog
429                .send(BusWorkerTask::Stop {
430                    state,
431                    halted_tx: halted_tx_opt,
432                })
433                .unwrap()
434        }
435        return true;
436    }
437
438    fn manage_subscriber_state(
439        &self,
440        subscriber_id: usize,
441        state: Receiver<SubscriberState>,
442        task: SubscriberTask,
443        next_state: Sender<SubscriberState>,
444    ) -> bool {
445        log::debug!(
446            "bus worker {} trying to manage state of subscriber {}",
447            self.id,
448            subscriber_id,
449        );
450        match state.try_recv() {
451            Ok(state) => self.perform_subscriber_task(subscriber_id, state, task, next_state),
452            // If we simply repost the task there are some cases where we'll en up with a busy loop
453            // (e.g. if one subscriber is taking its time and there are only tasks for it in the
454            // backlog). We don't want that so we need to handle slow subscribers a little more
455            // carefully
456            Err(TryRecvError::Empty) => self
457                .backlog
458                .send(BusWorkerTask::ManageSlowSubscribersStates {
459                    subscriber_ids: vec![subscriber_id],
460                    states: vec![state],
461                    tasks: vec![task],
462                    next_states: vec![next_state],
463                })
464                .expect("backlog channel was disconnected"),
465            Err(TryRecvError::Disconnected) => {
466                panic!("Channel for subscriber state is disconnected")
467            }
468        }
469        return true;
470    }
471
472    fn perform_subscriber_task(
473        &self,
474        subscriber_id: usize,
475        mut state: SubscriberState,
476        task: SubscriberTask,
477        next_state: Sender<SubscriberState>,
478    ) {
479        log::debug!(
480            "bus worker {} performing task {:?} for subscriber {} with state {:?}",
481            self.id,
482            task,
483            subscriber_id,
484            state,
485        );
486        match task {
487            SubscriberTask::Receive {
488                type_id,
489                message,
490                worker,
491            } => {
492                state.callbacks.get(&type_id).map(|its| {
493                    for (_, it) in its {
494                        worker.call(it, Arc::clone(&message))
495                    }
496                });
497            }
498            SubscriberTask::RegisterCallback {
499                type_id,
500                callback_id,
501                callback,
502            } => {
503                state
504                    .callbacks
505                    .entry(type_id)
506                    .or_insert_with(|| vec![])
507                    .push((callback_id, callback));
508            }
509            SubscriberTask::UnregisterCallback {
510                type_id,
511                callback_id,
512            } => {
513                state
514                    .callbacks
515                    .get_mut(&type_id)
516                    .expect("Trying to unregister a callback on a type not seen yet")
517                    .retain(|(it, _)| *it != callback_id);
518            }
519        }
520
521        next_state
522            .send(state)
523            .expect("state channel for subscriber should not be disconnected");
524    }
525
526    fn manage_slow_subscribers_states(
527        &self,
528        mut subscriber_ids: Vec<usize>,
529        mut states: Vec<Receiver<SubscriberState>>,
530        mut tasks: Vec<SubscriberTask>,
531        mut next_states: Vec<Sender<SubscriberState>>,
532    ) -> bool {
533        log::debug!(
534            "bus worker {} trying to manage states of {} slow subscribers",
535            self.id,
536            subscriber_ids.len(),
537        );
538        enum Action {
539            Exec {
540                state: SubscriberState,
541                index: usize,
542            },
543            Merge {
544                other_ids: Vec<usize>,
545                other_states: Vec<Receiver<SubscriberState>>,
546                other_tasks: Vec<SubscriberTask>,
547                other_next_states: Vec<Sender<SubscriberState>>,
548            },
549            Other {
550                task: BusWorkerTask,
551            },
552        }
553
554        let action = {
555            let mut select = Select::new();
556            for state in &states {
557                select.recv(state);
558            }
559
560            select.recv(&self.tasks);
561
562            let oper = select.select();
563            let index = oper.index();
564
565            if index == states.len() {
566                let task = oper.recv(&self.tasks);
567                match task {
568                    Ok(BusWorkerTask::ManageSlowSubscribersStates {
569                        subscriber_ids: other_ids,
570                        states: other_states,
571                        tasks: other_tasks,
572                        next_states: other_next_states,
573                    }) => Action::Merge {
574                        other_ids,
575                        other_states,
576                        other_tasks,
577                        other_next_states,
578                    },
579                    Ok(task) => Action::Other { task },
580                    Err(RecvError {}) => {
581                        panic!("Task/backlog channel closed, this should not happen")
582                    }
583                }
584            } else {
585                let state = oper.recv(&states[index]);
586                match state {
587                    Ok(state) => Action::Exec { state, index },
588                    Err(RecvError {}) => {
589                        panic!("state channel for slow subscriber should not be disconnected")
590                    }
591                }
592            }
593        };
594
595        return match action {
596            Action::Exec { state, index } => {
597                states.remove(index);
598                let subscriber_id = subscriber_ids.remove(index);
599                log::debug!(
600                    "bus worker {} managing state of subscriber {}",
601                    self.id,
602                    subscriber_id,
603                );
604                let task = tasks.remove(index);
605                let next_state = next_states.remove(index);
606                if states.len() > 0 {
607                    // we own a receiver so sending should not fail
608                    self.backlog
609                        .send(BusWorkerTask::ManageSlowSubscribersStates {
610                            subscriber_ids,
611                            states,
612                            tasks,
613                            next_states,
614                        })
615                        .expect("backlog channel was disconnected");
616                }
617                self.perform_subscriber_task(subscriber_id, state, task, next_state);
618                true
619            }
620            Action::Merge {
621                mut other_ids,
622                mut other_states,
623                mut other_tasks,
624                mut other_next_states,
625            } => {
626                log::debug!(
627                    "bus worker {} adding {} subscribers to task already containing {}",
628                    self.id,
629                    other_ids.len(),
630                    subscriber_ids.len(),
631                );
632                subscriber_ids.append(&mut other_ids);
633                states.append(&mut other_states);
634                tasks.append(&mut other_tasks);
635                next_states.append(&mut other_next_states);
636                // we own a receiver so sending should not fail
637                self.backlog
638                    .send(BusWorkerTask::ManageSlowSubscribersStates {
639                        subscriber_ids,
640                        states,
641                        tasks,
642                        next_states,
643                    })
644                    .expect("backlog channel was disconnected");
645                true
646            }
647            Action::Other { task } => {
648                log::debug!("bus worker {} executing another task", self.id);
649                // we own a receiver so sending should not fail
650                self.backlog
651                    .send(BusWorkerTask::ManageSlowSubscribersStates {
652                        subscriber_ids,
653                        states,
654                        tasks,
655                        next_states,
656                    })
657                    .expect("backlog channel was disconnected");
658                self.handle_task(task)
659            }
660        };
661    }
662}
663
664/// Error produced when a bus operation is impossible. Getting such an error the bus is dead
665#[derive(Debug)]
666pub struct DeadBusError;
667
668impl std::error::Error for DeadBusError {}
669
670impl std::fmt::Display for DeadBusError {
671    fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
672        write!(
673            f,
674            "A new callback was registered on a ripb subscriber for a dropped bus"
675        )
676    }
677}
678
679impl Bus {
680    /// Create a new bus, with a thread count equal to the number of CPUs
681    pub fn new() -> Self {
682        Self::with_thread_count(
683            NonZeroUsize::new(num_cpus::get()).expect("Number of CPU should be non zero"),
684        )
685    }
686
687    /// Create a new bus with the given thread count
688    pub fn with_thread_count(thread_count: NonZeroUsize) -> Self {
689        let (backlog, tasks) = crossbeam_channel::unbounded();
690        let (control, bus_tasks) = crossbeam_channel::unbounded();
691        let thread_count = thread_count.get();
692
693        for id in 0..thread_count {
694            let worker = BusWorker {
695                id,
696                tasks: tasks.clone(),
697                backlog: backlog.clone(),
698            };
699            let _ = thread::Builder::new()
700                .name(format!("ripb.worker{}", id))
701                .spawn(move || worker.run());
702        }
703
704        // receiver still in scope so sending should not fail
705        backlog
706            .send(BusWorkerTask::ManageBusState {
707                state: BusState {
708                    subs: HashMap::new(),
709                    tasks: bus_tasks,
710                    thread_count,
711                },
712            })
713            .expect("backlog channel was disconnected");
714
715        Bus {
716            control,
717            subscriber_id_source: Arc::new(AtomicUsize::from(0)),
718        }
719    }
720
721    /// Create a new subscriber for this bus
722    pub fn create_subscriber(&self) -> Subscriber {
723        let subscriber_id = self.subscriber_id_source.fetch_add(1, Ordering::Relaxed);
724        let control = self.control.clone();
725        let callback_id_source = AtomicUsize::from(0);
726        self.control
727            .send(BusTask::RegisterSubscriber {
728                subscriber: SubscriberState {
729                    callbacks: HashMap::new(),
730                },
731                subscriber_id,
732            })
733            .expect("could not communicate with the bus, did a worker thread panic ?");
734        Subscriber {
735            subscriber_id,
736            control,
737            callback_id_source,
738        }
739    }
740
741    /// Publish a new message on this bus
742    pub fn publish<M: Message + 'static>(&self, message: M) {
743        self.control
744            .send(BusTask::Publish {
745                type_id: TypeId::of::<M>(),
746                message: Arc::new(Box::new(message)),
747                worker: Worker::of::<M>(),
748            })
749            .expect("could not communicate with the bus, did a worker thread panic ?")
750    }
751}
752
753impl Drop for Bus {
754    fn drop(&mut self) {
755        let arc_count = Arc::strong_count(&self.subscriber_id_source);
756        log::debug!("Dropping a Bus, arc count is {}", arc_count);
757        if arc_count == 1 {
758            log::debug!("This is the last instance for this bus, killing it");
759            let (halted_tx, halted_rx) = crossbeam_channel::bounded(1);
760            self.control
761                .send(BusTask::Stop { halted_tx })
762                .expect("control channel was disconnected");
763            match halted_rx.recv_timeout(std::time::Duration::from_secs(5)) {
764                Err(RecvTimeoutError::Timeout) => {
765                    panic!("bus didn't stop properly after 5 seconds")
766                }
767                Err(RecvTimeoutError::Disconnected) => panic!("stop channel was closed"),
768                Ok(()) => {}
769            }
770        }
771    }
772}
773
774impl Subscriber {
775    fn on_message_inner<F, M>(&self, callback: F) -> Result<usize, DeadBusError>
776    where
777        F: Fn(&M) + Send + 'static,
778        M: Message + 'static,
779    {
780        let callback_id = self.callback_id_source.fetch_add(1, Ordering::Relaxed);
781        self.control
782            .send(BusTask::RegisterSubscriberCallback {
783                subscriber_id: self.subscriber_id,
784                callback_id,
785                callback: Box::new(Callback::new(callback)),
786                type_id: TypeId::of::<M>(),
787            })
788            .map_err(|_| DeadBusError)?;
789        Ok(callback_id)
790    }
791
792    /// Register a new callback to be called each time a message of the given type is published on
793    /// the bus, callback lives as until the `Subscriber` is dropped
794    pub fn on_message<F, M>(&self, callback: F) -> Result<(), DeadBusError>
795    where
796        F: Fn(&M) + Send + 'static,
797        M: Message + 'static,
798    {
799        self.on_message_inner(callback)?;
800        Ok(())
801    }
802
803    /// Register a new callback to be called each time a message of the given type is published on
804    /// the bus, callback lives as until the `SubscriptionToken` is dropped, `unsubscribe` is
805    /// called on it or the `Subscriber` is dropped
806    pub fn on_message_with_token<F, M>(
807        &self,
808        callback: F,
809    ) -> Result<SubscriptionToken, DeadBusError>
810    where
811        F: Fn(&M) + Send + 'static,
812        M: Message + 'static,
813    {
814        Ok(SubscriptionToken {
815            subscriber_id: self.subscriber_id,
816            callback_id: self.on_message_inner(callback)?,
817            type_id: TypeId::of::<M>(),
818            subscriber: &self,
819        })
820    }
821}
822
823impl Drop for Subscriber {
824    fn drop(&mut self) {
825        // this may fail if the bus is already stopped, but in that case we don't care
826        let _ = self.control.send(BusTask::UnregisterSubscriber {
827            subscriber_id: self.subscriber_id,
828        });
829    }
830}
831
832#[must_use]
833pub struct SubscriptionToken<'a> {
834    type_id: TypeId,
835    subscriber_id: usize,
836    callback_id: usize,
837    subscriber: &'a Subscriber,
838}
839
840impl<'a> SubscriptionToken<'a> {
841    pub fn unsubscribe(self) {
842        // unsubscription is done when drop occurs
843        drop(self)
844    }
845}
846
847impl<'a> Drop for SubscriptionToken<'a> {
848    fn drop(&mut self) {
849        // this may fail if the bus is already stopped, but in that case we don't care
850        let _ = self
851            .subscriber
852            .control
853            .send(BusTask::UnregisterSubscriberCallback {
854                type_id: self.type_id,
855                callback_id: self.callback_id,
856                subscriber_id: self.subscriber_id,
857            });
858    }
859}
860
861#[cfg(test)]
862mod tests {
863    use super::*;
864    use std::sync::mpsc::channel;
865    use std::thread;
866    use std::time::Duration;
867
868    #[test]
869    fn can_send_a_simple_message_to_a_subscriber() {
870        let bus = Bus::new();
871        let subscriber = bus.create_subscriber();
872
873        let (tx, rx) = channel();
874
875        subscriber
876            .on_message(move |_: &()| tx.send(()).unwrap())
877            .unwrap();
878
879        bus.publish(());
880
881        assert!(rx.recv_timeout(Duration::from_secs(1)).is_ok());
882    }
883
884    #[test]
885    fn can_send_a_simple_message_to_2_subscribers() {
886        let bus = Bus::new();
887        let subscriber = bus.create_subscriber();
888        let subscriber2 = bus.create_subscriber();
889
890        let (tx, rx) = channel();
891        let (tx2, rx2) = channel();
892
893        subscriber
894            .on_message(move |_: &()| tx.send(()).unwrap())
895            .unwrap();
896        subscriber2
897            .on_message(move |_: &()| tx2.send(()).unwrap())
898            .unwrap();
899
900        bus.publish(());
901
902        assert!(rx.recv_timeout(Duration::from_secs(1)).is_ok());
903        assert!(rx2.recv_timeout(Duration::from_secs(1)).is_ok());
904    }
905
906    #[test]
907    fn can_send_a_complex_message_to_a_subscriber() {
908        let bus = Bus::new();
909        let subscriber = bus.create_subscriber();
910
911        let (tx, rx) = channel();
912
913        struct Message {
914            payload: String,
915        }
916
917        subscriber
918            .on_message(move |m: &Message| tx.send(m.payload.clone()).unwrap())
919            .unwrap();
920
921        bus.publish(Message {
922            payload: "hello world".into(),
923        });
924
925        let result = rx.recv_timeout(Duration::from_secs(1));
926
927        assert!(result.is_ok());
928        assert_eq!(result.unwrap(), "hello world".to_string());
929    }
930
931    #[test]
932    fn can_send_simple_messages_to_a_subscriber_from_multiple_threads() {
933        let bus = Bus::new();
934        let subscriber = bus.create_subscriber();
935
936        let (tx, rx) = channel();
937
938        subscriber
939            .on_message(move |_: &()| tx.send(()).unwrap())
940            .unwrap();
941
942        bus.publish(());
943        thread::spawn(move || bus.publish(()));
944
945        assert!(rx.recv_timeout(Duration::from_secs(1)).is_ok());
946        assert!(rx.recv_timeout(Duration::from_secs(1)).is_ok());
947    }
948
949    #[test]
950    fn can_send_simple_messages_to_a_subscriber_from_cloned_instance() {
951        let bus = Bus::new();
952        let subscriber = bus.create_subscriber();
953
954        let (tx, rx) = channel();
955
956        subscriber
957            .on_message(move |_: &()| tx.send(()).unwrap())
958            .unwrap();
959
960        bus.publish(());
961        let bus2 = bus.clone();
962        bus2.publish(());
963
964        assert!(rx.recv_timeout(Duration::from_secs(1)).is_ok());
965        assert!(rx.recv_timeout(Duration::from_secs(1)).is_ok());
966    }
967
968    #[test]
969    fn can_unsubscribe_callbacks() {
970        let bus = Bus::new();
971        let subscriber = bus.create_subscriber();
972
973        let (tx, rx) = channel();
974
975        let token = subscriber
976            .on_message_with_token(move |_: &()| tx.send(()).unwrap())
977            .unwrap();
978        bus.publish(());
979        assert!(rx.recv_timeout(Duration::from_secs(1)).is_ok());
980
981        token.unsubscribe();
982        bus.publish(());
983        assert!(rx.recv_timeout(Duration::from_secs(1)).is_err());
984    }
985
986    #[test]
987    fn cannot_receive_messages_in_a_dropped_subscriber() {
988        let bus = Bus::new();
989        let (tx, rx) = channel();
990        {
991            let subscriber = bus.create_subscriber();
992
993            subscriber
994                .on_message(move |_: &()| tx.send(()).unwrap())
995                .unwrap();
996            bus.publish(());
997            assert!(rx.recv_timeout(Duration::from_secs(1)).is_ok());
998            // subscriber is dropped here
999        };
1000
1001        bus.publish(());
1002
1003        assert!(rx.recv_timeout(Duration::from_secs(1)).is_err());
1004    }
1005
1006    #[test]
1007    fn dropping_subscribers_drops_the_corresponding_subscription() {
1008        fn drop_subscriber(_sub: Subscriber) {}
1009
1010        let bus = Bus::new();
1011        let (tx, rx) = channel();
1012        let (tx2, rx2) = channel();
1013        let (tx3, rx3) = channel();
1014
1015        let subscriber = bus.create_subscriber();
1016        subscriber
1017            .on_message(move |_: &()| tx.send(()).unwrap())
1018            .unwrap();
1019        let subscriber2 = bus.create_subscriber();
1020        subscriber2
1021            .on_message(move |_: &()| tx2.send(()).unwrap())
1022            .unwrap();
1023        let subscriber3 = bus.create_subscriber();
1024        subscriber3
1025            .on_message(move |_: &()| tx3.send(()).unwrap())
1026            .unwrap();
1027
1028        bus.publish(());
1029        assert!(rx.recv_timeout(Duration::from_secs(1)).is_ok());
1030        assert!(rx2.recv_timeout(Duration::from_secs(1)).is_ok());
1031        assert!(rx3.recv_timeout(Duration::from_secs(1)).is_ok());
1032        drop_subscriber(subscriber);
1033        drop_subscriber(subscriber3);
1034
1035        bus.publish(());
1036
1037        assert!(rx.recv_timeout(Duration::from_secs(1)).is_err());
1038        assert!(rx2.recv_timeout(Duration::from_secs(1)).is_ok());
1039        assert!(rx3.recv_timeout(Duration::from_secs(1)).is_err());
1040    }
1041
1042    #[test]
1043    fn can_register_multiple_message_handlers_for_the_same_message_on_a_single_receiver() {
1044        fn drop_token(_sub: SubscriptionToken) {}
1045
1046        let bus = Bus::new();
1047        let (tx, rx) = channel();
1048        let (tx2, rx2) = channel();
1049        let (tx3, rx3) = channel();
1050
1051        let subscriber = bus.create_subscriber();
1052        let t = subscriber
1053            .on_message_with_token(move |_: &()| tx.send(()).unwrap())
1054            .unwrap();
1055        subscriber
1056            .on_message(move |_: &()| tx2.send(()).unwrap())
1057            .unwrap();
1058        let t3 = subscriber
1059            .on_message_with_token(move |_: &()| tx3.send(()).unwrap())
1060            .unwrap();
1061
1062        bus.publish(());
1063        assert!(rx.recv_timeout(Duration::from_secs(1)).is_ok());
1064        assert!(rx2.recv_timeout(Duration::from_secs(1)).is_ok());
1065        assert!(rx3.recv_timeout(Duration::from_secs(1)).is_ok());
1066
1067        t.unsubscribe();
1068        drop_token(t3);
1069
1070        bus.publish(());
1071
1072        assert!(rx.recv_timeout(Duration::from_secs(1)).is_err());
1073        assert!(rx2.recv_timeout(Duration::from_secs(1)).is_ok());
1074        assert!(rx3.recv_timeout(Duration::from_secs(1)).is_err());
1075    }
1076
1077    #[test]
1078    fn subscriber_on_a_dropped_bus_should_generate_dead_bus_error_on_subscribe() {
1079        let bus = Bus::new();
1080        let subscriber = bus.create_subscriber();
1081        drop(bus);
1082        let r = subscriber.on_message(|_: &()| {});
1083        assert!(r.is_err())
1084    }
1085
1086    #[test]
1087    fn can_drop_the_bus_while_it_is_still_working() {
1088        let bus = Bus::new();
1089        let subscriber = bus.create_subscriber();
1090        let (tx, rx) = crossbeam_channel::unbounded();
1091        subscriber
1092            .on_message(move |_: &()| tx.send(()).unwrap())
1093            .unwrap();
1094
1095        for _ in 0..100 {
1096            bus.publish(())
1097        }
1098
1099        drop(bus);
1100
1101        assert_eq!(rx.len(), 100)
1102    }
1103
1104    #[test]
1105    #[ignore]
1106    fn no_busyloop() {
1107        // Watch cpu usage during the 2 first secs of executing this, it should be nearly 0%
1108        // you may want to up the sleep time to make this more visible
1109
1110        let bus = Bus::with_thread_count(NonZeroUsize::new(4).unwrap());
1111
1112        let subscriber = bus.create_subscriber();
1113        let (tx, rx) = channel();
1114
1115        subscriber
1116            .on_message(move |_: &()| {
1117                ::std::thread::sleep(::std::time::Duration::from_secs(1));
1118                tx.send(()).unwrap();
1119            })
1120            .unwrap();
1121
1122        bus.publish(());
1123        bus.publish(());
1124        bus.publish(());
1125
1126        assert_eq!(rx.recv().expect("recv 1"), ());
1127        assert_eq!(rx.recv().expect("recv 2"), ());
1128        assert_eq!(rx.recv().expect("recv 3"), ());
1129    }
1130}