pipeworks_core/
channel.rs1use std::{
2 any::TypeId,
3 sync::{Arc, RwLock},
4 time::SystemTime,
5};
6
7use tokio::sync::{broadcast, mpsc};
8use tracing::{error, warn};
9
10use crate::{
11 event::{AnyBusEvent, BusEvent},
12 node_id::NODE_ID,
13 reg::{BusType, TypeReg},
14};
15
16pub struct Channel {
19 pub type_reg: TypeReg,
21
22 local: broadcast::Sender<AnyBusEvent>,
25
26 serde: broadcast::Sender<BusEvent<SerdePayload>>,
32
33 latest: Arc<RwLock<Option<AnyBusEvent>>>,
35}
36
37#[derive(Clone)]
38pub enum SerdePayload {
39 Bitcode(Vec<u8>),
40}
41
42impl Channel {
43 pub fn from_type_reg(type_reg: TypeReg) -> Self {
44 let (local_tx, mut local_rx) = broadcast::channel::<AnyBusEvent>(type_reg.buffer_cap);
45 let (serde_tx, mut serde_rx) =
46 broadcast::channel::<BusEvent<SerdePayload>>(type_reg.buffer_cap);
47 let latest: Arc<RwLock<Option<AnyBusEvent>>> = Arc::new(RwLock::new(None));
48
49 if let Some((to_bytes, from_bytes)) = type_reg.bitcode_support {
50 let serde_tx_clone = serde_tx.clone();
52 tokio::spawn(async move {
53 loop {
54 let Ok(AnyBusEvent { source, time, msg }) = local_rx.recv().await else {
55 warn!(
56 "TODO(Send to CAS) Serde serialize was unable to keep up with local bus ingres {}",
57 type_reg.type_name
58 );
59 continue;
60 };
61
62 if !source.is_me() || serde_tx_clone.receiver_count() == 0 {
64 continue;
65 }
66
67 let bytes = to_bytes(msg);
68
69 let _ = serde_tx_clone.send(BusEvent {
72 source,
73 time,
74 msg: SerdePayload::Bitcode(bytes),
75 });
76 }
77 });
78
79 let local_tx_clone = local_tx.clone();
81 let latest_clone = latest.clone();
82 tokio::spawn(async move {
83 loop {
84 let Ok(BusEvent { source, time, msg }) = serde_rx.recv().await else {
85 warn!(
86 "TODO(Send to CAS) Serde deserialize was unable to keep up with serde ingres {}",
87 type_reg.type_name
88 );
89 continue;
90 };
91 if source.is_me() {
93 continue;
94 }
95
96 let msg = match msg {
97 SerdePayload::Bitcode(bytes) => {
98 match from_bytes(&bytes) {
99 Ok(v) => v,
100 Err(err) => {
101 error!(
103 "TODO(push to CAS): Failed to deserialize bytes into bus channel type {}: {}",
104 type_reg.type_name, err
105 );
106 continue;
107 }
108 }
109 }
110 };
111
112 let event = AnyBusEvent { source, time, msg };
113 *latest_clone.write().unwrap() = Some(event.clone());
114
115 let _ = local_tx_clone.send(event);
117 }
118 });
119 }
120
121 Self {
122 type_reg,
123 local: local_tx,
124 serde: serde_tx,
125 latest,
126 }
127 }
128
129 pub fn send_event<T: BusType>(&self, event: BusEvent<T>) {
131 assert!(TypeId::of::<T>() == self.type_reg.type_id);
132 let any_event = event.type_erase();
133 *self.latest.write().unwrap() = Some(any_event.clone());
134 let _ = self.local.send(any_event);
135 }
136
137 pub fn send_serde_event(&self, event: BusEvent<SerdePayload>) {
138 let _ = self.serde.send(event);
139 }
140
141 pub fn update_latest<T: Default + BusType>(&self, f: impl FnOnce(T) -> T) {
142 assert!(TypeId::of::<T>() == self.type_reg.type_id);
143
144 let mut latest_guard = self.latest.write().unwrap();
145 let value = latest_guard
146 .as_ref()
147 .map(|e| e.downcast_cloned::<T>().msg)
148 .unwrap_or_default();
149
150 let event = BusEvent {
151 source: *NODE_ID,
152 time: SystemTime::now(),
153 msg: f(value),
154 };
155
156 let any_event = event.type_erase();
157 *latest_guard = Some(any_event.clone());
158 let _ = self.local.send(any_event);
159 }
160
161 pub fn subscribe<T: BusType>(&self, prefix_latest: bool) -> mpsc::Receiver<BusEvent<T>> {
162 assert!(TypeId::of::<T>() == self.type_reg.type_id);
163 let mut any_event_rx = self.local.subscribe();
164 let type_reg = self.type_reg.clone();
165
166 let (tx, rx) = mpsc::channel(1);
167
168 let push_latest = if prefix_latest {
171 self.get_latest::<T>()
172 } else {
173 None
174 };
175
176 tokio::spawn(async move {
177 if let Some(latest) = push_latest {
178 if let Err(_) = tx.send(latest).await {
179 return;
181 }
182 }
183
184 loop {
185 match any_event_rx.recv().await {
186 Ok(any_event) => {
187 if let Err(_) = tx.send(any_event.downcast_cloned::<T>()).await {
188 break;
190 }
191 }
192 Err(err) => {
193 match err {
194 broadcast::error::RecvError::Closed => break,
196 broadcast::error::RecvError::Lagged(count) => {
198 warn!(
199 "TODO(push to CAS) Channel receiver {} lagged {} events",
200 type_reg.type_name, count
201 );
202 }
203 }
204 }
205 }
206 }
207 });
208
209 rx
210 }
211
212 pub fn subscribe_serde(&self) -> broadcast::Receiver<BusEvent<SerdePayload>> {
213 self.serde.subscribe()
214 }
215
216 pub fn get_latest<T: BusType>(&self) -> Option<BusEvent<T>> {
217 self.latest
218 .read()
219 .unwrap()
220 .as_ref()
221 .map(|any_event| any_event.downcast_cloned::<T>())
222 }
223}