commonware_p2p/utils/
mux.rs

1//! This utility wraps a [Sender] and [Receiver], providing lightweight sub-channels keyed by
2//! [Channel].
3//!
4//! Usage:
5//! - Call [Muxer::new] to obtain a ([Muxer], [MuxHandle]) pair.
6//! - Call [Muxer::start] or run [Muxer::run] in a background task to demux incoming messages into
7//!   per-subchannel queues.
8//! - Call [MuxHandle::register] to obtain a ([SubSender], [SubReceiver]) pair for that subchannel,
9//!   even if the muxer is already running.
10
11use crate::{Channel, Message, Receiver, Recipients, Sender};
12use bytes::{BufMut, Bytes, BytesMut};
13use commonware_codec::{varint::UInt, EncodeSize, ReadExt, Write};
14use commonware_macros::select;
15use commonware_runtime::{Handle, Spawner};
16use futures::{
17    channel::{mpsc, oneshot},
18    SinkExt, StreamExt,
19};
20use std::{collections::HashMap, fmt::Debug};
21use thiserror::Error;
22use tracing::debug;
23
24/// Errors that can occur when interacting with a [SubReceiver] or [MuxHandle].
25#[derive(Error, Debug)]
26pub enum Error {
27    #[error("subchannel already registered: {0}")]
28    AlreadyRegistered(Channel),
29    #[error("muxer is closed")]
30    Closed,
31    #[error("recv failed")]
32    RecvFailed,
33}
34
35/// Control messages for the [Muxer].
36enum Control<R: Receiver> {
37    Register {
38        subchannel: Channel,
39        sender: oneshot::Sender<mpsc::Receiver<Message<R::PublicKey>>>,
40    },
41    Deregister {
42        subchannel: Channel,
43    },
44}
45
46/// Thread-safe routing table mapping each [Channel] to the [mpsc::Sender] for [`Message<P>`].
47type Routes<P> = HashMap<Channel, mpsc::Sender<Message<P>>>;
48
49/// A multiplexer of p2p channels into subchannels.
50pub struct Muxer<E: Spawner, S: Sender, R: Receiver> {
51    context: E,
52    sender: S,
53    receiver: R,
54    mailbox_size: usize,
55    control_rx: mpsc::Receiver<Control<R>>,
56    routes: Routes<R::PublicKey>,
57}
58
59impl<E: Spawner, S: Sender, R: Receiver> Muxer<E, S, R> {
60    /// Create a multiplexed wrapper around a [Sender] and [Receiver] pair, and return a ([Muxer],
61    /// [MuxHandle]) pair that can be used to register routes dynamically.
62    pub fn new(
63        context: E,
64        sender: S,
65        receiver: R,
66        mailbox_size: usize,
67    ) -> (Self, MuxHandle<E, S, R>) {
68        let (control_tx, control_rx) = mpsc::channel(mailbox_size);
69        let mux = Self {
70            context: context.clone(),
71            sender,
72            receiver,
73            mailbox_size,
74            control_rx,
75            routes: HashMap::new(),
76        };
77
78        let handle = MuxHandle {
79            context,
80            sender: mux.sender.clone(),
81            control_tx,
82        };
83
84        (mux, handle)
85    }
86
87    /// Start the demuxer using the given spawner.
88    pub fn start(mut self) -> Handle<Result<(), R::Error>> {
89        self.context.spawn_ref()(self.run())
90    }
91
92    /// Drive demultiplexing of messages into per-subchannel receivers.
93    ///
94    /// Callers should run this in a background task for as long as the underlying `Receiver` is
95    /// expected to receive traffic.
96    pub async fn run(mut self) -> Result<(), R::Error> {
97        loop {
98            select! {
99                // Control messages (registration/deregistration)
100                control = self.control_rx.next() => {
101                    match control {
102                        Some(Control::Register { subchannel, sender }) => {
103                            // If the subchannel is already registered, drop the sender.
104                            if self.routes.contains_key(&subchannel) {
105                                continue;
106                            }
107
108                            // Otherwise, create a new subchannel and send the receiver to the caller.
109                            let (tx, rx) = mpsc::channel(self.mailbox_size);
110                            self.routes.insert(subchannel, tx);
111                            let _ = sender.send(rx);
112                        },
113                        Some(Control::Deregister { subchannel }) => {
114                            // Remove the route.
115                            self.routes.remove(&subchannel);
116                        },
117                        None => {
118                            // If the control channel is closed, we can shut down since there must
119                            // be no more registrations, and all receivers must have been dropped.
120                            return Ok(());
121                        }
122                    }
123                },
124                // Network messages
125                message = self.receiver.recv() => {
126                    let (pk, mut bytes) = message?;
127
128                    // Decode message: varint(subchannel) || bytes
129                    let subchannel: Channel = match UInt::read(&mut bytes) {
130                        Ok(v) => v.into(),
131                        Err(_) => {
132                            debug!(?pk, "invalid message: missing subchannel");
133                            continue;
134                        }
135                    };
136
137                    // Get the route for the subchannel.
138                    let Some(sender) = self.routes.get_mut(&subchannel) else {
139                        // Drops the message if the subchannel is not found
140                        continue;
141                    };
142
143                    // Send the message to the subchannel, blocking if the queue is full.
144                    if let Err(e) = sender.send((pk, bytes)).await {
145                        // Remove the route for the subchannel.
146                        self.routes.remove(&subchannel);
147
148                        // Failure, drop the sender since the receiver is no longer interested.
149                        debug!(?subchannel, ?e, "failed to send message to subchannel");
150                    }
151                }
152            }
153        }
154    }
155}
156
157/// A clonable handle that allows registering routes at any time, even after the [Muxer] is running.
158#[derive(Clone)]
159pub struct MuxHandle<E: Spawner, S: Sender, R: Receiver> {
160    context: E,
161    sender: S,
162    control_tx: mpsc::Sender<Control<R>>,
163}
164
165impl<E: Spawner, S: Sender, R: Receiver> MuxHandle<E, S, R> {
166    /// Open a `subchannel`. Returns a ([SubSender], [SubReceiver]) pair that can be used to send
167    /// and receive messages for that subchannel.
168    ///
169    /// Panics if the subchannel is already registered at any point.
170    pub async fn register(
171        &mut self,
172        subchannel: Channel,
173    ) -> Result<(SubSender<S>, SubReceiver<E, R>), Error> {
174        let (tx, rx) = oneshot::channel();
175        self.control_tx
176            .send(Control::Register {
177                subchannel,
178                sender: tx,
179            })
180            .await
181            .map_err(|_| Error::Closed)?;
182        let receiver = rx.await.map_err(|_| Error::AlreadyRegistered(subchannel))?;
183
184        Ok((
185            SubSender {
186                subchannel,
187                inner: self.sender.clone(),
188            },
189            SubReceiver {
190                context: self.context.clone(),
191                receiver,
192                control_tx: Some(self.control_tx.clone()),
193                subchannel,
194            },
195        ))
196    }
197}
198
199/// Sender that routes messages to the `subchannel`.
200#[derive(Clone, Debug)]
201pub struct SubSender<S: Sender> {
202    inner: S,
203    subchannel: Channel,
204}
205
206impl<S: Sender> Sender for SubSender<S> {
207    type Error = S::Error;
208    type PublicKey = S::PublicKey;
209
210    async fn send(
211        &mut self,
212        recipients: Recipients<S::PublicKey>,
213        payload: Bytes,
214        priority: bool,
215    ) -> Result<Vec<S::PublicKey>, S::Error> {
216        let subchannel = UInt(self.subchannel);
217        let mut buf = BytesMut::with_capacity(subchannel.encode_size() + payload.len());
218        subchannel.write(&mut buf);
219        buf.put_slice(&payload);
220        self.inner.send(recipients, buf.freeze(), priority).await
221    }
222}
223
224/// Receiver that yields messages for a specific subchannel.
225pub struct SubReceiver<E: Spawner, R: Receiver> {
226    context: E,
227    receiver: mpsc::Receiver<Message<R::PublicKey>>,
228    control_tx: Option<mpsc::Sender<Control<R>>>,
229    subchannel: Channel,
230}
231
232impl<E: Spawner, R: Receiver> Receiver for SubReceiver<E, R> {
233    type Error = Error;
234    type PublicKey = R::PublicKey;
235
236    async fn recv(&mut self) -> Result<Message<Self::PublicKey>, Self::Error> {
237        self.receiver.next().await.ok_or(Error::RecvFailed)
238    }
239}
240
241impl<E: Spawner, R: Receiver> Debug for SubReceiver<E, R> {
242    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243        write!(f, "SubReceiver({})", self.subchannel)
244    }
245}
246
247impl<E: Spawner, R: Receiver> Drop for SubReceiver<E, R> {
248    fn drop(&mut self) {
249        // Take the control channel to avoid cloning.
250        let mut control_tx = self
251            .control_tx
252            .take()
253            .expect("SubReceiver::drop called twice");
254
255        // If the control channel is not full, deregister the subchannel immediately.
256        let subchannel = self.subchannel;
257        if control_tx
258            .try_send(Control::Deregister { subchannel })
259            .is_ok()
260        {
261            return;
262        }
263
264        // Otherwise, spawn a task to deregister the subchannel.
265        self.context.spawn_ref()(async move {
266            let _ = control_tx.send(Control::Deregister { subchannel }).await;
267        });
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274    use crate::{
275        simulated::{Config as SimConfig, Link, Network, Oracle},
276        Recipients,
277    };
278    use bytes::Bytes;
279    use commonware_cryptography::{ed25519::PrivateKey, PrivateKeyExt, Signer};
280    use commonware_macros::{select, test_traced};
281    use commonware_runtime::{deterministic, Clock, Metrics, Runner};
282    use std::time::Duration;
283
284    type Pk = commonware_cryptography::ed25519::PublicKey;
285
286    const LINK: Link = Link {
287        latency: Duration::from_millis(0),
288        jitter: Duration::from_millis(0),
289        success_rate: 1.0,
290    };
291    const CAPACITY: usize = 5usize;
292
293    /// Start the network and return the oracle.
294    fn start_network(context: deterministic::Context) -> Oracle<Pk> {
295        let (network, oracle) = Network::new(
296            context.with_label("network"),
297            SimConfig {
298                max_size: 1024 * 1024,
299            },
300        );
301        network.start();
302        oracle
303    }
304
305    /// Create a public key from a seed.
306    fn pk(seed: u64) -> Pk {
307        PrivateKey::from_seed(seed).public_key()
308    }
309
310    /// Link two peers bidirectionally.
311    async fn link_bidirectional(oracle: &mut Oracle<Pk>, a: Pk, b: Pk) {
312        oracle.add_link(a.clone(), b.clone(), LINK).await.unwrap();
313        oracle.add_link(b, a, LINK).await.unwrap();
314    }
315
316    /// Create a peer and register it with the oracle.
317    async fn create_peer<E: Spawner>(
318        context: &E,
319        oracle: &mut Oracle<Pk>,
320        seed: u64,
321    ) -> (
322        Pk,
323        MuxHandle<E, impl Sender<PublicKey = Pk>, impl Receiver<PublicKey = Pk>>,
324    ) {
325        let pubkey = pk(seed);
326        let (sender, receiver) = oracle.register(pubkey.clone(), 0).await.unwrap();
327        let (mux, handle) = Muxer::new(context.clone(), sender, receiver, CAPACITY);
328        mux.start();
329        (pubkey, handle)
330    }
331
332    /// Send a burst of messages to a list of senders.
333    async fn send_burst<S: Sender>(txs: &mut [SubSender<S>], count: usize) {
334        for i in 0..count {
335            let payload = Bytes::from(vec![i as u8]);
336            for tx in txs.iter_mut() {
337                let _ = tx
338                    .send(Recipients::All, payload.clone(), false)
339                    .await
340                    .unwrap();
341            }
342        }
343    }
344
345    /// Wait for `n` messages to be received on the receiver.
346    async fn expect_n_messages<E: Spawner + Clock>(
347        rx: &mut SubReceiver<E, impl Receiver<PublicKey = Pk>>,
348        n: usize,
349        context: &E,
350    ) {
351        let mut count = 0;
352        loop {
353            select! {
354                res = rx.recv() => {
355                    res.expect("should have received message");
356                    count += 1;
357                },
358                _ = context.sleep(Duration::from_millis(100)) => { break; },
359            }
360        }
361        assert_eq!(n, count);
362    }
363
364    #[test]
365    fn test_basic_routing() {
366        // Can register a subchannel and send messages to it.
367        let executor = deterministic::Runner::default();
368        executor.start(|context| async move {
369            let mut oracle = start_network(context.clone());
370
371            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
372            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
373            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
374
375            let (_, mut sub_rx1) = handle1.register(7).await.unwrap();
376            let (mut sub_tx2, _) = handle2.register(7).await.unwrap();
377
378            // Send and receive
379            let payload = Bytes::from_static(b"hello");
380            let _ = sub_tx2
381                .send(Recipients::One(pk1.clone()), payload.clone(), false)
382                .await
383                .unwrap();
384            let (from, bytes) = sub_rx1.recv().await.unwrap();
385            assert_eq!(from, pk2);
386            assert_eq!(bytes, payload);
387        });
388    }
389
390    #[test]
391    fn test_multiple_routes() {
392        // Can register multiple subchannels and send messages to each.
393        let executor = deterministic::Runner::default();
394        executor.start(|context| async move {
395            let mut oracle = start_network(context.clone());
396
397            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
398            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
399            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
400
401            let (_, mut rx_a) = handle1.register(10).await.unwrap();
402            let (_, mut rx_b) = handle1.register(20).await.unwrap();
403
404            let (mut tx2_a, _) = handle2.register(10).await.unwrap();
405            let (mut tx2_b, _) = handle2.register(20).await.unwrap();
406
407            let payload_a = Bytes::from_static(b"A");
408            let payload_b = Bytes::from_static(b"B");
409            let _ = tx2_a
410                .send(Recipients::One(pk1.clone()), payload_a.clone(), false)
411                .await
412                .unwrap();
413            let _ = tx2_b
414                .send(Recipients::One(pk1.clone()), payload_b.clone(), false)
415                .await
416                .unwrap();
417
418            let (from_a, bytes_a) = rx_a.recv().await.unwrap();
419            assert_eq!(from_a, pk2);
420            assert_eq!(bytes_a, payload_a);
421
422            let (from_b, bytes_b) = rx_b.recv().await.unwrap();
423            assert_eq!(from_b, pk2);
424            assert_eq!(bytes_b, payload_b);
425        });
426    }
427
428    #[test_traced]
429    fn test_mailbox_capacity_blocks() {
430        // If a single subchannel is full, messages are blocked for all subchannels.
431        let executor = deterministic::Runner::default();
432        executor.start(|context| async move {
433            let mut oracle = start_network(context.clone());
434
435            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
436            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
437            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
438
439            // Register the subchannels.
440            let (tx1, _) = handle1.register(99).await.unwrap();
441            let (tx2, _) = handle1.register(100).await.unwrap();
442            let (_, mut rx1) = handle2.register(99).await.unwrap();
443            let (_, mut rx2) = handle2.register(100).await.unwrap();
444
445            // Send 10 messages to each subchannel from pk1 to pk2.
446            send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
447
448            // Try receiving all messages from the second subchannel.
449            expect_n_messages(&mut rx2, CAPACITY, &context).await;
450
451            // Try receiving from the first subchannel.
452            expect_n_messages(&mut rx1, CAPACITY * 2, &context).await;
453
454            // The second subchannel should be unblocked and receive the rest of the messages.
455            expect_n_messages(&mut rx2, CAPACITY, &context).await;
456        });
457    }
458
459    #[test]
460    fn test_drop_a_full_subchannel() {
461        // Drops the subchannel receiver while the sender is blocked.
462        let executor = deterministic::Runner::default();
463        executor.start(|context| async move {
464            let mut oracle = start_network(context.clone());
465
466            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
467            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
468            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
469
470            // Register the subchannels.
471            let (tx1, _) = handle1.register(99).await.unwrap();
472            let (tx2, _) = handle1.register(100).await.unwrap();
473            let (_, rx1) = handle2.register(99).await.unwrap();
474            let (_, mut rx2) = handle2.register(100).await.unwrap();
475
476            // Send 10 messages to each subchannel from pk1 to pk2.
477            send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
478
479            // Give the demuxers a moment to process messages.
480            context.sleep(Duration::from_millis(100)).await;
481
482            // Try receiving all messages from the second subchannel.
483            expect_n_messages(&mut rx2, CAPACITY, &context).await;
484
485            // Drop the first subchannel, erroring the sender and dropping it.
486            drop(rx1);
487
488            // The second subchannel should be unblocked and receive the rest of the messages.
489            expect_n_messages(&mut rx2, CAPACITY, &context).await;
490        });
491    }
492
493    #[test]
494    fn test_drop_messages_for_unregistered_subchannel() {
495        // Messages are dropped if the subchannel they are for is not registered.
496        let executor = deterministic::Runner::default();
497        executor.start(|context| async move {
498            let mut oracle = start_network(context.clone());
499
500            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
501            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
502            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
503
504            // Register the subchannels.
505            let (tx1, _) = handle1.register(1).await.unwrap();
506            let (tx2, _) = handle1.register(2).await.unwrap();
507            // Do not register the first subchannel on the second peer.
508            let (_, mut rx2) = handle2.register(2).await.unwrap();
509
510            // Send 10 messages to each subchannel from pk1 to pk2.
511            send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
512
513            // Give the demuxers a moment to process messages.
514            context.sleep(Duration::from_millis(100)).await;
515
516            // Try receiving all messages from the second subchannel.
517            expect_n_messages(&mut rx2, CAPACITY * 2, &context).await;
518        });
519    }
520
521    #[test]
522    fn test_duplicate_registration() {
523        // Returns an error if the subchannel is already registered.
524        let executor = deterministic::Runner::default();
525        executor.start(|context| async move {
526            let mut oracle = start_network(context.clone());
527
528            let (_pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
529
530            // Register the subchannel.
531            let (_, _rx) = handle1.register(7).await.unwrap();
532
533            // Registering again should return an error.
534            assert!(matches!(
535                handle1.register(7).await,
536                Err(Error::AlreadyRegistered(_))
537            ));
538        });
539    }
540
541    #[test]
542    fn test_register_after_deregister() {
543        // Can register a channel after it has been deregistered.
544        let executor = deterministic::Runner::default();
545        executor.start(|context| async move {
546            let mut oracle = start_network(context.clone());
547
548            let (_, mut handle) = create_peer(&context, &mut oracle, 0).await;
549            let (_, rx) = handle.register(7).await.unwrap();
550            drop(rx);
551
552            // Registering again should not return an error.
553            handle.register(7).await.unwrap();
554        });
555    }
556}