pipeworks_core/
bus.rs

1use once_cell::sync::OnceCell;
2use std::{
3    any::{TypeId, type_name},
4    collections::HashMap,
5    fmt::Debug,
6    sync::Arc,
7    time::SystemTime,
8};
9use tokio::{
10    sync::{broadcast, mpsc},
11    task::JoinHandle,
12};
13use tracing::trace;
14
15use crate::{
16    channel::{Channel, SerdePayload},
17    event::BusEvent,
18    node_id::NODE_ID,
19    reg::{BusType, TypeRegFn},
20};
21
22#[derive(Clone)]
23pub struct Bus {
24    channels: Arc<HashMap<TypeId, Channel>>,
25}
26
27impl Bus {
28    pub fn new() -> Self {
29        let channels = inventory::iter::<TypeRegFn>
30            .into_iter()
31            .map(|type_reg_fn| {
32                let type_reg = type_reg_fn.0();
33                (type_reg.type_id, Channel::from_type_reg(type_reg.clone()))
34            })
35            .collect();
36
37        Self {
38            channels: Arc::new(channels),
39        }
40    }
41
42    pub fn get_default() -> Self {
43        static INSTANCE: OnceCell<Bus> = OnceCell::new();
44        INSTANCE.get_or_init(|| Bus::new()).clone()
45    }
46
47    pub fn send<T: BusType>(&self, msg: T) {
48        let time = SystemTime::now();
49        match self.channels.get(&TypeId::of::<T>()) {
50            Some(channel) => channel.send_event(BusEvent {
51                source: *NODE_ID,
52                time,
53                msg,
54            }),
55            None => panic!("Type {} was never registered in the bus", type_name::<T>()),
56        }
57    }
58
59    pub fn send_event<T: BusType>(&self, event: BusEvent<T>) {
60        match self.channels.get(&TypeId::of::<T>()) {
61            Some(channel) => channel.send_event(event),
62            None => panic!("Type {} was never registered in the bus", type_name::<T>()),
63        }
64    }
65
66    pub fn send_serde_event(&self, type_id: &TypeId, event: BusEvent<SerdePayload>) {
67        match self.channels.get(type_id) {
68            Some(channel) => channel.send_serde_event(event),
69            None => panic!("TypeID was never registered in the bus"),
70        }
71    }
72
73    /// Provides a convenient way to get the latest value, update it, and re-push it to the bus.
74    /// Value will be created from Default::default if it has never been seen in the bus before.
75    ///
76    /// Only provides an atomic update if this is the only way this type is pushed to the bus. Then
77    /// it is atomic even if called from multiple threads (the lock on latest is not released until
78    /// the new value is pushed to the bus).
79    pub fn update_latest<T: Default + BusType>(&self, f: impl FnOnce(T) -> T) {
80        match self.channels.get(&TypeId::of::<T>()) {
81            Some(channel) => channel.update_latest(f),
82            None => panic!("TypeID was never registered in the bus"),
83        }
84    }
85
86    pub fn subscribe<T: BusType>(&self, prefix_latest: bool) -> mpsc::Receiver<BusEvent<T>> {
87        match self.channels.get(&TypeId::of::<T>()) {
88            Some(channel) => channel.subscribe(prefix_latest),
89            None => panic!("Type {} was never registered in the bus", type_name::<T>()),
90        }
91    }
92
93    pub fn subscribe_serde(&self, type_id: &TypeId) -> broadcast::Receiver<BusEvent<SerdePayload>> {
94        match self.channels.get(type_id) {
95            Some(channel) => channel.subscribe_serde(),
96            None => panic!("TypeID was never registered in the bus"),
97        }
98    }
99
100    pub fn get_latest<T: BusType>(&self) -> Option<BusEvent<T>> {
101        match self.channels.get(&TypeId::of::<T>()) {
102            Some(channel) => channel.get_latest(),
103            None => panic!("Type {} was never registered in the bus", type_name::<T>()),
104        }
105    }
106
107    pub fn trace_events<T: BusType + Debug>(&self) -> JoinHandle<()> {
108        // Trace is assumed to be all future events (no prefix of `latest`).
109        let mut rx = self.subscribe::<T>(false);
110        tokio::spawn(async move {
111            while let Some(event) = rx.recv().await {
112                trace!("{}: {:#?}", event.source, event.msg);
113            }
114        })
115    }
116}