Skip to main content

modrpc_hub/
broadcaster.rs

1use std::collections::HashMap;
2
3#[cfg(feature = "websocket-transport")]
4use tokio_tungstenite::tungstenite::{
5    Error as WsError,
6    protocol::Message as WsMessage,
7};
8#[cfg(feature = "websocket-transport")]
9use futures_util::sink::{Sink, SinkExt};
10#[cfg(feature = "gloo-websocket")]
11use futures_util::sink::{Sink, SinkExt};
12
13use modrpc::{
14    Packet,
15    PacketBundle, ShatterPacketBundle,
16};
17
18pub struct InPacket {
19    pub transport: TransportIndex,
20    pub channel_id: u32,
21    pub packet: Packet,
22}
23
24#[cfg(feature = "websocket-transport")]
25pub type WsSinkBox = Box<dyn Sink<WsMessage, Error = WsError> + Send + std::marker::Unpin>;
26#[cfg(feature = "gloo-websocket")]
27pub type GlooWsSinkBox = Box<dyn Sink<gloo_net::websocket::Message, Error = gloo_net::websocket::WebSocketError> + std::marker::Unpin>;
28
29enum BroadcasterRequest {
30    #[cfg(feature = "tcp-transport")]
31    AddTcp {
32        stream: tokio::net::tcp::OwnedWriteHalf,
33        response_tx: oneshot::Sender<TransportIndex>,
34    },
35    #[cfg(feature = "websocket-transport")]
36    AddWs {
37        ws_tx: WsSinkBox,
38        response_tx: oneshot::Sender<TransportIndex>,
39    },
40    #[cfg(feature = "gloo-websocket")]
41    AddGlooWs {
42        ws_tx: GlooWsSinkBox,
43        response_tx: oneshot::Sender<TransportIndex>,
44    },
45    AddLocal {
46        tx: localq::mpsc::Sender<Packet>,
47        response_tx: oneshot::Sender<TransportIndex>,
48    },
49    Remove {
50        transport: TransportIndex,
51        response_tx: oneshot::Sender<()>,
52    },
53    AddNextHopToChannels {
54        next_hop_transport: TransportIndex,
55        channel_ids: Vec<(ChannelId, ChannelId)>, // [(local channel ID, remote channel ID)]
56        response_tx: oneshot::Sender<()>,
57    },
58}
59
60#[cfg(feature = "tcp-transport")]
61struct TcpTransport {
62    stream: tokio::net::tcp::OwnedWriteHalf,
63}
64
65#[cfg(feature = "websocket-transport")]
66struct WsTransport {
67    ws_tx: WsSinkBox,
68}
69
70#[cfg(feature = "gloo-websocket")]
71struct GlooWsTransport {
72    ws_tx: GlooWsSinkBox,
73}
74
75struct LocalTransport {
76    tx: localq::mpsc::Sender<Packet>,
77}
78
79type TransportKey = slotmap::DefaultKey;
80
81#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
82enum TransportType {
83    #[cfg(feature = "tcp-transport")]
84    Tcp,
85    #[cfg(feature = "websocket-transport")]
86    WebSocket,
87    #[cfg(feature = "gloo-websocket")]
88    GlooWebSocket,
89    Local,
90}
91
92#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
93pub struct TransportIndex {
94    transport_type: TransportType,
95    transport: TransportKey,
96}
97
98#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
99pub struct ChannelId {
100    pub channel_id: u32,
101}
102
103#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
104struct NextHop {
105    remote_channel_id: ChannelId,
106    transport: TransportIndex,
107}
108
109const BUNDLE_HEADER_LEN: usize = <PacketBundle as mproto::BaseLen>::BASE_LEN;
110
111pub struct Broadcaster {
112    in_packet_receiver: localq::mpsc::Receiver<InPacket>,
113    in_packet_sender: localq::mpsc::Sender<InPacket>,
114
115    // Map local channel_id to list of transports to broadcast packet bundles to
116    next_hops: HashMap<ChannelId, Vec<NextHop>>,
117    transport_local_channel_ids: HashMap<TransportIndex, Vec<ChannelId>>,
118
119    #[cfg(feature = "tcp-transport")]
120    tcp_transports: slotmap::SlotMap<TransportKey, TcpTransport>,
121    #[cfg(feature = "websocket-transport")]
122    ws_transports: slotmap::SlotMap<TransportKey, WsTransport>,
123    #[cfg(feature = "gloo-websocket")]
124    gloo_ws_transports: slotmap::SlotMap<TransportKey, GlooWsTransport>,
125    local_transports: slotmap::SlotMap<TransportKey, LocalTransport>,
126
127    request_tx: localq::mpsc::Sender<BroadcasterRequest>,
128    request_rx: localq::mpsc::Receiver<BroadcasterRequest>,
129}
130
131impl Broadcaster {
132    pub fn new(packet_queue_capacity: usize) -> Self {
133        let (in_packet_sender, in_packet_receiver) = localq::mpsc::channel(packet_queue_capacity);
134        let (request_tx, request_rx) = localq::mpsc::channel(16);
135
136        Self {
137            in_packet_receiver,
138            in_packet_sender,
139
140            next_hops: HashMap::new(),
141            transport_local_channel_ids: HashMap::new(),
142
143            local_transports: slotmap::SlotMap::new(),
144            #[cfg(feature = "tcp-transport")]
145            tcp_transports: slotmap::SlotMap::new(),
146            #[cfg(feature = "websocket-transport")]
147            ws_transports: slotmap::SlotMap::new(),
148            #[cfg(feature = "gloo-websocket")]
149            gloo_ws_transports: slotmap::SlotMap::new(),
150
151            request_tx,
152            request_rx,
153        }
154    }
155
156    pub fn handle(&self) -> BroadcasterHandle {
157        BroadcasterHandle {
158            in_packet_sender: self.in_packet_sender.clone(),
159            request: self.request_tx.clone(),
160        }
161    }
162
163    pub fn add_local_transport(&mut self, tx: localq::mpsc::Sender<Packet>) -> TransportIndex {
164        let key = self.local_transports.insert(LocalTransport { tx });
165        log::debug!("Added Local transport {:?}", key);
166        TransportIndex {
167            transport_type: TransportType::Local,
168            transport: key,
169        }
170    }
171
172    pub async fn run(&mut self) {
173        use futures_util::FutureExt;
174
175        loop {
176            futures_util::select! {
177                in_packet = self.in_packet_receiver.recv().fuse() => {
178                    let Ok(in_packet) = in_packet else { break; };
179                    self.handle_in_packet(in_packet).await;
180                }
181                request = self.request_rx.recv().fuse() => {
182                    let Ok(request) = request else { break; };
183                    self.handle_request(request).await;
184                }
185            };
186        }
187    }
188
189    async fn handle_in_packet(&mut self, in_packet: InPacket) {
190        let local_channel_id = ChannelId {
191            channel_id: in_packet.channel_id,
192        };
193
194        log::trace!(
195            "in packet - channel_id={} transport={:?} len={}",
196            local_channel_id.channel_id,
197            in_packet.transport,
198            in_packet.packet.len(),
199        );
200
201        if let Some(next_hops) = self.next_hops.get(&local_channel_id) {
202            if let Err(_) = Self::broadcast(
203                in_packet,
204                next_hops,
205                #[cfg(feature = "tcp-transport")]
206                &mut self.tcp_transports,
207                #[cfg(feature = "websocket-transport")]
208                &mut self.ws_transports,
209                #[cfg(feature = "gloo-websocket")]
210                &mut self.gloo_ws_transports,
211                &mut self.local_transports,
212            ).await {
213                // TODO can this even fail?
214            }
215        } else {
216            log::trace!(
217                "No next-hops for local-channel-id={:?}",
218                local_channel_id,
219            );
220        };
221    }
222
223    async fn remove_transport(&mut self, transport: TransportIndex) {
224        log::info!("removing transport {:?}", transport);
225
226        match transport.transport_type {
227            #[cfg(feature = "tcp-transport")]
228            TransportType::Tcp => {
229                self.tcp_transports.remove(transport.transport);
230            }
231            #[cfg(feature = "websocket-transport")]
232            TransportType::WebSocket => {
233                self.ws_transports.remove(transport.transport);
234            }
235            #[cfg(feature = "gloo-websocket")]
236            TransportType::GlooWebSocket => {
237                self.gloo_ws_transports.remove(transport.transport);
238            }
239            TransportType::Local => {
240                self.local_transports.remove(transport.transport);
241            }
242        }
243
244        if let Some(local_channel_ids) =
245            self.transport_local_channel_ids.remove(&transport)
246        {
247            for local_channel_id in local_channel_ids {
248                log::debug!(
249                    "removing channel {:?} next_hops for transport {:?}",
250                    local_channel_id,
251                    transport,
252                );
253                if let Some(next_hops) = self.next_hops.get_mut(&local_channel_id) {
254                    // Remove this transport as a next-hop from all of the channels it
255                    // participated in.
256                    next_hops.retain(|next_hop| next_hop.transport != transport);
257                } else {
258                    // TODO warning?
259                }
260            }
261        } else {
262            // TODO warning?
263        }
264    }
265
266    async fn handle_request(&mut self, request: BroadcasterRequest) {
267        match request {
268            #[cfg(feature = "tcp-transport")]
269            BroadcasterRequest::AddTcp { stream, response_tx } => {
270                let key = self.tcp_transports.insert(TcpTransport {
271                    stream,
272                });
273                log::debug!("Added TCP transport {:?}", key);
274                let _ = response_tx.send(TransportIndex {
275                    transport_type: TransportType::Tcp,
276                    transport: key,
277                });
278            }
279            #[cfg(feature = "websocket-transport")]
280            BroadcasterRequest::AddWs { ws_tx, response_tx } => {
281                let key = self.ws_transports.insert(WsTransport { ws_tx });
282                log::debug!("Added WebSocket transport {:?}", key);
283                let _ = response_tx.send(TransportIndex {
284                    transport_type: TransportType::WebSocket,
285                    transport: key,
286                });
287            }
288            #[cfg(feature = "gloo-websocket")]
289            BroadcasterRequest::AddGlooWs { ws_tx, response_tx } => {
290                let key = self.gloo_ws_transports.insert(GlooWsTransport { ws_tx });
291                log::debug!("Added Gloo WebSocket transport {:?}", key);
292                let _ = response_tx.send(TransportIndex {
293                    transport_type: TransportType::GlooWebSocket,
294                    transport: key,
295                });
296            }
297            BroadcasterRequest::AddLocal { tx, response_tx } => {
298                let key = self.local_transports.insert(LocalTransport { tx });
299                log::debug!("Added Local transport {:?}", key);
300                let _ = response_tx.send(TransportIndex {
301                    transport_type: TransportType::Local,
302                    transport: key,
303                });
304            }
305            BroadcasterRequest::Remove { transport, response_tx } => {
306                self.remove_transport(transport).await;
307                log::debug!("TransportHub removed transport {:?}", transport);
308                let _ = response_tx.send(());
309            }
310            BroadcasterRequest::AddNextHopToChannels {
311                next_hop_transport, channel_ids, response_tx,
312            } => {
313                log::debug!(
314                    "Adding next hop to channels transport={:?}, channel_ids={:?}",
315                    next_hop_transport,
316                    channel_ids,
317                );
318                for &(local_channel_id, remote_channel_id) in &channel_ids {
319                    let next_hops =
320                        self.next_hops.entry(local_channel_id).or_insert(Vec::new());
321                    next_hops.push(NextHop {
322                        remote_channel_id,
323                        transport: next_hop_transport,
324                    });
325                }
326
327                self.transport_local_channel_ids
328                    .entry(next_hop_transport)
329                    .or_insert(Vec::new())
330                    .extend(channel_ids.iter().map(|(local_channel_id, _)| local_channel_id));
331
332                // Don't care if requester hung up
333                let _ = response_tx.send(());
334            }
335        }
336    }
337
338    async fn broadcast(
339        in_packet: InPacket,
340        next_hops: &[NextHop],
341        #[cfg(feature = "tcp-transport")]
342        tcp_transports: &mut slotmap::SlotMap<TransportKey, TcpTransport>,
343        #[cfg(feature = "websocket-transport")]
344        ws_transports: &mut slotmap::SlotMap<TransportKey, WsTransport>,
345        #[cfg(feature = "gloo-websocket")]
346        gloo_ws_transports: &mut slotmap::SlotMap<TransportKey, GlooWsTransport>,
347        local_transports: &mut slotmap::SlotMap<TransportKey, LocalTransport>,
348    ) -> std::io::Result<()> {
349        for next_hop in next_hops {
350            let transport_index = next_hop.transport;
351
352            if transport_index == in_packet.transport {
353                // Don't send bundles back to transport they originated from.
354                continue;
355            }
356
357            log::trace!(
358                "[transmitter]   Sending to next-hop - transport={:?} channel={} length={}",
359                transport_index,
360                next_hop.remote_channel_id.channel_id,
361                in_packet.packet.len(),
362            );
363
364            match transport_index.transport_type {
365                #[cfg(feature = "tcp-transport")]
366                TransportType::Tcp => {
367                    let bundle_payload = &in_packet.packet[..];
368
369                    // Fill bundle header
370                    let mut bundle_header_buf = [0u8; BUNDLE_HEADER_LEN];
371                    mproto::encode_value(
372                        PacketBundle {
373                            channel_id: next_hop.remote_channel_id.channel_id,
374                            length: bundle_payload.len() as u16,
375                        },
376                        &mut bundle_header_buf,
377                    );
378
379                    if let Some(tcp_transport) = tcp_transports.get_mut(transport_index.transport) {
380                        if let Err(_) =
381                            Self::write_tcp_bundle(
382                                &mut tcp_transport.stream,
383                                &bundle_header_buf,
384                                bundle_payload,
385                            ).await
386                        {
387                            log::debug!("TransportHub tcp transport closed: {:?}", transport_index);
388                            // Remove transport
389                            tcp_transports.remove(transport_index.transport);
390                        }
391                    }
392                }
393                #[cfg(feature = "websocket-transport")]
394                TransportType::WebSocket => {
395                    let bundle_payload = &in_packet.packet[..];
396                    let mut message = vec![0u8; BUNDLE_HEADER_LEN + bundle_payload.len()];
397
398                    // Fill bundle header
399                    mproto::encode_value(
400                        PacketBundle {
401                            channel_id: next_hop.remote_channel_id.channel_id,
402                            length: bundle_payload.len() as u16,
403                        },
404                        &mut message[..BUNDLE_HEADER_LEN],
405                    );
406
407                    message[BUNDLE_HEADER_LEN..].copy_from_slice(bundle_payload);
408
409                    if let Some(ws_transport) = ws_transports.get_mut(transport_index.transport) {
410                        if let Err(_) = ws_transport.ws_tx.send(WsMessage::Binary(message.into())).await {
411                            log::debug!("WebSocket transport closed: {:?}", transport_index);
412                            // Remove transport
413                            ws_transports.remove(transport_index.transport);
414                        }
415                    }
416                }
417                #[cfg(feature = "gloo-websocket")]
418                TransportType::GlooWebSocket => {
419                    let bundle_payload = &in_packet.packet[..];
420                    let mut message = vec![0u8; BUNDLE_HEADER_LEN + bundle_payload.len()];
421
422                    // Fill bundle header
423                    mproto::encode_value(
424                        PacketBundle {
425                            channel_id: next_hop.remote_channel_id.channel_id,
426                            length: bundle_payload.len() as u16,
427                        },
428                        &mut message[..BUNDLE_HEADER_LEN],
429                    );
430
431                    message[BUNDLE_HEADER_LEN..].copy_from_slice(bundle_payload);
432
433                    if let Some(ws_transport) = gloo_ws_transports.get_mut(transport_index.transport) {
434                        if let Err(_) =
435                            ws_transport.ws_tx.send(
436                                gloo_net::websocket::Message::Bytes(message)
437                            )
438                            .await
439                        {
440                            log::debug!("Gloo WebSocket transport closed: {:?}", transport_index);
441                            // Remove transport
442                            gloo_ws_transports.remove(transport_index.transport);
443                        }
444                    }
445                }
446                TransportType::Local => {
447                    let Some(local_transport) =
448                        local_transports.get(transport_index.transport)
449                    else {
450                        continue;
451                    };
452
453                    for packet in ShatterPacketBundle::new(&in_packet.packet) {
454                        if let Err(_) = local_transport.tx.send(packet).await {
455                            local_transports.remove(transport_index.transport);
456                            break;
457                        }
458                    }
459                }
460            }
461        }
462
463        Ok(())
464    }
465
466    #[cfg(feature = "tcp-transport")]
467    async fn write_tcp_bundle(
468        stream: &mut tokio::net::tcp::OwnedWriteHalf,
469        header: &[u8],
470        payload: &[u8],
471    ) -> std::io::Result<()> {
472        use tokio::io::AsyncWriteExt;
473
474        // TODO vectored write
475        stream.write_all(header).await?;
476        stream.write_all(payload).await?;
477
478        Ok(())
479    }
480}
481
482#[derive(Clone)]
483pub struct BroadcasterHandle {
484    in_packet_sender: localq::mpsc::Sender<InPacket>,
485    request: localq::mpsc::Sender<BroadcasterRequest>,
486}
487
488impl BroadcasterHandle {
489    pub fn in_packet_sender(&self) -> &localq::mpsc::Sender<InPacket> {
490        &self.in_packet_sender
491    }
492
493    #[cfg(feature = "tcp-transport")]
494    pub async fn add_tcp(
495        &self,
496        stream: tokio::net::tcp::OwnedWriteHalf,
497    ) -> TransportIndex {
498        let (response_tx, response_rx) = oneshot::channel();
499
500        self.request.send(BroadcasterRequest::AddTcp {
501            stream,
502            response_tx,
503        })
504        .await
505        .unwrap();
506        let transport_index = response_rx.await.unwrap();
507
508        transport_index
509    }
510
511    #[cfg(feature = "websocket-transport")]
512    pub async fn add_ws(
513        &self,
514        ws_tx: WsSinkBox,
515    ) -> TransportIndex {
516        let (response_tx, response_rx) = oneshot::channel();
517
518        self.request.send(BroadcasterRequest::AddWs {
519            ws_tx,
520            response_tx,
521        })
522        .await
523        .unwrap();
524        let transport_index = response_rx.await.unwrap();
525
526        transport_index
527    }
528
529    #[cfg(feature = "gloo-websocket")]
530    pub async fn add_gloo_ws(
531        &self,
532        ws_tx: GlooWsSinkBox,
533    ) -> TransportIndex {
534        let (response_tx, response_rx) = oneshot::channel();
535
536        self.request.send(BroadcasterRequest::AddGlooWs {
537            ws_tx,
538            response_tx,
539        })
540        .await
541        .unwrap();
542        let transport_index = response_rx.await.unwrap();
543
544        transport_index
545    }
546
547    pub async fn add_local(
548        &self,
549        tx: localq::mpsc::Sender<Packet>,
550    ) -> TransportIndex {
551        let (response_tx, response_rx) = oneshot::channel();
552
553        self.request.send(BroadcasterRequest::AddLocal {
554            tx,
555            response_tx,
556        })
557        .await
558        .unwrap();
559        let transport_index = response_rx.await.unwrap();
560
561        transport_index
562    }
563
564    pub async fn add_next_hop_to_channels(
565        &self,
566        next_hop_transport: TransportIndex,
567        channel_ids: Vec<(ChannelId, ChannelId)>,
568    ) {
569        let (response_tx, response_rx) = oneshot::channel();
570
571        self.request.send(BroadcasterRequest::AddNextHopToChannels {
572            next_hop_transport,
573            channel_ids,
574            response_tx,
575        })
576        .await
577        .unwrap();
578        let _ = response_rx.await.unwrap();
579    }
580
581    pub async fn remove_transport(
582        &self,
583        transport: TransportIndex,
584    ) {
585        let (response_tx, response_rx) = oneshot::channel();
586
587        self.request.send(BroadcasterRequest::Remove {
588            transport,
589            response_tx,
590        })
591        .await
592        .unwrap();
593        let _ = response_rx.await.unwrap();
594    }
595}
596