commonware_p2p/utils/
mux.rs

1//! This utility wraps a [Sender] and [Receiver], providing lightweight sub-channels keyed by
2//! [Channel].
3//!
4//! Usage:
5//! - Call [Muxer::new] to obtain a ([Muxer], [MuxHandle]) pair.
6//! - Call [Muxer::start] or run [Muxer::run] in a background task to demux incoming messages into
7//!   per-subchannel queues.
8//! - Call [MuxHandle::register] to obtain a ([SubSender], [SubReceiver]) pair for that subchannel,
9//!   even if the muxer is already running.
10
11use crate::{Channel, CheckedSender, LimitedSender, Message, Receiver, Recipients, Sender};
12use bytes::{BufMut, Bytes, BytesMut};
13use commonware_codec::{varint::UInt, EncodeSize, Error as CodecError, ReadExt, Write};
14use commonware_macros::select_loop;
15use commonware_runtime::{spawn_cell, ContextCell, Handle, Spawner};
16use futures::{
17    channel::{mpsc, oneshot},
18    SinkExt, StreamExt,
19};
20use std::{collections::HashMap, fmt::Debug, time::SystemTime};
21use thiserror::Error;
22use tracing::debug;
23
24/// Errors that can occur when interacting with a [SubReceiver] or [MuxHandle].
25#[derive(Error, Debug)]
26pub enum Error {
27    #[error("subchannel already registered: {0}")]
28    AlreadyRegistered(Channel),
29    #[error("muxer is closed")]
30    Closed,
31    #[error("recv failed")]
32    RecvFailed,
33}
34
35/// Parse a muxed message into its subchannel and payload.
36pub fn parse(mut bytes: Bytes) -> Result<(Channel, Bytes), CodecError> {
37    let subchannel: Channel = UInt::read(&mut bytes)?.into();
38    Ok((subchannel, bytes))
39}
40
41/// Control messages for the [Muxer].
42enum Control<R: Receiver> {
43    Register {
44        subchannel: Channel,
45        sender: oneshot::Sender<mpsc::Receiver<Message<R::PublicKey>>>,
46    },
47    Deregister {
48        subchannel: Channel,
49    },
50}
51
52/// Thread-safe routing table mapping each [Channel] to the [mpsc::Sender] for [`Message<P>`].
53type Routes<P> = HashMap<Channel, mpsc::Sender<Message<P>>>;
54
55/// A backup channel response, with a [SubSender] to respond, the [Channel] that wasn't registered,
56/// and the [Message] received.
57type BackupResponse<P> = (Channel, Message<P>);
58
59/// A multiplexer of p2p channels into subchannels.
60pub struct Muxer<E: Spawner, S: Sender, R: Receiver> {
61    context: ContextCell<E>,
62    sender: S,
63    receiver: R,
64    mailbox_size: usize,
65    control_rx: mpsc::UnboundedReceiver<Control<R>>,
66    routes: Routes<R::PublicKey>,
67    backup: Option<mpsc::Sender<BackupResponse<R::PublicKey>>>,
68}
69
70impl<E: Spawner, S: Sender, R: Receiver> Muxer<E, S, R> {
71    /// Create a multiplexed wrapper around a [Sender] and [Receiver] pair, and return a ([Muxer],
72    /// [MuxHandle]) pair that can be used to register routes dynamically.
73    pub fn new(context: E, sender: S, receiver: R, mailbox_size: usize) -> (Self, MuxHandle<S, R>) {
74        Self::builder(context, sender, receiver, mailbox_size).build()
75    }
76
77    /// Creates a [MuxerBuilder] that can be used to configure and build a [Muxer].
78    pub fn builder(
79        context: E,
80        sender: S,
81        receiver: R,
82        mailbox_size: usize,
83    ) -> MuxerBuilder<E, S, R> {
84        let (control_tx, control_rx) = mpsc::unbounded();
85        let mux = Self {
86            context: ContextCell::new(context),
87            sender,
88            receiver,
89            mailbox_size,
90            control_rx,
91            routes: HashMap::new(),
92            backup: None,
93        };
94
95        let mux_handle = MuxHandle {
96            sender: mux.sender.clone(),
97            control_tx,
98        };
99
100        MuxerBuilder { mux, mux_handle }
101    }
102
103    /// Start the demuxer using the given spawner.
104    pub fn start(mut self) -> Handle<Result<(), R::Error>> {
105        spawn_cell!(self.context, self.run().await)
106    }
107
108    /// Drive demultiplexing of messages into per-subchannel receivers.
109    ///
110    /// Callers should run this in a background task for as long as the underlying `Receiver` is
111    /// expected to receive traffic.
112    pub async fn run(mut self) -> Result<(), R::Error> {
113        select_loop! {
114            self.context,
115            on_stopped => {
116                debug!("context shutdown, stopping muxer");
117            },
118            // Prefer control messages because network messages will
119            // already block when full (providing backpressure).
120            control = self.control_rx.next() => {
121                match control {
122                    Some(Control::Register { subchannel, sender }) => {
123                        // If the subchannel is already registered, drop the sender.
124                        if self.routes.contains_key(&subchannel) {
125                            continue;
126                        }
127
128                        // Otherwise, create a new subchannel and send the receiver to the caller.
129                        let (tx, rx) = mpsc::channel(self.mailbox_size);
130                        self.routes.insert(subchannel, tx);
131                        let _ = sender.send(rx);
132                    },
133                    Some(Control::Deregister { subchannel }) => {
134                        // Remove the route.
135                        self.routes.remove(&subchannel);
136                    },
137                    None => {
138                        // If the control channel is closed, we can shut down since there must
139                        // be no more registrations, and all receivers must have been dropped.
140                        return Ok(());
141                    }
142                }
143            },
144            // Process network messages.
145            message = self.receiver.recv() => {
146                // Decode the message.
147                let (pk, bytes) = message?;
148                let (subchannel, bytes) = match parse(bytes) {
149                    Ok(parsed) => parsed,
150                    Err(_) => {
151                        debug!(?pk, "invalid message: missing subchannel");
152                        continue;
153                    }
154                };
155
156                // Get the route for the subchannel.
157                let Some(sender) = self.routes.get_mut(&subchannel) else {
158                    // Attempt to use the backup channel if available.
159                    if let Some(backup) = &mut self.backup {
160                        if let Err(e) = backup.send((subchannel, (pk, bytes))).await {
161                            debug!(?subchannel, ?e, "failed to send message to backup channel");
162                        }
163                    }
164
165                    // Drops the message if the subchannel is not found or the backup
166                    // channel was not used.
167                    continue;
168                };
169
170                // Send the message to the subchannel, blocking if the queue is full.
171                if let Err(e) = sender.send((pk, bytes)).await {
172                    // Remove the route for the subchannel.
173                    self.routes.remove(&subchannel);
174
175                    // Failure, drop the sender since the receiver is no longer interested.
176                    debug!(?subchannel, ?e, "failed to send message to subchannel");
177
178                    // NOTE: The channel is deregistered, but it wasn't when the message was received.
179                    // The backup channel is not used in this case.
180                }
181            }
182        }
183
184        Ok(())
185    }
186}
187
188/// A clonable handle that allows registering routes at any time, even after the [Muxer] is running.
189#[derive(Clone)]
190pub struct MuxHandle<S: Sender, R: Receiver> {
191    sender: S,
192    control_tx: mpsc::UnboundedSender<Control<R>>,
193}
194
195impl<S: Sender, R: Receiver> MuxHandle<S, R> {
196    /// Open a `subchannel`. Returns a ([SubSender], [SubReceiver]) pair that can be used to send
197    /// and receive messages for that subchannel.
198    ///
199    /// Panics if the subchannel is already registered at any point.
200    pub async fn register(
201        &mut self,
202        subchannel: Channel,
203    ) -> Result<(SubSender<S>, SubReceiver<R>), Error> {
204        let (tx, rx) = oneshot::channel();
205        self.control_tx
206            .send(Control::Register {
207                subchannel,
208                sender: tx,
209            })
210            .await
211            .map_err(|_| Error::Closed)?;
212        let receiver = rx.await.map_err(|_| Error::AlreadyRegistered(subchannel))?;
213
214        Ok((
215            SubSender {
216                subchannel,
217                inner: GlobalSender::new(self.sender.clone()),
218            },
219            SubReceiver {
220                receiver,
221                control_tx: Some(self.control_tx.clone()),
222                subchannel,
223            },
224        ))
225    }
226}
227
228/// Sender that routes messages to the `subchannel`.
229#[derive(Clone, Debug)]
230pub struct SubSender<S: Sender> {
231    inner: GlobalSender<S>,
232    subchannel: Channel,
233}
234
235impl<S: Sender> LimitedSender for SubSender<S> {
236    type PublicKey = S::PublicKey;
237    type Checked<'a> = CheckedGlobalSender<'a, S>;
238
239    async fn check(
240        &mut self,
241        recipients: Recipients<Self::PublicKey>,
242    ) -> Result<Self::Checked<'_>, SystemTime> {
243        self.inner
244            .check(recipients)
245            .await
246            .map(|checked| checked.with_subchannel(self.subchannel))
247    }
248}
249
250/// Receiver that yields messages for a specific subchannel.
251pub struct SubReceiver<R: Receiver> {
252    receiver: mpsc::Receiver<Message<R::PublicKey>>,
253    control_tx: Option<mpsc::UnboundedSender<Control<R>>>,
254    subchannel: Channel,
255}
256
257impl<R: Receiver> Receiver for SubReceiver<R> {
258    type Error = Error;
259    type PublicKey = R::PublicKey;
260
261    async fn recv(&mut self) -> Result<Message<Self::PublicKey>, Self::Error> {
262        self.receiver.next().await.ok_or(Error::RecvFailed)
263    }
264}
265
266impl<R: Receiver> Debug for SubReceiver<R> {
267    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
268        write!(f, "SubReceiver({})", self.subchannel)
269    }
270}
271
272impl<R: Receiver> Drop for SubReceiver<R> {
273    fn drop(&mut self) {
274        // Take the control channel to avoid cloning.
275        let control_tx = self
276            .control_tx
277            .take()
278            .expect("SubReceiver::drop called twice");
279
280        // Deregister the subchannel immediately.
281        let _ = control_tx.unbounded_send(Control::Deregister {
282            subchannel: self.subchannel,
283        });
284    }
285}
286
287/// Sender that can send messages over any sub [Channel].
288#[derive(Clone, Debug)]
289pub struct GlobalSender<S: Sender> {
290    inner: S,
291}
292
293impl<S: Sender> GlobalSender<S> {
294    /// Create a new [GlobalSender] wrapping the given [Sender].
295    pub const fn new(inner: S) -> Self {
296        Self { inner }
297    }
298
299    /// Send a message over the given `subchannel`.
300    pub async fn send(
301        &mut self,
302        subchannel: Channel,
303        recipients: Recipients<S::PublicKey>,
304        payload: Bytes,
305        priority: bool,
306    ) -> Result<Vec<S::PublicKey>, <S::Checked<'_> as CheckedSender>::Error> {
307        match self.check(recipients).await {
308            Ok(checked) => {
309                checked
310                    .with_subchannel(subchannel)
311                    .send(payload, priority)
312                    .await
313            }
314            Err(_) => Ok(Vec::new()),
315        }
316    }
317}
318
319impl<S: Sender> LimitedSender for GlobalSender<S> {
320    type PublicKey = S::PublicKey;
321    type Checked<'a> = CheckedGlobalSender<'a, S>;
322
323    async fn check(
324        &mut self,
325        recipients: Recipients<Self::PublicKey>,
326    ) -> Result<Self::Checked<'_>, SystemTime> {
327        self.inner
328            .check(recipients)
329            .await
330            .map(|checked| CheckedGlobalSender {
331                subchannel: None,
332                inner: checked,
333            })
334    }
335}
336
337/// A checked sender for a [GlobalSender].
338pub struct CheckedGlobalSender<'a, S: Sender> {
339    subchannel: Option<Channel>,
340    inner: S::Checked<'a>,
341}
342
343impl<'a, S: Sender> CheckedGlobalSender<'a, S> {
344    /// Set the subchannel for this sender.
345    pub const fn with_subchannel(mut self, subchannel: Channel) -> Self {
346        self.subchannel = Some(subchannel);
347        self
348    }
349}
350
351impl<'a, S: Sender> CheckedSender for CheckedGlobalSender<'a, S> {
352    type PublicKey = S::PublicKey;
353    type Error = <S::Checked<'a> as CheckedSender>::Error;
354
355    async fn send(
356        self,
357        message: Bytes,
358        priority: bool,
359    ) -> Result<Vec<Self::PublicKey>, Self::Error> {
360        let subchannel = UInt(self.subchannel.expect("subchannel not set"));
361        let mut buf = BytesMut::with_capacity(subchannel.encode_size() + message.len());
362        subchannel.write(&mut buf);
363        buf.put_slice(&message);
364        self.inner.send(buf.freeze(), priority).await
365    }
366}
367
368/// A generic builder interface.
369pub trait Builder {
370    /// The output type produced by the builder.
371    type Output;
372
373    /// Builds the output type, consuming `self`.
374    fn build(self) -> Self::Output;
375}
376
377/// A builder that constructs a [Muxer].
378pub struct MuxerBuilder<E: Spawner, S: Sender, R: Receiver> {
379    mux: Muxer<E, S, R>,
380    mux_handle: MuxHandle<S, R>,
381}
382
383impl<E: Spawner, S: Sender, R: Receiver> Builder for MuxerBuilder<E, S, R> {
384    type Output = (Muxer<E, S, R>, MuxHandle<S, R>);
385
386    fn build(self) -> Self::Output {
387        (self.mux, self.mux_handle)
388    }
389}
390
391impl<E: Spawner, S: Sender, R: Receiver> MuxerBuilder<E, S, R> {
392    /// Registers a backup channel with the muxer.
393    pub fn with_backup(mut self) -> MuxerBuilderWithBackup<E, S, R> {
394        let (tx, rx) = mpsc::channel(self.mux.mailbox_size);
395        self.mux.backup = Some(tx);
396
397        MuxerBuilderWithBackup {
398            mux: self.mux,
399            mux_handle: self.mux_handle,
400            backup_rx: rx,
401        }
402    }
403
404    /// Registers a global sender with the muxer.
405    pub fn with_global_sender(self) -> MuxerBuilderWithGlobalSender<E, S, R> {
406        let global_sender = GlobalSender::new(self.mux.sender.clone());
407
408        MuxerBuilderWithGlobalSender {
409            mux: self.mux,
410            mux_handle: self.mux_handle,
411            global_sender,
412        }
413    }
414}
415
416/// A builder that constructs a [Muxer] with a backup channel.
417pub struct MuxerBuilderWithBackup<E: Spawner, S: Sender, R: Receiver> {
418    mux: Muxer<E, S, R>,
419    mux_handle: MuxHandle<S, R>,
420    backup_rx: mpsc::Receiver<BackupResponse<R::PublicKey>>,
421}
422
423impl<E: Spawner, S: Sender, R: Receiver> MuxerBuilderWithBackup<E, S, R> {
424    /// Registers a global sender with the muxer.
425    pub fn with_global_sender(self) -> MuxerBuilderAllOpts<E, S, R> {
426        let global_sender = GlobalSender::new(self.mux.sender.clone());
427
428        MuxerBuilderAllOpts {
429            mux: self.mux,
430            mux_handle: self.mux_handle,
431            backup_rx: self.backup_rx,
432            global_sender,
433        }
434    }
435}
436
437impl<E: Spawner, S: Sender, R: Receiver> Builder for MuxerBuilderWithBackup<E, S, R> {
438    type Output = (
439        Muxer<E, S, R>,
440        MuxHandle<S, R>,
441        mpsc::Receiver<BackupResponse<R::PublicKey>>,
442    );
443
444    fn build(self) -> Self::Output {
445        (self.mux, self.mux_handle, self.backup_rx)
446    }
447}
448
449/// A builder that constructs a [Muxer] with a [GlobalSender].
450pub struct MuxerBuilderWithGlobalSender<E: Spawner, S: Sender, R: Receiver> {
451    mux: Muxer<E, S, R>,
452    mux_handle: MuxHandle<S, R>,
453    global_sender: GlobalSender<S>,
454}
455
456impl<E: Spawner, S: Sender, R: Receiver> MuxerBuilderWithGlobalSender<E, S, R> {
457    /// Registers a backup channel with the muxer.
458    pub fn with_backup(mut self) -> MuxerBuilderAllOpts<E, S, R> {
459        let (tx, rx) = mpsc::channel(self.mux.mailbox_size);
460        self.mux.backup = Some(tx);
461
462        MuxerBuilderAllOpts {
463            mux: self.mux,
464            mux_handle: self.mux_handle,
465            backup_rx: rx,
466            global_sender: self.global_sender,
467        }
468    }
469}
470
471impl<E: Spawner, S: Sender, R: Receiver> Builder for MuxerBuilderWithGlobalSender<E, S, R> {
472    type Output = (Muxer<E, S, R>, MuxHandle<S, R>, GlobalSender<S>);
473
474    fn build(self) -> Self::Output {
475        (self.mux, self.mux_handle, self.global_sender)
476    }
477}
478
479/// A builder that constructs a [Muxer] with a [GlobalSender] and backup channel.
480pub struct MuxerBuilderAllOpts<E: Spawner, S: Sender, R: Receiver> {
481    mux: Muxer<E, S, R>,
482    mux_handle: MuxHandle<S, R>,
483    backup_rx: mpsc::Receiver<BackupResponse<R::PublicKey>>,
484    global_sender: GlobalSender<S>,
485}
486
487impl<E: Spawner, S: Sender, R: Receiver> Builder for MuxerBuilderAllOpts<E, S, R> {
488    type Output = (
489        Muxer<E, S, R>,
490        MuxHandle<S, R>,
491        mpsc::Receiver<BackupResponse<R::PublicKey>>,
492        GlobalSender<S>,
493    );
494
495    fn build(self) -> Self::Output {
496        (
497            self.mux,
498            self.mux_handle,
499            self.backup_rx,
500            self.global_sender,
501        )
502    }
503}
504
505#[cfg(test)]
506mod tests {
507    use super::*;
508    use crate::{
509        simulated::{self, Link, Network, Oracle},
510        Recipients,
511    };
512    use bytes::Bytes;
513    use commonware_cryptography::{
514        ed25519::{PrivateKey, PublicKey},
515        Signer,
516    };
517    use commonware_macros::{select, test_traced};
518    use commonware_runtime::{deterministic, Metrics, Quota, Runner};
519    use std::{num::NonZeroU32, time::Duration};
520
521    const LINK: Link = Link {
522        latency: Duration::from_millis(0),
523        jitter: Duration::from_millis(0),
524        success_rate: 1.0,
525    };
526    const CAPACITY: usize = 5usize;
527
528    /// Default rate limit set high enough to not interfere with normal operation
529    const TEST_QUOTA: Quota = Quota::per_second(NonZeroU32::MAX);
530
531    /// Start the network and return the oracle.
532    fn start_network(context: deterministic::Context) -> Oracle<PublicKey, deterministic::Context> {
533        let (network, oracle) = Network::new(
534            context.with_label("network"),
535            simulated::Config {
536                max_size: 1024 * 1024,
537                disconnect_on_block: true,
538                tracked_peer_sets: None,
539            },
540        );
541        network.start();
542        oracle
543    }
544
545    /// Create a public key from a seed.
546    fn pk(seed: u64) -> PublicKey {
547        PrivateKey::from_seed(seed).public_key()
548    }
549
550    /// Link two peers bidirectionally.
551    async fn link_bidirectional(
552        oracle: &mut Oracle<PublicKey, deterministic::Context>,
553        a: PublicKey,
554        b: PublicKey,
555    ) {
556        oracle.add_link(a.clone(), b.clone(), LINK).await.unwrap();
557        oracle.add_link(b, a, LINK).await.unwrap();
558    }
559
560    /// Create a peer and register it with the oracle.
561    async fn create_peer(
562        context: &deterministic::Context,
563        oracle: &mut Oracle<PublicKey, deterministic::Context>,
564        seed: u64,
565    ) -> (
566        PublicKey,
567        MuxHandle<impl Sender<PublicKey = PublicKey>, impl Receiver<PublicKey = PublicKey>>,
568    ) {
569        let pubkey = pk(seed);
570        let (sender, receiver) = oracle
571            .control(pubkey.clone())
572            .register(0, TEST_QUOTA)
573            .await
574            .unwrap();
575        let (mux, handle) = Muxer::new(context.with_label("mux"), sender, receiver, CAPACITY);
576        mux.start();
577        (pubkey, handle)
578    }
579
580    /// Create a peer and register it with the oracle.
581    async fn create_peer_with_backup_and_global_sender(
582        context: &deterministic::Context,
583        oracle: &mut Oracle<PublicKey, deterministic::Context>,
584        seed: u64,
585    ) -> (
586        PublicKey,
587        MuxHandle<impl Sender<PublicKey = PublicKey>, impl Receiver<PublicKey = PublicKey>>,
588        mpsc::Receiver<BackupResponse<PublicKey>>,
589        GlobalSender<simulated::Sender<PublicKey, deterministic::Context>>,
590    ) {
591        let pubkey = pk(seed);
592        let (sender, receiver) = oracle
593            .control(pubkey.clone())
594            .register(0, TEST_QUOTA)
595            .await
596            .unwrap();
597        let (mux, handle, backup, global_sender) =
598            Muxer::builder(context.with_label("mux"), sender, receiver, CAPACITY)
599                .with_backup()
600                .with_global_sender()
601                .build();
602        mux.start();
603        (pubkey, handle, backup, global_sender)
604    }
605
606    /// Send a burst of messages to a list of senders.
607    async fn send_burst<S: Sender>(txs: &mut [SubSender<S>], count: usize) {
608        for i in 0..count {
609            let payload = Bytes::from(vec![i as u8]);
610            for tx in txs.iter_mut() {
611                let _ = tx
612                    .send(Recipients::All, payload.clone(), false)
613                    .await
614                    .unwrap();
615            }
616        }
617    }
618
619    /// Wait for `n` messages to be received on the receiver.
620    async fn expect_n_messages(
621        rx: &mut SubReceiver<impl Receiver<PublicKey = PublicKey>>,
622        n: usize,
623    ) {
624        let mut count = 0;
625        loop {
626            select! {
627                res = rx.recv() => {
628                    res.expect("should have received message");
629                    count += 1;
630                },
631            }
632
633            if count >= n {
634                break;
635            }
636        }
637        assert_eq!(n, count);
638    }
639
640    /// Wait for `n` messages to be received on the receiver + backup receiver.
641    async fn expect_n_messages_with_backup(
642        rx: &mut SubReceiver<impl Receiver<PublicKey = PublicKey>>,
643        backup_rx: &mut mpsc::Receiver<BackupResponse<PublicKey>>,
644        n: usize,
645        n_backup: usize,
646    ) {
647        let mut count_std = 0;
648        let mut count_backup = 0;
649        loop {
650            select! {
651                res = rx.recv() => {
652                    res.expect("should have received message");
653                    count_std += 1;
654                },
655                res = backup_rx.next() => {
656                    res.expect("should have received message");
657                    count_backup += 1;
658                },
659            }
660
661            if count_std >= n && count_backup >= n_backup {
662                break;
663            }
664        }
665        assert_eq!(n, count_std);
666        assert_eq!(n_backup, count_backup);
667    }
668
669    #[test]
670    fn test_basic_routing() {
671        // Can register a subchannel and send messages to it.
672        let executor = deterministic::Runner::default();
673        executor.start(|context| async move {
674            let mut oracle = start_network(context.clone());
675
676            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
677            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
678            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
679
680            let (_, mut sub_rx1) = handle1.register(7).await.unwrap();
681            let (mut sub_tx2, _) = handle2.register(7).await.unwrap();
682
683            // Send and receive
684            let payload = Bytes::from_static(b"hello");
685            let _ = sub_tx2
686                .send(Recipients::One(pk1.clone()), payload.clone(), false)
687                .await
688                .unwrap();
689            let (from, bytes) = sub_rx1.recv().await.unwrap();
690            assert_eq!(from, pk2);
691            assert_eq!(bytes, payload);
692        });
693    }
694
695    #[test]
696    fn test_multiple_routes() {
697        // Can register multiple subchannels and send messages to each.
698        let executor = deterministic::Runner::default();
699        executor.start(|context| async move {
700            let mut oracle = start_network(context.clone());
701
702            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
703            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
704            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
705
706            let (_, mut rx_a) = handle1.register(10).await.unwrap();
707            let (_, mut rx_b) = handle1.register(20).await.unwrap();
708
709            let (mut tx2_a, _) = handle2.register(10).await.unwrap();
710            let (mut tx2_b, _) = handle2.register(20).await.unwrap();
711
712            let payload_a = Bytes::from_static(b"A");
713            let payload_b = Bytes::from_static(b"B");
714            let _ = tx2_a
715                .send(Recipients::One(pk1.clone()), payload_a.clone(), false)
716                .await
717                .unwrap();
718            let _ = tx2_b
719                .send(Recipients::One(pk1.clone()), payload_b.clone(), false)
720                .await
721                .unwrap();
722
723            let (from_a, bytes_a) = rx_a.recv().await.unwrap();
724            assert_eq!(from_a, pk2);
725            assert_eq!(bytes_a, payload_a);
726
727            let (from_b, bytes_b) = rx_b.recv().await.unwrap();
728            assert_eq!(from_b, pk2);
729            assert_eq!(bytes_b, payload_b);
730        });
731    }
732
733    #[test_traced]
734    fn test_mailbox_capacity_blocks() {
735        // If a single subchannel is full, messages are blocked for all subchannels.
736        let executor = deterministic::Runner::default();
737        executor.start(|context| async move {
738            let mut oracle = start_network(context.clone());
739
740            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
741            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
742            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
743
744            // Register the subchannels.
745            let (tx1, _) = handle1.register(99).await.unwrap();
746            let (tx2, _) = handle1.register(100).await.unwrap();
747            let (_, mut rx1) = handle2.register(99).await.unwrap();
748            let (_, mut rx2) = handle2.register(100).await.unwrap();
749
750            // Send 10 messages to each subchannel from pk1 to pk2.
751            send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
752
753            // Try receiving all messages from the second subchannel.
754            expect_n_messages(&mut rx2, CAPACITY).await;
755
756            // Try receiving from the first subchannel.
757            expect_n_messages(&mut rx1, CAPACITY * 2).await;
758
759            // The second subchannel should be unblocked and receive the rest of the messages.
760            expect_n_messages(&mut rx2, CAPACITY).await;
761        });
762    }
763
764    #[test]
765    fn test_drop_a_full_subchannel() {
766        // Drops the subchannel receiver while the sender is blocked.
767        let executor = deterministic::Runner::default();
768        executor.start(|context| async move {
769            let mut oracle = start_network(context.clone());
770
771            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
772            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
773            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
774
775            // Register the subchannels.
776            let (tx1, _) = handle1.register(99).await.unwrap();
777            let (tx2, _) = handle1.register(100).await.unwrap();
778            let (_, rx1) = handle2.register(99).await.unwrap();
779            let (_, mut rx2) = handle2.register(100).await.unwrap();
780
781            // Send 10 messages to each subchannel from pk1 to pk2.
782            send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
783
784            // Try receiving all messages from the second subchannel.
785            expect_n_messages(&mut rx2, CAPACITY).await;
786
787            // Drop the first subchannel, erroring the sender and dropping it.
788            drop(rx1);
789
790            // The second subchannel should be unblocked and receive the rest of the messages.
791            expect_n_messages(&mut rx2, CAPACITY).await;
792        });
793    }
794
795    #[test]
796    fn test_drop_messages_for_unregistered_subchannel() {
797        // Messages are dropped if the subchannel they are for is not registered.
798        let executor = deterministic::Runner::default();
799        executor.start(|context| async move {
800            let mut oracle = start_network(context.clone());
801
802            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
803            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
804            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
805
806            // Register the subchannels.
807            let (tx1, _) = handle1.register(1).await.unwrap();
808            let (tx2, _) = handle1.register(2).await.unwrap();
809            // Do not register the first subchannel on the second peer.
810            let (_, mut rx2) = handle2.register(2).await.unwrap();
811
812            // Send 10 messages to each subchannel from pk1 to pk2.
813            send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
814
815            // Try receiving all messages from the second subchannel.
816            expect_n_messages(&mut rx2, CAPACITY * 2).await;
817        });
818    }
819
820    #[test]
821    fn test_backup_for_unregistered_subchannel() {
822        // Messages are forwarded to the backup channel if the subchannel they are for
823        // is not registered.
824        let executor = deterministic::Runner::default();
825        executor.start(|context| async move {
826            let mut oracle = start_network(context.clone());
827
828            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
829            let (pk2, mut handle2, mut backup2, _) =
830                create_peer_with_backup_and_global_sender(&context, &mut oracle, 1).await;
831            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
832
833            // Register the subchannels.
834            let (tx1, _) = handle1.register(1).await.unwrap();
835            let (tx2, _) = handle1.register(2).await.unwrap();
836            // Do not register the first subchannel on the second peer.
837            let (_, mut rx2) = handle2.register(2).await.unwrap();
838
839            // Send 10 messages to each subchannel from pk1 to pk2.
840            send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
841
842            // Try receiving all messages from the second subchannel and backup channel.
843            // All 20 messages sent should be received.
844            expect_n_messages_with_backup(&mut rx2, &mut backup2, CAPACITY * 2, CAPACITY * 2).await;
845        });
846    }
847
848    #[test]
849    fn test_backup_for_unregistered_subchannel_response() {
850        // Messages are forwarded to the backup channel if the subchannel they are for
851        // is not registered.
852        let executor = deterministic::Runner::default();
853        executor.start(|context| async move {
854            let mut oracle = start_network(context.clone());
855
856            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
857            let (pk2, _handle2, mut backup2, mut global_sender2) =
858                create_peer_with_backup_and_global_sender(&context, &mut oracle, 1).await;
859            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
860
861            // Register the subchannels.
862            let (tx1, mut rx1) = handle1.register(1).await.unwrap();
863            // Do not register any subchannels on the second peer.
864
865            // Send 1 message to each subchannel from pk1 to pk2.
866            send_burst(&mut [tx1], 1).await;
867
868            // Get the message from pk2's backup channel and respond.
869            let (subchannel, (from, _)) = backup2.next().await.unwrap();
870            assert_eq!(subchannel, 1);
871            assert_eq!(from, pk1);
872            global_sender2
873                .send(
874                    subchannel,
875                    Recipients::One(pk1),
876                    b"TEST".to_vec().into(),
877                    true,
878                )
879                .await
880                .unwrap();
881
882            // Receive the response with pk1's receiver.
883            let (from, bytes) = rx1.recv().await.unwrap();
884            assert_eq!(from, pk2);
885            assert_eq!(bytes.as_ref(), b"TEST");
886        });
887    }
888
889    #[test]
890    fn test_message_dropped_for_closed_subchannel() {
891        // Messages are dropped if the subchannel they are for is registered, but has been closed.
892        //
893        // NOTE: This case should be exceedingly rare in practice due to `SubReceiver` deregistering
894        // the subchannel on drop, but is included for completeness.
895        let executor = deterministic::Runner::default();
896        executor.start(|context| async move {
897            let mut oracle = start_network(context.clone());
898
899            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
900            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
901            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
902
903            // Register the subchannels.
904            let (tx1, _) = handle1.register(1).await.unwrap();
905            let (tx2, _) = handle1.register(2).await.unwrap();
906            let (_, mut rx1) = handle2.register(1).await.unwrap();
907            let (_, mut rx2) = handle2.register(2).await.unwrap();
908
909            // Send 10 messages to subchannel 1 from pk1 to pk2.
910            send_burst(&mut [tx1.clone()], CAPACITY * 2).await;
911
912            // Try receiving all messages from the first subchannel.
913            expect_n_messages(&mut rx1, CAPACITY * 2).await;
914
915            // Send 10 messages to subchannel 2 from pk1 to pk2.
916            send_burst(&mut [tx2.clone()], CAPACITY * 2).await;
917
918            // Try receiving all messages from the first subchannel.
919            expect_n_messages(&mut rx2, CAPACITY * 2).await;
920
921            // Explicitly close the underlying receiver for the first subchannel.
922            rx1.receiver.close();
923
924            // Send 10 messages to each subchannel from pk1 to pk2.
925            send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
926
927            // Try receiving all messages from the second subchannel.
928            expect_n_messages(&mut rx2, CAPACITY * 2).await;
929        });
930    }
931
932    #[test]
933    fn test_dropped_backup_channel_doesnt_block() {
934        let executor = deterministic::Runner::default();
935        executor.start(|context| async move {
936            let mut oracle = start_network(context.clone());
937
938            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
939            let (pk2, mut handle2, backup2, _) =
940                create_peer_with_backup_and_global_sender(&context, &mut oracle, 1).await;
941            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
942
943            // Explicitly drop the backup receiver.
944            drop(backup2);
945
946            // Register the subchannels.
947            let (tx1, _) = handle1.register(1).await.unwrap();
948            let (tx2, _) = handle1.register(2).await.unwrap();
949            // Do not register the first subchannel on the second peer.
950            let (_, mut rx2) = handle2.register(2).await.unwrap();
951
952            // Send 10 messages to each subchannel from pk1 to pk2.
953            send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
954
955            // Try receiving all messages from the second subchannel.
956            expect_n_messages(&mut rx2, CAPACITY * 2).await;
957        });
958    }
959
960    #[test]
961    fn test_duplicate_registration() {
962        // Returns an error if the subchannel is already registered.
963        let executor = deterministic::Runner::default();
964        executor.start(|context| async move {
965            let mut oracle = start_network(context.clone());
966
967            let (_pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
968
969            // Register the subchannel.
970            let (_, _rx) = handle1.register(7).await.unwrap();
971
972            // Registering again should return an error.
973            assert!(matches!(
974                handle1.register(7).await,
975                Err(Error::AlreadyRegistered(_))
976            ));
977        });
978    }
979
980    #[test]
981    fn test_register_after_deregister() {
982        // Can register a channel after it has been deregistered.
983        let executor = deterministic::Runner::default();
984        executor.start(|context| async move {
985            let mut oracle = start_network(context.clone());
986
987            let (_, mut handle) = create_peer(&context, &mut oracle, 0).await;
988            let (_, rx) = handle.register(7).await.unwrap();
989            drop(rx);
990
991            // Registering again should not return an error.
992            handle.register(7).await.unwrap();
993        });
994    }
995}