1use std::collections::HashMap;
2
3#[cfg(feature = "websocket-transport")]
4use tokio_tungstenite::tungstenite::{
5 Error as WsError,
6 protocol::Message as WsMessage,
7};
8#[cfg(feature = "websocket-transport")]
9use futures_util::sink::{Sink, SinkExt};
10#[cfg(feature = "gloo-websocket")]
11use futures_util::sink::{Sink, SinkExt};
12
13use modrpc::{
14 Packet,
15 PacketBundle, ShatterPacketBundle,
16};
17
18pub struct InPacket {
19 pub transport: TransportIndex,
20 pub channel_id: u32,
21 pub packet: Packet,
22}
23
24#[cfg(feature = "websocket-transport")]
25pub type WsSinkBox = Box<dyn Sink<WsMessage, Error = WsError> + Send + std::marker::Unpin>;
26#[cfg(feature = "gloo-websocket")]
27pub type GlooWsSinkBox = Box<dyn Sink<gloo_net::websocket::Message, Error = gloo_net::websocket::WebSocketError> + std::marker::Unpin>;
28
29enum BroadcasterRequest {
30 #[cfg(feature = "tcp-transport")]
31 AddTcp {
32 stream: tokio::net::tcp::OwnedWriteHalf,
33 response_tx: oneshot::Sender<TransportIndex>,
34 },
35 #[cfg(feature = "websocket-transport")]
36 AddWs {
37 ws_tx: WsSinkBox,
38 response_tx: oneshot::Sender<TransportIndex>,
39 },
40 #[cfg(feature = "gloo-websocket")]
41 AddGlooWs {
42 ws_tx: GlooWsSinkBox,
43 response_tx: oneshot::Sender<TransportIndex>,
44 },
45 AddLocal {
46 tx: localq::mpsc::Sender<Packet>,
47 response_tx: oneshot::Sender<TransportIndex>,
48 },
49 Remove {
50 transport: TransportIndex,
51 response_tx: oneshot::Sender<()>,
52 },
53 AddNextHopToChannels {
54 next_hop_transport: TransportIndex,
55 channel_ids: Vec<(ChannelId, ChannelId)>, response_tx: oneshot::Sender<()>,
57 },
58}
59
60#[cfg(feature = "tcp-transport")]
61struct TcpTransport {
62 stream: tokio::net::tcp::OwnedWriteHalf,
63}
64
65#[cfg(feature = "websocket-transport")]
66struct WsTransport {
67 ws_tx: WsSinkBox,
68}
69
70#[cfg(feature = "gloo-websocket")]
71struct GlooWsTransport {
72 ws_tx: GlooWsSinkBox,
73}
74
75struct LocalTransport {
76 tx: localq::mpsc::Sender<Packet>,
77}
78
79type TransportKey = slotmap::DefaultKey;
80
81#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
82enum TransportType {
83 #[cfg(feature = "tcp-transport")]
84 Tcp,
85 #[cfg(feature = "websocket-transport")]
86 WebSocket,
87 #[cfg(feature = "gloo-websocket")]
88 GlooWebSocket,
89 Local,
90}
91
92#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
93pub struct TransportIndex {
94 transport_type: TransportType,
95 transport: TransportKey,
96}
97
98#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
99pub struct ChannelId {
100 pub channel_id: u32,
101}
102
103#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
104struct NextHop {
105 remote_channel_id: ChannelId,
106 transport: TransportIndex,
107}
108
109const BUNDLE_HEADER_LEN: usize = <PacketBundle as mproto::BaseLen>::BASE_LEN;
110
111pub struct Broadcaster {
112 in_packet_receiver: localq::mpsc::Receiver<InPacket>,
113 in_packet_sender: localq::mpsc::Sender<InPacket>,
114
115 next_hops: HashMap<ChannelId, Vec<NextHop>>,
117 transport_local_channel_ids: HashMap<TransportIndex, Vec<ChannelId>>,
118
119 #[cfg(feature = "tcp-transport")]
120 tcp_transports: slotmap::SlotMap<TransportKey, TcpTransport>,
121 #[cfg(feature = "websocket-transport")]
122 ws_transports: slotmap::SlotMap<TransportKey, WsTransport>,
123 #[cfg(feature = "gloo-websocket")]
124 gloo_ws_transports: slotmap::SlotMap<TransportKey, GlooWsTransport>,
125 local_transports: slotmap::SlotMap<TransportKey, LocalTransport>,
126
127 request_tx: localq::mpsc::Sender<BroadcasterRequest>,
128 request_rx: localq::mpsc::Receiver<BroadcasterRequest>,
129}
130
131impl Broadcaster {
132 pub fn new(packet_queue_capacity: usize) -> Self {
133 let (in_packet_sender, in_packet_receiver) = localq::mpsc::channel(packet_queue_capacity);
134 let (request_tx, request_rx) = localq::mpsc::channel(16);
135
136 Self {
137 in_packet_receiver,
138 in_packet_sender,
139
140 next_hops: HashMap::new(),
141 transport_local_channel_ids: HashMap::new(),
142
143 local_transports: slotmap::SlotMap::new(),
144 #[cfg(feature = "tcp-transport")]
145 tcp_transports: slotmap::SlotMap::new(),
146 #[cfg(feature = "websocket-transport")]
147 ws_transports: slotmap::SlotMap::new(),
148 #[cfg(feature = "gloo-websocket")]
149 gloo_ws_transports: slotmap::SlotMap::new(),
150
151 request_tx,
152 request_rx,
153 }
154 }
155
156 pub fn handle(&self) -> BroadcasterHandle {
157 BroadcasterHandle {
158 in_packet_sender: self.in_packet_sender.clone(),
159 request: self.request_tx.clone(),
160 }
161 }
162
163 pub fn add_local_transport(&mut self, tx: localq::mpsc::Sender<Packet>) -> TransportIndex {
164 let key = self.local_transports.insert(LocalTransport { tx });
165 log::debug!("Added Local transport {:?}", key);
166 TransportIndex {
167 transport_type: TransportType::Local,
168 transport: key,
169 }
170 }
171
172 pub async fn run(&mut self) {
173 use futures_util::FutureExt;
174
175 loop {
176 futures_util::select! {
177 in_packet = self.in_packet_receiver.recv().fuse() => {
178 let Ok(in_packet) = in_packet else { break; };
179 self.handle_in_packet(in_packet).await;
180 }
181 request = self.request_rx.recv().fuse() => {
182 let Ok(request) = request else { break; };
183 self.handle_request(request).await;
184 }
185 };
186 }
187 }
188
189 async fn handle_in_packet(&mut self, in_packet: InPacket) {
190 let local_channel_id = ChannelId {
191 channel_id: in_packet.channel_id,
192 };
193
194 log::trace!(
195 "in packet - channel_id={} transport={:?} len={}",
196 local_channel_id.channel_id,
197 in_packet.transport,
198 in_packet.packet.len(),
199 );
200
201 if let Some(next_hops) = self.next_hops.get(&local_channel_id) {
202 if let Err(_) = Self::broadcast(
203 in_packet,
204 next_hops,
205 #[cfg(feature = "tcp-transport")]
206 &mut self.tcp_transports,
207 #[cfg(feature = "websocket-transport")]
208 &mut self.ws_transports,
209 #[cfg(feature = "gloo-websocket")]
210 &mut self.gloo_ws_transports,
211 &mut self.local_transports,
212 ).await {
213 }
215 } else {
216 log::trace!(
217 "No next-hops for local-channel-id={:?}",
218 local_channel_id,
219 );
220 };
221 }
222
223 async fn remove_transport(&mut self, transport: TransportIndex) {
224 log::info!("removing transport {:?}", transport);
225
226 match transport.transport_type {
227 #[cfg(feature = "tcp-transport")]
228 TransportType::Tcp => {
229 self.tcp_transports.remove(transport.transport);
230 }
231 #[cfg(feature = "websocket-transport")]
232 TransportType::WebSocket => {
233 self.ws_transports.remove(transport.transport);
234 }
235 #[cfg(feature = "gloo-websocket")]
236 TransportType::GlooWebSocket => {
237 self.gloo_ws_transports.remove(transport.transport);
238 }
239 TransportType::Local => {
240 self.local_transports.remove(transport.transport);
241 }
242 }
243
244 if let Some(local_channel_ids) =
245 self.transport_local_channel_ids.remove(&transport)
246 {
247 for local_channel_id in local_channel_ids {
248 log::debug!(
249 "removing channel {:?} next_hops for transport {:?}",
250 local_channel_id,
251 transport,
252 );
253 if let Some(next_hops) = self.next_hops.get_mut(&local_channel_id) {
254 next_hops.retain(|next_hop| next_hop.transport != transport);
257 } else {
258 }
260 }
261 } else {
262 }
264 }
265
266 async fn handle_request(&mut self, request: BroadcasterRequest) {
267 match request {
268 #[cfg(feature = "tcp-transport")]
269 BroadcasterRequest::AddTcp { stream, response_tx } => {
270 let key = self.tcp_transports.insert(TcpTransport {
271 stream,
272 });
273 log::debug!("Added TCP transport {:?}", key);
274 let _ = response_tx.send(TransportIndex {
275 transport_type: TransportType::Tcp,
276 transport: key,
277 });
278 }
279 #[cfg(feature = "websocket-transport")]
280 BroadcasterRequest::AddWs { ws_tx, response_tx } => {
281 let key = self.ws_transports.insert(WsTransport { ws_tx });
282 log::debug!("Added WebSocket transport {:?}", key);
283 let _ = response_tx.send(TransportIndex {
284 transport_type: TransportType::WebSocket,
285 transport: key,
286 });
287 }
288 #[cfg(feature = "gloo-websocket")]
289 BroadcasterRequest::AddGlooWs { ws_tx, response_tx } => {
290 let key = self.gloo_ws_transports.insert(GlooWsTransport { ws_tx });
291 log::debug!("Added Gloo WebSocket transport {:?}", key);
292 let _ = response_tx.send(TransportIndex {
293 transport_type: TransportType::GlooWebSocket,
294 transport: key,
295 });
296 }
297 BroadcasterRequest::AddLocal { tx, response_tx } => {
298 let key = self.local_transports.insert(LocalTransport { tx });
299 log::debug!("Added Local transport {:?}", key);
300 let _ = response_tx.send(TransportIndex {
301 transport_type: TransportType::Local,
302 transport: key,
303 });
304 }
305 BroadcasterRequest::Remove { transport, response_tx } => {
306 self.remove_transport(transport).await;
307 log::debug!("TransportHub removed transport {:?}", transport);
308 let _ = response_tx.send(());
309 }
310 BroadcasterRequest::AddNextHopToChannels {
311 next_hop_transport, channel_ids, response_tx,
312 } => {
313 log::debug!(
314 "Adding next hop to channels transport={:?}, channel_ids={:?}",
315 next_hop_transport,
316 channel_ids,
317 );
318 for &(local_channel_id, remote_channel_id) in &channel_ids {
319 let next_hops =
320 self.next_hops.entry(local_channel_id).or_insert(Vec::new());
321 next_hops.push(NextHop {
322 remote_channel_id,
323 transport: next_hop_transport,
324 });
325 }
326
327 self.transport_local_channel_ids
328 .entry(next_hop_transport)
329 .or_insert(Vec::new())
330 .extend(channel_ids.iter().map(|(local_channel_id, _)| local_channel_id));
331
332 let _ = response_tx.send(());
334 }
335 }
336 }
337
338 async fn broadcast(
339 in_packet: InPacket,
340 next_hops: &[NextHop],
341 #[cfg(feature = "tcp-transport")]
342 tcp_transports: &mut slotmap::SlotMap<TransportKey, TcpTransport>,
343 #[cfg(feature = "websocket-transport")]
344 ws_transports: &mut slotmap::SlotMap<TransportKey, WsTransport>,
345 #[cfg(feature = "gloo-websocket")]
346 gloo_ws_transports: &mut slotmap::SlotMap<TransportKey, GlooWsTransport>,
347 local_transports: &mut slotmap::SlotMap<TransportKey, LocalTransport>,
348 ) -> std::io::Result<()> {
349 for next_hop in next_hops {
350 let transport_index = next_hop.transport;
351
352 if transport_index == in_packet.transport {
353 continue;
355 }
356
357 log::trace!(
358 "[transmitter] Sending to next-hop - transport={:?} channel={} length={}",
359 transport_index,
360 next_hop.remote_channel_id.channel_id,
361 in_packet.packet.len(),
362 );
363
364 match transport_index.transport_type {
365 #[cfg(feature = "tcp-transport")]
366 TransportType::Tcp => {
367 let bundle_payload = &in_packet.packet[..];
368
369 let mut bundle_header_buf = [0u8; BUNDLE_HEADER_LEN];
371 mproto::encode_value(
372 PacketBundle {
373 channel_id: next_hop.remote_channel_id.channel_id,
374 length: bundle_payload.len() as u16,
375 },
376 &mut bundle_header_buf,
377 );
378
379 if let Some(tcp_transport) = tcp_transports.get_mut(transport_index.transport) {
380 if let Err(_) =
381 Self::write_tcp_bundle(
382 &mut tcp_transport.stream,
383 &bundle_header_buf,
384 bundle_payload,
385 ).await
386 {
387 log::debug!("TransportHub tcp transport closed: {:?}", transport_index);
388 tcp_transports.remove(transport_index.transport);
390 }
391 }
392 }
393 #[cfg(feature = "websocket-transport")]
394 TransportType::WebSocket => {
395 let bundle_payload = &in_packet.packet[..];
396 let mut message = vec![0u8; BUNDLE_HEADER_LEN + bundle_payload.len()];
397
398 mproto::encode_value(
400 PacketBundle {
401 channel_id: next_hop.remote_channel_id.channel_id,
402 length: bundle_payload.len() as u16,
403 },
404 &mut message[..BUNDLE_HEADER_LEN],
405 );
406
407 message[BUNDLE_HEADER_LEN..].copy_from_slice(bundle_payload);
408
409 if let Some(ws_transport) = ws_transports.get_mut(transport_index.transport) {
410 if let Err(_) = ws_transport.ws_tx.send(WsMessage::Binary(message.into())).await {
411 log::debug!("WebSocket transport closed: {:?}", transport_index);
412 ws_transports.remove(transport_index.transport);
414 }
415 }
416 }
417 #[cfg(feature = "gloo-websocket")]
418 TransportType::GlooWebSocket => {
419 let bundle_payload = &in_packet.packet[..];
420 let mut message = vec![0u8; BUNDLE_HEADER_LEN + bundle_payload.len()];
421
422 mproto::encode_value(
424 PacketBundle {
425 channel_id: next_hop.remote_channel_id.channel_id,
426 length: bundle_payload.len() as u16,
427 },
428 &mut message[..BUNDLE_HEADER_LEN],
429 );
430
431 message[BUNDLE_HEADER_LEN..].copy_from_slice(bundle_payload);
432
433 if let Some(ws_transport) = gloo_ws_transports.get_mut(transport_index.transport) {
434 if let Err(_) =
435 ws_transport.ws_tx.send(
436 gloo_net::websocket::Message::Bytes(message)
437 )
438 .await
439 {
440 log::debug!("Gloo WebSocket transport closed: {:?}", transport_index);
441 gloo_ws_transports.remove(transport_index.transport);
443 }
444 }
445 }
446 TransportType::Local => {
447 let Some(local_transport) =
448 local_transports.get(transport_index.transport)
449 else {
450 continue;
451 };
452
453 for packet in ShatterPacketBundle::new(&in_packet.packet) {
454 if let Err(_) = local_transport.tx.send(packet).await {
455 local_transports.remove(transport_index.transport);
456 break;
457 }
458 }
459 }
460 }
461 }
462
463 Ok(())
464 }
465
466 #[cfg(feature = "tcp-transport")]
467 async fn write_tcp_bundle(
468 stream: &mut tokio::net::tcp::OwnedWriteHalf,
469 header: &[u8],
470 payload: &[u8],
471 ) -> std::io::Result<()> {
472 use tokio::io::AsyncWriteExt;
473
474 stream.write_all(header).await?;
476 stream.write_all(payload).await?;
477
478 Ok(())
479 }
480}
481
482#[derive(Clone)]
483pub struct BroadcasterHandle {
484 in_packet_sender: localq::mpsc::Sender<InPacket>,
485 request: localq::mpsc::Sender<BroadcasterRequest>,
486}
487
488impl BroadcasterHandle {
489 pub fn in_packet_sender(&self) -> &localq::mpsc::Sender<InPacket> {
490 &self.in_packet_sender
491 }
492
493 #[cfg(feature = "tcp-transport")]
494 pub async fn add_tcp(
495 &self,
496 stream: tokio::net::tcp::OwnedWriteHalf,
497 ) -> TransportIndex {
498 let (response_tx, response_rx) = oneshot::channel();
499
500 self.request.send(BroadcasterRequest::AddTcp {
501 stream,
502 response_tx,
503 })
504 .await
505 .unwrap();
506 let transport_index = response_rx.await.unwrap();
507
508 transport_index
509 }
510
511 #[cfg(feature = "websocket-transport")]
512 pub async fn add_ws(
513 &self,
514 ws_tx: WsSinkBox,
515 ) -> TransportIndex {
516 let (response_tx, response_rx) = oneshot::channel();
517
518 self.request.send(BroadcasterRequest::AddWs {
519 ws_tx,
520 response_tx,
521 })
522 .await
523 .unwrap();
524 let transport_index = response_rx.await.unwrap();
525
526 transport_index
527 }
528
529 #[cfg(feature = "gloo-websocket")]
530 pub async fn add_gloo_ws(
531 &self,
532 ws_tx: GlooWsSinkBox,
533 ) -> TransportIndex {
534 let (response_tx, response_rx) = oneshot::channel();
535
536 self.request.send(BroadcasterRequest::AddGlooWs {
537 ws_tx,
538 response_tx,
539 })
540 .await
541 .unwrap();
542 let transport_index = response_rx.await.unwrap();
543
544 transport_index
545 }
546
547 pub async fn add_local(
548 &self,
549 tx: localq::mpsc::Sender<Packet>,
550 ) -> TransportIndex {
551 let (response_tx, response_rx) = oneshot::channel();
552
553 self.request.send(BroadcasterRequest::AddLocal {
554 tx,
555 response_tx,
556 })
557 .await
558 .unwrap();
559 let transport_index = response_rx.await.unwrap();
560
561 transport_index
562 }
563
564 pub async fn add_next_hop_to_channels(
565 &self,
566 next_hop_transport: TransportIndex,
567 channel_ids: Vec<(ChannelId, ChannelId)>,
568 ) {
569 let (response_tx, response_rx) = oneshot::channel();
570
571 self.request.send(BroadcasterRequest::AddNextHopToChannels {
572 next_hop_transport,
573 channel_ids,
574 response_tx,
575 })
576 .await
577 .unwrap();
578 let _ = response_rx.await.unwrap();
579 }
580
581 pub async fn remove_transport(
582 &self,
583 transport: TransportIndex,
584 ) {
585 let (response_tx, response_rx) = oneshot::channel();
586
587 self.request.send(BroadcasterRequest::Remove {
588 transport,
589 response_tx,
590 })
591 .await
592 .unwrap();
593 let _ = response_rx.await.unwrap();
594 }
595}
596