pipeworks_core/
channel.rs

1use std::{
2    any::TypeId,
3    sync::{Arc, RwLock},
4    time::SystemTime,
5};
6
7use tokio::sync::{broadcast, mpsc};
8use tracing::{error, warn};
9
10use crate::{
11    event::{AnyBusEvent, BusEvent},
12    node_id::NODE_ID,
13    reg::{BusType, TypeReg},
14};
15
16/// A type-erased broadcast channel for a single type, with optional support for out-of-band
17/// (de)serialization to/from another internal broadcast channel.
18pub struct Channel {
19    /// The type registration used to create this channel.
20    pub type_reg: TypeReg,
21
22    /// The broadcast sender for local type-erased events. This bus will optionally feed the
23    /// `serde` bus if serialization is enabled.
24    local: broadcast::Sender<AnyBusEvent>,
25
26    /// The serialized bus.
27    /// - Remote events pushed to this broadcast will result in the message being deserialized and
28    ///   pushed to the `local` broadcast.
29    /// - Non-remote events pushed to `local` will be (lazily) serialized and pushed to this
30    ///   broadcast.
31    serde: broadcast::Sender<BusEvent<SerdePayload>>,
32
33    /// The latest value (if any) received by the channel.
34    latest: Arc<RwLock<Option<AnyBusEvent>>>,
35}
36
37#[derive(Clone)]
38pub enum SerdePayload {
39    Bitcode(Vec<u8>),
40}
41
42impl Channel {
43    pub fn from_type_reg(type_reg: TypeReg) -> Self {
44        let (local_tx, mut local_rx) = broadcast::channel::<AnyBusEvent>(type_reg.buffer_cap);
45        let (serde_tx, mut serde_rx) =
46            broadcast::channel::<BusEvent<SerdePayload>>(type_reg.buffer_cap);
47        let latest: Arc<RwLock<Option<AnyBusEvent>>> = Arc::new(RwLock::new(None));
48
49        if let Some((to_bytes, from_bytes)) = type_reg.bitcode_support {
50            // Spawn the local -> serde task
51            let serde_tx_clone = serde_tx.clone();
52            tokio::spawn(async move {
53                loop {
54                    let Ok(AnyBusEvent { source, time, msg }) = local_rx.recv().await else {
55                        warn!(
56                            "TODO(Send to CAS) Serde serialize was unable to keep up with local bus ingres {}",
57                            type_reg.type_name
58                        );
59                        continue;
60                    };
61
62                    // Never serialize remote events, or events when there is no receiver
63                    if !source.is_me() || serde_tx_clone.receiver_count() == 0 {
64                        continue;
65                    }
66
67                    let bytes = to_bytes(msg);
68
69                    // Ignore errors (it just means there is no receiver, which can happen if it was
70                    // dropped after the short-circuit check above).
71                    let _ = serde_tx_clone.send(BusEvent {
72                        source,
73                        time,
74                        msg: SerdePayload::Bitcode(bytes),
75                    });
76                }
77            });
78
79            // Spawn the serde -> local task
80            let local_tx_clone = local_tx.clone();
81            let latest_clone = latest.clone();
82            tokio::spawn(async move {
83                loop {
84                    let Ok(BusEvent { source, time, msg }) = serde_rx.recv().await else {
85                        warn!(
86                            "TODO(Send to CAS) Serde deserialize was unable to keep up with serde ingres {}",
87                            type_reg.type_name
88                        );
89                        continue;
90                    };
91                    // Never deserialize local events
92                    if source.is_me() {
93                        continue;
94                    }
95
96                    let msg = match msg {
97                        SerdePayload::Bitcode(bytes) => {
98                            match from_bytes(&bytes) {
99                                Ok(v) => v,
100                                Err(err) => {
101                                    // TODO: Send to some kind of a CAS
102                                    error!(
103                                        "TODO(push to CAS): Failed to deserialize bytes into bus channel type {}: {}",
104                                        type_reg.type_name, err
105                                    );
106                                    continue;
107                                }
108                            }
109                        }
110                    };
111
112                    let event = AnyBusEvent { source, time, msg };
113                    *latest_clone.write().unwrap() = Some(event.clone());
114
115                    // Ignore errors. It just means no one is listening.
116                    let _ = local_tx_clone.send(event);
117                }
118            });
119        }
120
121        Self {
122            type_reg,
123            local: local_tx,
124            serde: serde_tx,
125            latest,
126        }
127    }
128
129    /// Panics if type T is not the same as the underlying erased type.
130    pub fn send_event<T: BusType>(&self, event: BusEvent<T>) {
131        assert!(TypeId::of::<T>() == self.type_reg.type_id);
132        let any_event = event.type_erase();
133        *self.latest.write().unwrap() = Some(any_event.clone());
134        let _ = self.local.send(any_event);
135    }
136
137    pub fn send_serde_event(&self, event: BusEvent<SerdePayload>) {
138        let _ = self.serde.send(event);
139    }
140
141    pub fn update_latest<T: Default + BusType>(&self, f: impl FnOnce(T) -> T) {
142        assert!(TypeId::of::<T>() == self.type_reg.type_id);
143
144        let mut latest_guard = self.latest.write().unwrap();
145        let value = latest_guard
146            .as_ref()
147            .map(|e| e.downcast_cloned::<T>().msg)
148            .unwrap_or_default();
149
150        let event = BusEvent {
151            source: *NODE_ID,
152            time: SystemTime::now(),
153            msg: f(value),
154        };
155
156        let any_event = event.type_erase();
157        *latest_guard = Some(any_event.clone());
158        let _ = self.local.send(any_event);
159    }
160
161    pub fn subscribe<T: BusType>(&self, prefix_latest: bool) -> mpsc::Receiver<BusEvent<T>> {
162        assert!(TypeId::of::<T>() == self.type_reg.type_id);
163        let mut any_event_rx = self.local.subscribe();
164        let type_reg = self.type_reg.clone();
165
166        let (tx, rx) = mpsc::channel(1);
167
168        // This can safely be queried now, as we already started the subscription to the broadcast
169        // above.
170        let push_latest = if prefix_latest {
171            self.get_latest::<T>()
172        } else {
173            None
174        };
175
176        tokio::spawn(async move {
177            if let Some(latest) = push_latest {
178                if let Err(_) = tx.send(latest).await {
179                    // Receiver was dropped.
180                    return;
181                }
182            }
183
184            loop {
185                match any_event_rx.recv().await {
186                    Ok(any_event) => {
187                        if let Err(_) = tx.send(any_event.downcast_cloned::<T>()).await {
188                            // Receiver was dropped.
189                            break;
190                        }
191                    }
192                    Err(err) => {
193                        match err {
194                            // The whole channel was dropped.
195                            broadcast::error::RecvError::Closed => break,
196                            // The consumer could not keep up.
197                            broadcast::error::RecvError::Lagged(count) => {
198                                warn!(
199                                    "TODO(push to CAS) Channel receiver {} lagged {} events",
200                                    type_reg.type_name, count
201                                );
202                            }
203                        }
204                    }
205                }
206            }
207        });
208
209        rx
210    }
211
212    pub fn subscribe_serde(&self) -> broadcast::Receiver<BusEvent<SerdePayload>> {
213        self.serde.subscribe()
214    }
215
216    pub fn get_latest<T: BusType>(&self) -> Option<BusEvent<T>> {
217        self.latest
218            .read()
219            .unwrap()
220            .as_ref()
221            .map(|any_event| any_event.downcast_cloned::<T>())
222    }
223}