Skip to main content

commonware_p2p/utils/
mux.rs

1//! This utility wraps a [Sender] and [Receiver], providing lightweight sub-channels keyed by
2//! [Channel].
3//!
4//! Usage:
5//! - Call [Muxer::new] to obtain a ([Muxer], [MuxHandle]) pair.
6//! - Call [Muxer::start] or run [Muxer::run] in a background task to demux incoming messages into
7//!   per-subchannel queues.
8//! - Call [MuxHandle::register] to obtain a ([SubSender], [SubReceiver]) pair for that subchannel,
9//!   even if the muxer is already running.
10
11use crate::{Channel, CheckedSender, LimitedSender, Message, Receiver, Recipients, Sender};
12use commonware_codec::{varint::UInt, Encode, Error as CodecError, ReadExt};
13use commonware_macros::select_loop;
14use commonware_runtime::{spawn_cell, BufMut, ContextCell, Handle, IoBuf, IoBufMut, Spawner};
15use commonware_utils::channel::{
16    fallible::FallibleExt,
17    mpsc::{self, error::TrySendError},
18    oneshot,
19};
20use std::{collections::HashMap, fmt::Debug, time::SystemTime};
21use thiserror::Error;
22use tracing::debug;
23
24/// Errors that can occur when interacting with a [SubReceiver] or [MuxHandle].
25#[derive(Error, Debug)]
26pub enum Error {
27    #[error("subchannel already registered: {0}")]
28    AlreadyRegistered(Channel),
29    #[error("muxer is closed")]
30    Closed,
31    #[error("recv failed")]
32    RecvFailed,
33}
34
35/// Parse a muxed message into its subchannel and payload.
36pub fn parse(mut buf: IoBuf) -> Result<(Channel, IoBuf), CodecError> {
37    let subchannel: Channel = UInt::read(&mut buf)?.into();
38    Ok((subchannel, buf))
39}
40
41/// Control messages for the [Muxer].
42enum Control<R: Receiver> {
43    Register {
44        subchannel: Channel,
45        sender: oneshot::Sender<mpsc::Receiver<Message<R::PublicKey>>>,
46    },
47    Deregister {
48        subchannel: Channel,
49    },
50}
51
52/// Thread-safe routing table mapping each [Channel] to the [mpsc::Sender] for [`Message<P>`].
53type Routes<P> = HashMap<Channel, mpsc::Sender<Message<P>>>;
54
55/// A backup channel response, with a [SubSender] to respond, the [Channel] that wasn't registered,
56/// and the [Message] received.
57type BackupResponse<P> = (Channel, Message<P>);
58
59/// A multiplexer of p2p channels into subchannels.
60pub struct Muxer<E: Spawner, S: Sender, R: Receiver> {
61    context: ContextCell<E>,
62    sender: S,
63    receiver: R,
64    mailbox_size: usize,
65    control_rx: mpsc::UnboundedReceiver<Control<R>>,
66    routes: Routes<R::PublicKey>,
67    backup: Option<mpsc::Sender<BackupResponse<R::PublicKey>>>,
68}
69
70impl<E: Spawner, S: Sender, R: Receiver> Muxer<E, S, R> {
71    /// Create a multiplexed wrapper around a [Sender] and [Receiver] pair, and return a ([Muxer],
72    /// [MuxHandle]) pair that can be used to register routes dynamically.
73    pub fn new(context: E, sender: S, receiver: R, mailbox_size: usize) -> (Self, MuxHandle<S, R>) {
74        Self::builder(context, sender, receiver, mailbox_size).build()
75    }
76
77    /// Creates a [MuxerBuilder] that can be used to configure and build a [Muxer].
78    pub fn builder(
79        context: E,
80        sender: S,
81        receiver: R,
82        mailbox_size: usize,
83    ) -> MuxerBuilder<E, S, R> {
84        let (control_tx, control_rx) = mpsc::unbounded_channel();
85        let mux = Self {
86            context: ContextCell::new(context),
87            sender,
88            receiver,
89            mailbox_size,
90            control_rx,
91            routes: HashMap::new(),
92            backup: None,
93        };
94
95        let mux_handle = MuxHandle {
96            sender: mux.sender.clone(),
97            control_tx,
98        };
99
100        MuxerBuilder { mux, mux_handle }
101    }
102
103    /// Start the demuxer using the given spawner.
104    pub fn start(mut self) -> Handle<Result<(), R::Error>> {
105        spawn_cell!(self.context, self.run().await)
106    }
107
108    /// Drive demultiplexing of messages into per-subchannel receivers.
109    ///
110    /// Callers should run this in a background task for as long as the underlying `Receiver` is
111    /// expected to receive traffic.
112    pub async fn run(mut self) -> Result<(), R::Error> {
113        select_loop! {
114            self.context,
115            on_stopped => {
116                debug!("context shutdown, stopping muxer");
117            },
118            // Prefer control messages because network messages will
119            // already block when full (providing backpressure).
120            Some(control) = self.control_rx.recv() else {
121                // If the control channel is closed, we can shut down since there must
122                // be no more registrations, and all receivers must have been dropped.
123                return Ok(());
124            } => match control {
125                Control::Register { subchannel, sender } => {
126                    // If the subchannel is already registered, drop the sender.
127                    if self.routes.contains_key(&subchannel) {
128                        continue;
129                    }
130
131                    // Otherwise, create a new subchannel and send the receiver to the caller.
132                    let (tx, rx) = mpsc::channel(self.mailbox_size);
133                    self.routes.insert(subchannel, tx);
134                    let _ = sender.send(rx);
135                }
136                Control::Deregister { subchannel } => {
137                    // Remove the route.
138                    self.routes.remove(&subchannel);
139                }
140            },
141            // Process network messages.
142            message = self.receiver.recv() => {
143                // Decode the message.
144                let (pk, bytes) = message?;
145                let (subchannel, bytes) = match parse(bytes) {
146                    Ok(parsed) => parsed,
147                    Err(_) => {
148                        debug!(?pk, "invalid message: missing subchannel");
149                        continue;
150                    }
151                };
152
153                // Get the route for the subchannel.
154                let Some(sender) = self.routes.get_mut(&subchannel) else {
155                    // Attempt to use the backup channel if available.
156                    if let Some(backup) = &mut self.backup {
157                        if let Err(e) = backup.try_send((subchannel, (pk, bytes))) {
158                            debug!(?subchannel, ?e, "failed to send message to backup channel");
159                        }
160                    }
161
162                    // Drops the message if the subchannel is not found or the backup
163                    // channel was not used.
164                    continue;
165                };
166
167                // Send the message to the subchannel using non-blocking try_send
168                // to avoid head-of-line blocking when one subchannel is slow.
169                if let Err(e) = sender.try_send((pk, bytes)) {
170                    // Check if the channel is disconnected (receiver dropped)
171                    if matches!(e, TrySendError::Closed(_)) {
172                        // Remove the route for the subchannel.
173                        self.routes.remove(&subchannel);
174                        debug!(?subchannel, "subchannel receiver dropped, removing route");
175                    } else {
176                        // Channel is full, drop the message
177                        debug!(?subchannel, "subchannel full, dropping message");
178                    }
179                }
180            },
181        }
182
183        Ok(())
184    }
185}
186
187/// A clonable handle that allows registering routes at any time, even after the [Muxer] is running.
188#[derive(Clone)]
189pub struct MuxHandle<S: Sender, R: Receiver> {
190    sender: S,
191    control_tx: mpsc::UnboundedSender<Control<R>>,
192}
193
194impl<S: Sender, R: Receiver> MuxHandle<S, R> {
195    /// Open a `subchannel`. Returns a ([SubSender], [SubReceiver]) pair that can be used to send
196    /// and receive messages for that subchannel.
197    ///
198    /// Panics if the subchannel is already registered at any point.
199    pub async fn register(
200        &mut self,
201        subchannel: Channel,
202    ) -> Result<(SubSender<S>, SubReceiver<R>), Error> {
203        let (tx, rx) = oneshot::channel();
204        self.control_tx
205            .send(Control::Register {
206                subchannel,
207                sender: tx,
208            })
209            .map_err(|_| Error::Closed)?;
210        let receiver = rx.await.map_err(|_| Error::AlreadyRegistered(subchannel))?;
211
212        Ok((
213            SubSender {
214                subchannel,
215                inner: GlobalSender::new(self.sender.clone()),
216            },
217            SubReceiver {
218                receiver,
219                control_tx: Some(self.control_tx.clone()),
220                subchannel,
221            },
222        ))
223    }
224}
225
226/// Sender that routes messages to the `subchannel`.
227#[derive(Clone, Debug)]
228pub struct SubSender<S: Sender> {
229    inner: GlobalSender<S>,
230    subchannel: Channel,
231}
232
233impl<S: Sender> LimitedSender for SubSender<S> {
234    type PublicKey = S::PublicKey;
235    type Checked<'a> = CheckedGlobalSender<'a, S>;
236
237    async fn check(
238        &mut self,
239        recipients: Recipients<Self::PublicKey>,
240    ) -> Result<Self::Checked<'_>, SystemTime> {
241        self.inner
242            .check(recipients)
243            .await
244            .map(|checked| checked.with_subchannel(self.subchannel))
245    }
246}
247
248/// Receiver that yields messages for a specific subchannel.
249pub struct SubReceiver<R: Receiver> {
250    receiver: mpsc::Receiver<Message<R::PublicKey>>,
251    control_tx: Option<mpsc::UnboundedSender<Control<R>>>,
252    subchannel: Channel,
253}
254
255impl<R: Receiver> Receiver for SubReceiver<R> {
256    type Error = Error;
257    type PublicKey = R::PublicKey;
258
259    async fn recv(&mut self) -> Result<Message<Self::PublicKey>, Self::Error> {
260        self.receiver.recv().await.ok_or(Error::RecvFailed)
261    }
262}
263
264impl<R: Receiver> Debug for SubReceiver<R> {
265    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266        write!(f, "SubReceiver({})", self.subchannel)
267    }
268}
269
270impl<R: Receiver> Drop for SubReceiver<R> {
271    fn drop(&mut self) {
272        // Take the control channel to avoid cloning.
273        let control_tx = self
274            .control_tx
275            .take()
276            .expect("SubReceiver::drop called twice");
277
278        // Deregister the subchannel immediately.
279        control_tx.send_lossy(Control::Deregister {
280            subchannel: self.subchannel,
281        });
282    }
283}
284
285/// Sender that can send messages over any sub [Channel].
286#[derive(Clone, Debug)]
287pub struct GlobalSender<S: Sender> {
288    inner: S,
289}
290
291impl<S: Sender> GlobalSender<S> {
292    /// Create a new [GlobalSender] wrapping the given [Sender].
293    pub const fn new(inner: S) -> Self {
294        Self { inner }
295    }
296
297    /// Send a message over the given `subchannel`.
298    pub async fn send(
299        &mut self,
300        subchannel: Channel,
301        recipients: Recipients<S::PublicKey>,
302        payload: impl Into<IoBufMut> + Send,
303        priority: bool,
304    ) -> Result<Vec<S::PublicKey>, <S::Checked<'_> as CheckedSender>::Error> {
305        match self.check(recipients).await {
306            Ok(checked) => {
307                checked
308                    .with_subchannel(subchannel)
309                    .send(payload, priority)
310                    .await
311            }
312            Err(_) => Ok(Vec::new()),
313        }
314    }
315}
316
317impl<S: Sender> LimitedSender for GlobalSender<S> {
318    type PublicKey = S::PublicKey;
319    type Checked<'a> = CheckedGlobalSender<'a, S>;
320
321    async fn check(
322        &mut self,
323        recipients: Recipients<Self::PublicKey>,
324    ) -> Result<Self::Checked<'_>, SystemTime> {
325        self.inner
326            .check(recipients)
327            .await
328            .map(|checked| CheckedGlobalSender {
329                subchannel: None,
330                inner: checked,
331            })
332    }
333}
334
335/// A checked sender for a [GlobalSender].
336pub struct CheckedGlobalSender<'a, S: Sender> {
337    subchannel: Option<Channel>,
338    inner: S::Checked<'a>,
339}
340
341impl<'a, S: Sender> CheckedGlobalSender<'a, S> {
342    /// Set the subchannel for this sender.
343    pub const fn with_subchannel(mut self, subchannel: Channel) -> Self {
344        self.subchannel = Some(subchannel);
345        self
346    }
347}
348
349impl<'a, S: Sender> CheckedSender for CheckedGlobalSender<'a, S> {
350    type PublicKey = S::PublicKey;
351    type Error = <S::Checked<'a> as CheckedSender>::Error;
352
353    async fn send(
354        self,
355        message: impl Into<IoBufMut> + Send,
356        priority: bool,
357    ) -> Result<Vec<Self::PublicKey>, Self::Error> {
358        let subchannel = UInt(self.subchannel.expect("subchannel not set"));
359        let subchannel_bytes = subchannel.encode();
360        let message = message.into();
361        let mut combined = IoBufMut::with_capacity(subchannel_bytes.len() + message.len());
362        combined.put_slice(subchannel_bytes.as_ref());
363        combined.put_slice(message.as_ref());
364        self.inner.send(combined, priority).await
365    }
366}
367
368/// A generic builder interface.
369pub trait Builder {
370    /// The output type produced by the builder.
371    type Output;
372
373    /// Builds the output type, consuming `self`.
374    fn build(self) -> Self::Output;
375}
376
377/// A builder that constructs a [Muxer].
378pub struct MuxerBuilder<E: Spawner, S: Sender, R: Receiver> {
379    mux: Muxer<E, S, R>,
380    mux_handle: MuxHandle<S, R>,
381}
382
383impl<E: Spawner, S: Sender, R: Receiver> Builder for MuxerBuilder<E, S, R> {
384    type Output = (Muxer<E, S, R>, MuxHandle<S, R>);
385
386    fn build(self) -> Self::Output {
387        (self.mux, self.mux_handle)
388    }
389}
390
391impl<E: Spawner, S: Sender, R: Receiver> MuxerBuilder<E, S, R> {
392    /// Registers a backup channel with the muxer.
393    pub fn with_backup(mut self) -> MuxerBuilderWithBackup<E, S, R> {
394        let (tx, rx) = mpsc::channel(self.mux.mailbox_size);
395        self.mux.backup = Some(tx);
396
397        MuxerBuilderWithBackup {
398            mux: self.mux,
399            mux_handle: self.mux_handle,
400            backup_rx: rx,
401        }
402    }
403
404    /// Registers a global sender with the muxer.
405    pub fn with_global_sender(self) -> MuxerBuilderWithGlobalSender<E, S, R> {
406        let global_sender = GlobalSender::new(self.mux.sender.clone());
407
408        MuxerBuilderWithGlobalSender {
409            mux: self.mux,
410            mux_handle: self.mux_handle,
411            global_sender,
412        }
413    }
414}
415
416/// A builder that constructs a [Muxer] with a backup channel.
417pub struct MuxerBuilderWithBackup<E: Spawner, S: Sender, R: Receiver> {
418    mux: Muxer<E, S, R>,
419    mux_handle: MuxHandle<S, R>,
420    backup_rx: mpsc::Receiver<BackupResponse<R::PublicKey>>,
421}
422
423impl<E: Spawner, S: Sender, R: Receiver> MuxerBuilderWithBackup<E, S, R> {
424    /// Registers a global sender with the muxer.
425    pub fn with_global_sender(self) -> MuxerBuilderAllOpts<E, S, R> {
426        let global_sender = GlobalSender::new(self.mux.sender.clone());
427
428        MuxerBuilderAllOpts {
429            mux: self.mux,
430            mux_handle: self.mux_handle,
431            backup_rx: self.backup_rx,
432            global_sender,
433        }
434    }
435}
436
437impl<E: Spawner, S: Sender, R: Receiver> Builder for MuxerBuilderWithBackup<E, S, R> {
438    type Output = (
439        Muxer<E, S, R>,
440        MuxHandle<S, R>,
441        mpsc::Receiver<BackupResponse<R::PublicKey>>,
442    );
443
444    fn build(self) -> Self::Output {
445        (self.mux, self.mux_handle, self.backup_rx)
446    }
447}
448
449/// A builder that constructs a [Muxer] with a [GlobalSender].
450pub struct MuxerBuilderWithGlobalSender<E: Spawner, S: Sender, R: Receiver> {
451    mux: Muxer<E, S, R>,
452    mux_handle: MuxHandle<S, R>,
453    global_sender: GlobalSender<S>,
454}
455
456impl<E: Spawner, S: Sender, R: Receiver> MuxerBuilderWithGlobalSender<E, S, R> {
457    /// Registers a backup channel with the muxer.
458    pub fn with_backup(mut self) -> MuxerBuilderAllOpts<E, S, R> {
459        let (tx, rx) = mpsc::channel(self.mux.mailbox_size);
460        self.mux.backup = Some(tx);
461
462        MuxerBuilderAllOpts {
463            mux: self.mux,
464            mux_handle: self.mux_handle,
465            backup_rx: rx,
466            global_sender: self.global_sender,
467        }
468    }
469}
470
471impl<E: Spawner, S: Sender, R: Receiver> Builder for MuxerBuilderWithGlobalSender<E, S, R> {
472    type Output = (Muxer<E, S, R>, MuxHandle<S, R>, GlobalSender<S>);
473
474    fn build(self) -> Self::Output {
475        (self.mux, self.mux_handle, self.global_sender)
476    }
477}
478
479/// A builder that constructs a [Muxer] with a [GlobalSender] and backup channel.
480pub struct MuxerBuilderAllOpts<E: Spawner, S: Sender, R: Receiver> {
481    mux: Muxer<E, S, R>,
482    mux_handle: MuxHandle<S, R>,
483    backup_rx: mpsc::Receiver<BackupResponse<R::PublicKey>>,
484    global_sender: GlobalSender<S>,
485}
486
487impl<E: Spawner, S: Sender, R: Receiver> Builder for MuxerBuilderAllOpts<E, S, R> {
488    type Output = (
489        Muxer<E, S, R>,
490        MuxHandle<S, R>,
491        mpsc::Receiver<BackupResponse<R::PublicKey>>,
492        GlobalSender<S>,
493    );
494
495    fn build(self) -> Self::Output {
496        (
497            self.mux,
498            self.mux_handle,
499            self.backup_rx,
500            self.global_sender,
501        )
502    }
503}
504
505#[cfg(test)]
506mod tests {
507    use super::*;
508    use crate::{
509        simulated::{self, Link, Network, Oracle},
510        Recipients,
511    };
512    use commonware_cryptography::{
513        ed25519::{PrivateKey, PublicKey},
514        Signer,
515    };
516    use commonware_macros::{select, test_traced};
517    use commonware_runtime::{deterministic, IoBuf, Metrics, Quota, Runner};
518    use std::{num::NonZeroU32, time::Duration};
519
520    const LINK: Link = Link {
521        latency: Duration::from_millis(0),
522        jitter: Duration::from_millis(0),
523        success_rate: 1.0,
524    };
525    const CAPACITY: usize = 5usize;
526
527    /// Default rate limit set high enough to not interfere with normal operation
528    const TEST_QUOTA: Quota = Quota::per_second(NonZeroU32::MAX);
529
530    /// Start the network and return the oracle.
531    fn start_network(context: deterministic::Context) -> Oracle<PublicKey, deterministic::Context> {
532        let (network, oracle) = Network::new(
533            context.with_label("network"),
534            simulated::Config {
535                max_size: 1024 * 1024,
536                disconnect_on_block: true,
537                tracked_peer_sets: None,
538            },
539        );
540        network.start();
541        oracle
542    }
543
544    /// Create a public key from a seed.
545    fn pk(seed: u64) -> PublicKey {
546        PrivateKey::from_seed(seed).public_key()
547    }
548
549    /// Link two peers bidirectionally.
550    async fn link_bidirectional(
551        oracle: &mut Oracle<PublicKey, deterministic::Context>,
552        a: PublicKey,
553        b: PublicKey,
554    ) {
555        oracle.add_link(a.clone(), b.clone(), LINK).await.unwrap();
556        oracle.add_link(b, a, LINK).await.unwrap();
557    }
558
559    /// Create a peer and register it with the oracle.
560    async fn create_peer(
561        context: &deterministic::Context,
562        oracle: &mut Oracle<PublicKey, deterministic::Context>,
563        seed: u64,
564    ) -> (
565        PublicKey,
566        MuxHandle<impl Sender<PublicKey = PublicKey>, impl Receiver<PublicKey = PublicKey>>,
567    ) {
568        let pubkey = pk(seed);
569        let (sender, receiver) = oracle
570            .control(pubkey.clone())
571            .register(0, TEST_QUOTA)
572            .await
573            .unwrap();
574        let (mux, handle) = Muxer::new(context.with_label("mux"), sender, receiver, CAPACITY);
575        mux.start();
576        (pubkey, handle)
577    }
578
579    /// Create a peer and register it with the oracle.
580    async fn create_peer_with_backup_and_global_sender(
581        context: &deterministic::Context,
582        oracle: &mut Oracle<PublicKey, deterministic::Context>,
583        seed: u64,
584    ) -> (
585        PublicKey,
586        MuxHandle<impl Sender<PublicKey = PublicKey>, impl Receiver<PublicKey = PublicKey>>,
587        mpsc::Receiver<BackupResponse<PublicKey>>,
588        GlobalSender<simulated::Sender<PublicKey, deterministic::Context>>,
589    ) {
590        let pubkey = pk(seed);
591        let (sender, receiver) = oracle
592            .control(pubkey.clone())
593            .register(0, TEST_QUOTA)
594            .await
595            .unwrap();
596        let (mux, handle, backup, global_sender) =
597            Muxer::builder(context.with_label("mux"), sender, receiver, CAPACITY)
598                .with_backup()
599                .with_global_sender()
600                .build();
601        mux.start();
602        (pubkey, handle, backup, global_sender)
603    }
604
605    /// Send a burst of messages to a list of senders.
606    async fn send_burst<S: Sender>(txs: &mut [SubSender<S>], count: usize) {
607        for i in 0..count {
608            let payload = IoBuf::from(vec![i as u8]);
609            for tx in txs.iter_mut() {
610                let _ = tx
611                    .send(Recipients::All, payload.clone(), false)
612                    .await
613                    .unwrap();
614            }
615        }
616    }
617
618    /// Wait for `n` messages to be received on the receiver.
619    async fn expect_n_messages(
620        rx: &mut SubReceiver<impl Receiver<PublicKey = PublicKey>>,
621        n: usize,
622    ) {
623        let mut count = 0;
624        loop {
625            select! {
626                res = rx.recv() => {
627                    res.expect("should have received message");
628                    count += 1;
629                },
630            }
631
632            if count >= n {
633                break;
634            }
635        }
636        assert_eq!(n, count);
637    }
638
639    /// Wait for `n` messages to be received on the receiver + backup receiver.
640    async fn expect_n_messages_with_backup(
641        rx: &mut SubReceiver<impl Receiver<PublicKey = PublicKey>>,
642        backup_rx: &mut mpsc::Receiver<BackupResponse<PublicKey>>,
643        n: usize,
644        n_backup: usize,
645    ) {
646        let mut count_std = 0;
647        let mut count_backup = 0;
648        loop {
649            select! {
650                res = rx.recv() => {
651                    res.expect("should have received message");
652                    count_std += 1;
653                },
654                res = backup_rx.recv() => {
655                    res.expect("should have received message");
656                    count_backup += 1;
657                },
658            }
659
660            if count_std >= n && count_backup >= n_backup {
661                break;
662            }
663        }
664        assert_eq!(n, count_std);
665        assert_eq!(n_backup, count_backup);
666    }
667
668    #[test]
669    fn test_basic_routing() {
670        // Can register a subchannel and send messages to it.
671        let executor = deterministic::Runner::default();
672        executor.start(|context| async move {
673            let mut oracle = start_network(context.clone());
674
675            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
676            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
677            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
678
679            let (_, mut sub_rx1) = handle1.register(7).await.unwrap();
680            let (mut sub_tx2, _) = handle2.register(7).await.unwrap();
681
682            // Send and receive
683            let payload = IoBuf::from(b"hello");
684            let _ = sub_tx2
685                .send(Recipients::One(pk1.clone()), payload.clone(), false)
686                .await
687                .unwrap();
688            let (from, bytes) = sub_rx1.recv().await.unwrap();
689            assert_eq!(from, pk2);
690            assert_eq!(bytes, payload);
691        });
692    }
693
694    #[test]
695    fn test_multiple_routes() {
696        // Can register multiple subchannels and send messages to each.
697        let executor = deterministic::Runner::default();
698        executor.start(|context| async move {
699            let mut oracle = start_network(context.clone());
700
701            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
702            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
703            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
704
705            let (_, mut rx_a) = handle1.register(10).await.unwrap();
706            let (_, mut rx_b) = handle1.register(20).await.unwrap();
707
708            let (mut tx2_a, _) = handle2.register(10).await.unwrap();
709            let (mut tx2_b, _) = handle2.register(20).await.unwrap();
710
711            let payload_a = IoBuf::from(b"A");
712            let payload_b = IoBuf::from(b"B");
713            let _ = tx2_a
714                .send(Recipients::One(pk1.clone()), payload_a.clone(), false)
715                .await
716                .unwrap();
717            let _ = tx2_b
718                .send(Recipients::One(pk1.clone()), payload_b.clone(), false)
719                .await
720                .unwrap();
721
722            let (from_a, bytes_a) = rx_a.recv().await.unwrap();
723            assert_eq!(from_a, pk2);
724            assert_eq!(bytes_a, payload_a);
725
726            let (from_b, bytes_b) = rx_b.recv().await.unwrap();
727            assert_eq!(from_b, pk2);
728            assert_eq!(bytes_b, payload_b);
729        });
730    }
731
732    #[test_traced]
733    fn test_mailbox_capacity_drops_when_full() {
734        // Messages are dropped (not blocked) when a subchannel buffer is full.
735        // This prevents head-of-line blocking where one slow subchannel blocks all others.
736        let executor = deterministic::Runner::default();
737        executor.start(|context| async move {
738            let mut oracle = start_network(context.clone());
739
740            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
741            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
742            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
743
744            // Register the subchannels.
745            let (tx1, _) = handle1.register(99).await.unwrap();
746            let (tx2, _) = handle1.register(100).await.unwrap();
747            let (_, mut rx1) = handle2.register(99).await.unwrap();
748            let (_, mut rx2) = handle2.register(100).await.unwrap();
749
750            // Send 10 messages to each subchannel from pk1 to pk2.
751            // With buffer size of CAPACITY=5, messages beyond that are dropped.
752            send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
753
754            // Each subchannel should receive up to CAPACITY messages (the rest are dropped).
755            expect_n_messages(&mut rx1, CAPACITY).await;
756            expect_n_messages(&mut rx2, CAPACITY).await;
757        });
758    }
759
760    #[test]
761    fn test_drop_subchannel_receiver_deregisters_route() {
762        // Dropping a subchannel receiver deregisters the route, and subsequent
763        // messages to that subchannel are dropped.
764        let executor = deterministic::Runner::default();
765        executor.start(|context| async move {
766            let mut oracle = start_network(context.clone());
767
768            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
769            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
770            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
771
772            // Register the subchannels.
773            let (tx1, _) = handle1.register(99).await.unwrap();
774            let (tx2, _) = handle1.register(100).await.unwrap();
775            let (_, rx1) = handle2.register(99).await.unwrap();
776            let (_, mut rx2) = handle2.register(100).await.unwrap();
777
778            // Drop rx1 before any messages are sent - its route is now deregistered.
779            drop(rx1);
780
781            // Send messages to both subchannels. Messages to subchannel 99 will be dropped
782            // since its receiver was dropped.
783            send_burst(&mut [tx1, tx2], CAPACITY).await;
784
785            // rx2 should receive all CAPACITY messages sent to subchannel 100.
786            expect_n_messages(&mut rx2, CAPACITY).await;
787        });
788    }
789
790    #[test]
791    fn test_drop_messages_for_unregistered_subchannel() {
792        // Messages are dropped if the subchannel they are for is not registered.
793        // The unregistered subchannel does not affect the registered one.
794        let executor = deterministic::Runner::default();
795        executor.start(|context| async move {
796            let mut oracle = start_network(context.clone());
797
798            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
799            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
800            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
801
802            // Register the subchannels.
803            let (tx1, _) = handle1.register(1).await.unwrap();
804            let (tx2, _) = handle1.register(2).await.unwrap();
805            // Do not register the first subchannel on the second peer.
806            let (_, mut rx2) = handle2.register(2).await.unwrap();
807
808            // Send CAPACITY messages to each subchannel.
809            // Messages to subchannel 1 are dropped (unregistered).
810            // Messages to subchannel 2 fill the buffer.
811            send_burst(&mut [tx1, tx2], CAPACITY).await;
812
813            // Receive messages from subchannel 2.
814            expect_n_messages(&mut rx2, CAPACITY).await;
815        });
816    }
817
818    #[test]
819    fn test_backup_for_unregistered_subchannel() {
820        // Messages are forwarded to the backup channel if the subchannel they are for
821        // is not registered.
822        let executor = deterministic::Runner::default();
823        executor.start(|context| async move {
824            let mut oracle = start_network(context.clone());
825
826            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
827            let (pk2, mut handle2, mut backup2, _) =
828                create_peer_with_backup_and_global_sender(&context, &mut oracle, 1).await;
829            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
830
831            // Register the subchannels.
832            let (tx1, _) = handle1.register(1).await.unwrap();
833            let (tx2, _) = handle1.register(2).await.unwrap();
834            // Do not register the first subchannel on the second peer.
835            let (_, mut rx2) = handle2.register(2).await.unwrap();
836
837            // Send CAPACITY messages to each subchannel.
838            // Subchannel 1 messages go to backup, subchannel 2 messages go to rx2.
839            send_burst(&mut [tx1, tx2], CAPACITY).await;
840
841            // Both channels should receive CAPACITY messages each.
842            expect_n_messages_with_backup(&mut rx2, &mut backup2, CAPACITY, CAPACITY).await;
843        });
844    }
845
846    #[test]
847    fn test_backup_for_unregistered_subchannel_response() {
848        // Messages are forwarded to the backup channel if the subchannel they are for
849        // is not registered.
850        let executor = deterministic::Runner::default();
851        executor.start(|context| async move {
852            let mut oracle = start_network(context.clone());
853
854            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
855            let (pk2, _handle2, mut backup2, mut global_sender2) =
856                create_peer_with_backup_and_global_sender(&context, &mut oracle, 1).await;
857            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
858
859            // Register the subchannels.
860            let (tx1, mut rx1) = handle1.register(1).await.unwrap();
861            // Do not register any subchannels on the second peer.
862
863            // Send 1 message to each subchannel from pk1 to pk2.
864            send_burst(&mut [tx1], 1).await;
865
866            // Get the message from pk2's backup channel and respond.
867            let (subchannel, (from, _)) = backup2.recv().await.unwrap();
868            assert_eq!(subchannel, 1);
869            assert_eq!(from, pk1);
870            global_sender2
871                .send(subchannel, Recipients::One(pk1), b"TEST", true)
872                .await
873                .unwrap();
874
875            // Receive the response with pk1's receiver.
876            let (from, bytes) = rx1.recv().await.unwrap();
877            assert_eq!(from, pk2);
878            assert_eq!(bytes, b"TEST");
879        });
880    }
881
882    #[test]
883    fn test_message_dropped_for_closed_subchannel() {
884        // Messages are dropped if the subchannel they are for is registered, but has been closed.
885        //
886        // NOTE: This case should be exceedingly rare in practice due to `SubReceiver` deregistering
887        // the subchannel on drop, but is included for completeness.
888        let executor = deterministic::Runner::default();
889        executor.start(|context| async move {
890            let mut oracle = start_network(context.clone());
891
892            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
893            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
894            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
895
896            // Register the subchannels.
897            let (tx1, _) = handle1.register(1).await.unwrap();
898            let (tx2, _) = handle1.register(2).await.unwrap();
899            let (_, mut rx1) = handle2.register(1).await.unwrap();
900            let (_, mut rx2) = handle2.register(2).await.unwrap();
901
902            // Send CAPACITY messages to subchannel 1, then drain them.
903            send_burst(&mut [tx1.clone()], CAPACITY).await;
904            expect_n_messages(&mut rx1, CAPACITY).await;
905
906            // Send CAPACITY messages to subchannel 2, then drain them.
907            send_burst(&mut [tx2.clone()], CAPACITY).await;
908            expect_n_messages(&mut rx2, CAPACITY).await;
909
910            // Explicitly close the underlying receiver for the first subchannel.
911            rx1.receiver.close();
912
913            // Send CAPACITY messages to each subchannel.
914            // Messages to subchannel 1 are dropped (receiver closed).
915            send_burst(&mut [tx1, tx2], CAPACITY).await;
916
917            // Subchannel 2 should receive CAPACITY messages.
918            expect_n_messages(&mut rx2, CAPACITY).await;
919        });
920    }
921
922    #[test]
923    fn test_dropped_backup_channel_doesnt_block() {
924        // Dropping the backup receiver doesn't block message processing.
925        // Messages to unregistered subchannels are simply dropped.
926        let executor = deterministic::Runner::default();
927        executor.start(|context| async move {
928            let mut oracle = start_network(context.clone());
929
930            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
931            let (pk2, mut handle2, backup2, _) =
932                create_peer_with_backup_and_global_sender(&context, &mut oracle, 1).await;
933            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
934
935            // Explicitly drop the backup receiver.
936            drop(backup2);
937
938            // Register the subchannels.
939            let (tx1, _) = handle1.register(1).await.unwrap();
940            let (tx2, _) = handle1.register(2).await.unwrap();
941            // Do not register the first subchannel on the second peer.
942            let (_, mut rx2) = handle2.register(2).await.unwrap();
943
944            // Send CAPACITY messages to each subchannel.
945            // Subchannel 1 messages are dropped (backup is closed).
946            // Subchannel 2 messages go to rx2.
947            send_burst(&mut [tx1, tx2], CAPACITY).await;
948
949            // rx2 should receive all CAPACITY messages.
950            expect_n_messages(&mut rx2, CAPACITY).await;
951        });
952    }
953
954    #[test]
955    fn test_duplicate_registration() {
956        // Returns an error if the subchannel is already registered.
957        let executor = deterministic::Runner::default();
958        executor.start(|context| async move {
959            let mut oracle = start_network(context.clone());
960
961            let (_pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
962
963            // Register the subchannel.
964            let (_, _rx) = handle1.register(7).await.unwrap();
965
966            // Registering again should return an error.
967            assert!(matches!(
968                handle1.register(7).await,
969                Err(Error::AlreadyRegistered(_))
970            ));
971        });
972    }
973
974    #[test]
975    fn test_register_after_deregister() {
976        // Can register a channel after it has been deregistered.
977        let executor = deterministic::Runner::default();
978        executor.start(|context| async move {
979            let mut oracle = start_network(context.clone());
980
981            let (_, mut handle) = create_peer(&context, &mut oracle, 0).await;
982            let (_, rx) = handle.register(7).await.unwrap();
983            drop(rx);
984
985            // Registering again should not return an error.
986            handle.register(7).await.unwrap();
987        });
988    }
989}