commonware_p2p/utils/
mux.rs

1//! This utility wraps a [Sender] and [Receiver], providing lightweight sub-channels keyed by
2//! [Channel].
3//!
4//! Usage:
5//! - Call [Muxer::new] to obtain a ([Muxer], [MuxHandle]) pair.
6//! - Call [Muxer::start] or run [Muxer::run] in a background task to demux incoming messages into
7//!   per-subchannel queues.
8//! - Call [MuxHandle::register] to obtain a ([SubSender], [SubReceiver]) pair for that subchannel,
9//!   even if the muxer is already running.
10
11use crate::{Channel, CheckedSender, LimitedSender, Message, Receiver, Recipients, Sender};
12use bytes::{Buf, Bytes};
13use commonware_codec::{varint::UInt, Encode, Error as CodecError, ReadExt};
14use commonware_macros::select_loop;
15use commonware_runtime::{spawn_cell, ContextCell, Handle, Spawner};
16use commonware_utils::channels::fallible::FallibleExt;
17use futures::{
18    channel::{mpsc, oneshot},
19    SinkExt, StreamExt,
20};
21use std::{collections::HashMap, fmt::Debug, time::SystemTime};
22use thiserror::Error;
23use tracing::debug;
24
25/// Errors that can occur when interacting with a [SubReceiver] or [MuxHandle].
26#[derive(Error, Debug)]
27pub enum Error {
28    #[error("subchannel already registered: {0}")]
29    AlreadyRegistered(Channel),
30    #[error("muxer is closed")]
31    Closed,
32    #[error("recv failed")]
33    RecvFailed,
34}
35
36/// Parse a muxed message into its subchannel and payload.
37pub fn parse(mut bytes: Bytes) -> Result<(Channel, Bytes), CodecError> {
38    let subchannel: Channel = UInt::read(&mut bytes)?.into();
39    Ok((subchannel, bytes))
40}
41
42/// Control messages for the [Muxer].
43enum Control<R: Receiver> {
44    Register {
45        subchannel: Channel,
46        sender: oneshot::Sender<mpsc::Receiver<Message<R::PublicKey>>>,
47    },
48    Deregister {
49        subchannel: Channel,
50    },
51}
52
53/// Thread-safe routing table mapping each [Channel] to the [mpsc::Sender] for [`Message<P>`].
54type Routes<P> = HashMap<Channel, mpsc::Sender<Message<P>>>;
55
56/// A backup channel response, with a [SubSender] to respond, the [Channel] that wasn't registered,
57/// and the [Message] received.
58type BackupResponse<P> = (Channel, Message<P>);
59
60/// A multiplexer of p2p channels into subchannels.
61pub struct Muxer<E: Spawner, S: Sender, R: Receiver> {
62    context: ContextCell<E>,
63    sender: S,
64    receiver: R,
65    mailbox_size: usize,
66    control_rx: mpsc::UnboundedReceiver<Control<R>>,
67    routes: Routes<R::PublicKey>,
68    backup: Option<mpsc::Sender<BackupResponse<R::PublicKey>>>,
69}
70
71impl<E: Spawner, S: Sender, R: Receiver> Muxer<E, S, R> {
72    /// Create a multiplexed wrapper around a [Sender] and [Receiver] pair, and return a ([Muxer],
73    /// [MuxHandle]) pair that can be used to register routes dynamically.
74    pub fn new(context: E, sender: S, receiver: R, mailbox_size: usize) -> (Self, MuxHandle<S, R>) {
75        Self::builder(context, sender, receiver, mailbox_size).build()
76    }
77
78    /// Creates a [MuxerBuilder] that can be used to configure and build a [Muxer].
79    pub fn builder(
80        context: E,
81        sender: S,
82        receiver: R,
83        mailbox_size: usize,
84    ) -> MuxerBuilder<E, S, R> {
85        let (control_tx, control_rx) = mpsc::unbounded();
86        let mux = Self {
87            context: ContextCell::new(context),
88            sender,
89            receiver,
90            mailbox_size,
91            control_rx,
92            routes: HashMap::new(),
93            backup: None,
94        };
95
96        let mux_handle = MuxHandle {
97            sender: mux.sender.clone(),
98            control_tx,
99        };
100
101        MuxerBuilder { mux, mux_handle }
102    }
103
104    /// Start the demuxer using the given spawner.
105    pub fn start(mut self) -> Handle<Result<(), R::Error>> {
106        spawn_cell!(self.context, self.run().await)
107    }
108
109    /// Drive demultiplexing of messages into per-subchannel receivers.
110    ///
111    /// Callers should run this in a background task for as long as the underlying `Receiver` is
112    /// expected to receive traffic.
113    pub async fn run(mut self) -> Result<(), R::Error> {
114        select_loop! {
115            self.context,
116            on_stopped => {
117                debug!("context shutdown, stopping muxer");
118            },
119            // Prefer control messages because network messages will
120            // already block when full (providing backpressure).
121            control = self.control_rx.next() => {
122                match control {
123                    Some(Control::Register { subchannel, sender }) => {
124                        // If the subchannel is already registered, drop the sender.
125                        if self.routes.contains_key(&subchannel) {
126                            continue;
127                        }
128
129                        // Otherwise, create a new subchannel and send the receiver to the caller.
130                        let (tx, rx) = mpsc::channel(self.mailbox_size);
131                        self.routes.insert(subchannel, tx);
132                        let _ = sender.send(rx);
133                    },
134                    Some(Control::Deregister { subchannel }) => {
135                        // Remove the route.
136                        self.routes.remove(&subchannel);
137                    },
138                    None => {
139                        // If the control channel is closed, we can shut down since there must
140                        // be no more registrations, and all receivers must have been dropped.
141                        return Ok(());
142                    }
143                }
144            },
145            // Process network messages.
146            message = self.receiver.recv() => {
147                // Decode the message.
148                let (pk, bytes) = message?;
149                let (subchannel, bytes) = match parse(bytes) {
150                    Ok(parsed) => parsed,
151                    Err(_) => {
152                        debug!(?pk, "invalid message: missing subchannel");
153                        continue;
154                    }
155                };
156
157                // Get the route for the subchannel.
158                let Some(sender) = self.routes.get_mut(&subchannel) else {
159                    // Attempt to use the backup channel if available.
160                    if let Some(backup) = &mut self.backup {
161                        if let Err(e) = backup.try_send((subchannel, (pk, bytes))) {
162                            debug!(?subchannel, ?e, "failed to send message to backup channel");
163                        }
164                    }
165
166                    // Drops the message if the subchannel is not found or the backup
167                    // channel was not used.
168                    continue;
169                };
170
171                // Send the message to the subchannel using non-blocking try_send
172                // to avoid head-of-line blocking when one subchannel is slow.
173                if let Err(e) = sender.try_send((pk, bytes)) {
174                    // Check if the channel is disconnected (receiver dropped)
175                    if e.is_disconnected() {
176                        // Remove the route for the subchannel.
177                        self.routes.remove(&subchannel);
178                        debug!(?subchannel, "subchannel receiver dropped, removing route");
179                    } else {
180                        // Channel is full, drop the message
181                        debug!(?subchannel, "subchannel full, dropping message");
182                    }
183                }
184            }
185        }
186
187        Ok(())
188    }
189}
190
191/// A clonable handle that allows registering routes at any time, even after the [Muxer] is running.
192#[derive(Clone)]
193pub struct MuxHandle<S: Sender, R: Receiver> {
194    sender: S,
195    control_tx: mpsc::UnboundedSender<Control<R>>,
196}
197
198impl<S: Sender, R: Receiver> MuxHandle<S, R> {
199    /// Open a `subchannel`. Returns a ([SubSender], [SubReceiver]) pair that can be used to send
200    /// and receive messages for that subchannel.
201    ///
202    /// Panics if the subchannel is already registered at any point.
203    pub async fn register(
204        &mut self,
205        subchannel: Channel,
206    ) -> Result<(SubSender<S>, SubReceiver<R>), Error> {
207        let (tx, rx) = oneshot::channel();
208        self.control_tx
209            .send(Control::Register {
210                subchannel,
211                sender: tx,
212            })
213            .await
214            .map_err(|_| Error::Closed)?;
215        let receiver = rx.await.map_err(|_| Error::AlreadyRegistered(subchannel))?;
216
217        Ok((
218            SubSender {
219                subchannel,
220                inner: GlobalSender::new(self.sender.clone()),
221            },
222            SubReceiver {
223                receiver,
224                control_tx: Some(self.control_tx.clone()),
225                subchannel,
226            },
227        ))
228    }
229}
230
231/// Sender that routes messages to the `subchannel`.
232#[derive(Clone, Debug)]
233pub struct SubSender<S: Sender> {
234    inner: GlobalSender<S>,
235    subchannel: Channel,
236}
237
238impl<S: Sender> LimitedSender for SubSender<S> {
239    type PublicKey = S::PublicKey;
240    type Checked<'a> = CheckedGlobalSender<'a, S>;
241
242    async fn check(
243        &mut self,
244        recipients: Recipients<Self::PublicKey>,
245    ) -> Result<Self::Checked<'_>, SystemTime> {
246        self.inner
247            .check(recipients)
248            .await
249            .map(|checked| checked.with_subchannel(self.subchannel))
250    }
251}
252
253/// Receiver that yields messages for a specific subchannel.
254pub struct SubReceiver<R: Receiver> {
255    receiver: mpsc::Receiver<Message<R::PublicKey>>,
256    control_tx: Option<mpsc::UnboundedSender<Control<R>>>,
257    subchannel: Channel,
258}
259
260impl<R: Receiver> Receiver for SubReceiver<R> {
261    type Error = Error;
262    type PublicKey = R::PublicKey;
263
264    async fn recv(&mut self) -> Result<Message<Self::PublicKey>, Self::Error> {
265        self.receiver.next().await.ok_or(Error::RecvFailed)
266    }
267}
268
269impl<R: Receiver> Debug for SubReceiver<R> {
270    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
271        write!(f, "SubReceiver({})", self.subchannel)
272    }
273}
274
275impl<R: Receiver> Drop for SubReceiver<R> {
276    fn drop(&mut self) {
277        // Take the control channel to avoid cloning.
278        let control_tx = self
279            .control_tx
280            .take()
281            .expect("SubReceiver::drop called twice");
282
283        // Deregister the subchannel immediately.
284        control_tx.send_lossy(Control::Deregister {
285            subchannel: self.subchannel,
286        });
287    }
288}
289
290/// Sender that can send messages over any sub [Channel].
291#[derive(Clone, Debug)]
292pub struct GlobalSender<S: Sender> {
293    inner: S,
294}
295
296impl<S: Sender> GlobalSender<S> {
297    /// Create a new [GlobalSender] wrapping the given [Sender].
298    pub const fn new(inner: S) -> Self {
299        Self { inner }
300    }
301
302    /// Send a message over the given `subchannel`.
303    pub async fn send(
304        &mut self,
305        subchannel: Channel,
306        recipients: Recipients<S::PublicKey>,
307        payload: impl Buf + Send,
308        priority: bool,
309    ) -> Result<Vec<S::PublicKey>, <S::Checked<'_> as CheckedSender>::Error> {
310        match self.check(recipients).await {
311            Ok(checked) => {
312                checked
313                    .with_subchannel(subchannel)
314                    .send(payload, priority)
315                    .await
316            }
317            Err(_) => Ok(Vec::new()),
318        }
319    }
320}
321
322impl<S: Sender> LimitedSender for GlobalSender<S> {
323    type PublicKey = S::PublicKey;
324    type Checked<'a> = CheckedGlobalSender<'a, S>;
325
326    async fn check(
327        &mut self,
328        recipients: Recipients<Self::PublicKey>,
329    ) -> Result<Self::Checked<'_>, SystemTime> {
330        self.inner
331            .check(recipients)
332            .await
333            .map(|checked| CheckedGlobalSender {
334                subchannel: None,
335                inner: checked,
336            })
337    }
338}
339
340/// A checked sender for a [GlobalSender].
341pub struct CheckedGlobalSender<'a, S: Sender> {
342    subchannel: Option<Channel>,
343    inner: S::Checked<'a>,
344}
345
346impl<'a, S: Sender> CheckedGlobalSender<'a, S> {
347    /// Set the subchannel for this sender.
348    pub const fn with_subchannel(mut self, subchannel: Channel) -> Self {
349        self.subchannel = Some(subchannel);
350        self
351    }
352}
353
354impl<'a, S: Sender> CheckedSender for CheckedGlobalSender<'a, S> {
355    type PublicKey = S::PublicKey;
356    type Error = <S::Checked<'a> as CheckedSender>::Error;
357
358    async fn send(
359        self,
360        message: impl Buf + Send,
361        priority: bool,
362    ) -> Result<Vec<Self::PublicKey>, Self::Error> {
363        let subchannel = UInt(self.subchannel.expect("subchannel not set"));
364        self.inner
365            .send(subchannel.encode().chain(message), priority)
366            .await
367    }
368}
369
370/// A generic builder interface.
371pub trait Builder {
372    /// The output type produced by the builder.
373    type Output;
374
375    /// Builds the output type, consuming `self`.
376    fn build(self) -> Self::Output;
377}
378
379/// A builder that constructs a [Muxer].
380pub struct MuxerBuilder<E: Spawner, S: Sender, R: Receiver> {
381    mux: Muxer<E, S, R>,
382    mux_handle: MuxHandle<S, R>,
383}
384
385impl<E: Spawner, S: Sender, R: Receiver> Builder for MuxerBuilder<E, S, R> {
386    type Output = (Muxer<E, S, R>, MuxHandle<S, R>);
387
388    fn build(self) -> Self::Output {
389        (self.mux, self.mux_handle)
390    }
391}
392
393impl<E: Spawner, S: Sender, R: Receiver> MuxerBuilder<E, S, R> {
394    /// Registers a backup channel with the muxer.
395    pub fn with_backup(mut self) -> MuxerBuilderWithBackup<E, S, R> {
396        let (tx, rx) = mpsc::channel(self.mux.mailbox_size);
397        self.mux.backup = Some(tx);
398
399        MuxerBuilderWithBackup {
400            mux: self.mux,
401            mux_handle: self.mux_handle,
402            backup_rx: rx,
403        }
404    }
405
406    /// Registers a global sender with the muxer.
407    pub fn with_global_sender(self) -> MuxerBuilderWithGlobalSender<E, S, R> {
408        let global_sender = GlobalSender::new(self.mux.sender.clone());
409
410        MuxerBuilderWithGlobalSender {
411            mux: self.mux,
412            mux_handle: self.mux_handle,
413            global_sender,
414        }
415    }
416}
417
418/// A builder that constructs a [Muxer] with a backup channel.
419pub struct MuxerBuilderWithBackup<E: Spawner, S: Sender, R: Receiver> {
420    mux: Muxer<E, S, R>,
421    mux_handle: MuxHandle<S, R>,
422    backup_rx: mpsc::Receiver<BackupResponse<R::PublicKey>>,
423}
424
425impl<E: Spawner, S: Sender, R: Receiver> MuxerBuilderWithBackup<E, S, R> {
426    /// Registers a global sender with the muxer.
427    pub fn with_global_sender(self) -> MuxerBuilderAllOpts<E, S, R> {
428        let global_sender = GlobalSender::new(self.mux.sender.clone());
429
430        MuxerBuilderAllOpts {
431            mux: self.mux,
432            mux_handle: self.mux_handle,
433            backup_rx: self.backup_rx,
434            global_sender,
435        }
436    }
437}
438
439impl<E: Spawner, S: Sender, R: Receiver> Builder for MuxerBuilderWithBackup<E, S, R> {
440    type Output = (
441        Muxer<E, S, R>,
442        MuxHandle<S, R>,
443        mpsc::Receiver<BackupResponse<R::PublicKey>>,
444    );
445
446    fn build(self) -> Self::Output {
447        (self.mux, self.mux_handle, self.backup_rx)
448    }
449}
450
451/// A builder that constructs a [Muxer] with a [GlobalSender].
452pub struct MuxerBuilderWithGlobalSender<E: Spawner, S: Sender, R: Receiver> {
453    mux: Muxer<E, S, R>,
454    mux_handle: MuxHandle<S, R>,
455    global_sender: GlobalSender<S>,
456}
457
458impl<E: Spawner, S: Sender, R: Receiver> MuxerBuilderWithGlobalSender<E, S, R> {
459    /// Registers a backup channel with the muxer.
460    pub fn with_backup(mut self) -> MuxerBuilderAllOpts<E, S, R> {
461        let (tx, rx) = mpsc::channel(self.mux.mailbox_size);
462        self.mux.backup = Some(tx);
463
464        MuxerBuilderAllOpts {
465            mux: self.mux,
466            mux_handle: self.mux_handle,
467            backup_rx: rx,
468            global_sender: self.global_sender,
469        }
470    }
471}
472
473impl<E: Spawner, S: Sender, R: Receiver> Builder for MuxerBuilderWithGlobalSender<E, S, R> {
474    type Output = (Muxer<E, S, R>, MuxHandle<S, R>, GlobalSender<S>);
475
476    fn build(self) -> Self::Output {
477        (self.mux, self.mux_handle, self.global_sender)
478    }
479}
480
481/// A builder that constructs a [Muxer] with a [GlobalSender] and backup channel.
482pub struct MuxerBuilderAllOpts<E: Spawner, S: Sender, R: Receiver> {
483    mux: Muxer<E, S, R>,
484    mux_handle: MuxHandle<S, R>,
485    backup_rx: mpsc::Receiver<BackupResponse<R::PublicKey>>,
486    global_sender: GlobalSender<S>,
487}
488
489impl<E: Spawner, S: Sender, R: Receiver> Builder for MuxerBuilderAllOpts<E, S, R> {
490    type Output = (
491        Muxer<E, S, R>,
492        MuxHandle<S, R>,
493        mpsc::Receiver<BackupResponse<R::PublicKey>>,
494        GlobalSender<S>,
495    );
496
497    fn build(self) -> Self::Output {
498        (
499            self.mux,
500            self.mux_handle,
501            self.backup_rx,
502            self.global_sender,
503        )
504    }
505}
506
507#[cfg(test)]
508mod tests {
509    use super::*;
510    use crate::{
511        simulated::{self, Link, Network, Oracle},
512        Recipients,
513    };
514    use bytes::Bytes;
515    use commonware_cryptography::{
516        ed25519::{PrivateKey, PublicKey},
517        Signer,
518    };
519    use commonware_macros::{select, test_traced};
520    use commonware_runtime::{deterministic, Metrics, Quota, Runner};
521    use std::{num::NonZeroU32, time::Duration};
522
523    const LINK: Link = Link {
524        latency: Duration::from_millis(0),
525        jitter: Duration::from_millis(0),
526        success_rate: 1.0,
527    };
528    const CAPACITY: usize = 5usize;
529
530    /// Default rate limit set high enough to not interfere with normal operation
531    const TEST_QUOTA: Quota = Quota::per_second(NonZeroU32::MAX);
532
533    /// Start the network and return the oracle.
534    fn start_network(context: deterministic::Context) -> Oracle<PublicKey, deterministic::Context> {
535        let (network, oracle) = Network::new(
536            context.with_label("network"),
537            simulated::Config {
538                max_size: 1024 * 1024,
539                disconnect_on_block: true,
540                tracked_peer_sets: None,
541            },
542        );
543        network.start();
544        oracle
545    }
546
547    /// Create a public key from a seed.
548    fn pk(seed: u64) -> PublicKey {
549        PrivateKey::from_seed(seed).public_key()
550    }
551
552    /// Link two peers bidirectionally.
553    async fn link_bidirectional(
554        oracle: &mut Oracle<PublicKey, deterministic::Context>,
555        a: PublicKey,
556        b: PublicKey,
557    ) {
558        oracle.add_link(a.clone(), b.clone(), LINK).await.unwrap();
559        oracle.add_link(b, a, LINK).await.unwrap();
560    }
561
562    /// Create a peer and register it with the oracle.
563    async fn create_peer(
564        context: &deterministic::Context,
565        oracle: &mut Oracle<PublicKey, deterministic::Context>,
566        seed: u64,
567    ) -> (
568        PublicKey,
569        MuxHandle<impl Sender<PublicKey = PublicKey>, impl Receiver<PublicKey = PublicKey>>,
570    ) {
571        let pubkey = pk(seed);
572        let (sender, receiver) = oracle
573            .control(pubkey.clone())
574            .register(0, TEST_QUOTA)
575            .await
576            .unwrap();
577        let (mux, handle) = Muxer::new(context.with_label("mux"), sender, receiver, CAPACITY);
578        mux.start();
579        (pubkey, handle)
580    }
581
582    /// Create a peer and register it with the oracle.
583    async fn create_peer_with_backup_and_global_sender(
584        context: &deterministic::Context,
585        oracle: &mut Oracle<PublicKey, deterministic::Context>,
586        seed: u64,
587    ) -> (
588        PublicKey,
589        MuxHandle<impl Sender<PublicKey = PublicKey>, impl Receiver<PublicKey = PublicKey>>,
590        mpsc::Receiver<BackupResponse<PublicKey>>,
591        GlobalSender<simulated::Sender<PublicKey, deterministic::Context>>,
592    ) {
593        let pubkey = pk(seed);
594        let (sender, receiver) = oracle
595            .control(pubkey.clone())
596            .register(0, TEST_QUOTA)
597            .await
598            .unwrap();
599        let (mux, handle, backup, global_sender) =
600            Muxer::builder(context.with_label("mux"), sender, receiver, CAPACITY)
601                .with_backup()
602                .with_global_sender()
603                .build();
604        mux.start();
605        (pubkey, handle, backup, global_sender)
606    }
607
608    /// Send a burst of messages to a list of senders.
609    async fn send_burst<S: Sender>(txs: &mut [SubSender<S>], count: usize) {
610        for i in 0..count {
611            let payload = Bytes::from(vec![i as u8]);
612            for tx in txs.iter_mut() {
613                let _ = tx
614                    .send(Recipients::All, payload.clone(), false)
615                    .await
616                    .unwrap();
617            }
618        }
619    }
620
621    /// Wait for `n` messages to be received on the receiver.
622    async fn expect_n_messages(
623        rx: &mut SubReceiver<impl Receiver<PublicKey = PublicKey>>,
624        n: usize,
625    ) {
626        let mut count = 0;
627        loop {
628            select! {
629                res = rx.recv() => {
630                    res.expect("should have received message");
631                    count += 1;
632                },
633            }
634
635            if count >= n {
636                break;
637            }
638        }
639        assert_eq!(n, count);
640    }
641
642    /// Wait for `n` messages to be received on the receiver + backup receiver.
643    async fn expect_n_messages_with_backup(
644        rx: &mut SubReceiver<impl Receiver<PublicKey = PublicKey>>,
645        backup_rx: &mut mpsc::Receiver<BackupResponse<PublicKey>>,
646        n: usize,
647        n_backup: usize,
648    ) {
649        let mut count_std = 0;
650        let mut count_backup = 0;
651        loop {
652            select! {
653                res = rx.recv() => {
654                    res.expect("should have received message");
655                    count_std += 1;
656                },
657                res = backup_rx.next() => {
658                    res.expect("should have received message");
659                    count_backup += 1;
660                },
661            }
662
663            if count_std >= n && count_backup >= n_backup {
664                break;
665            }
666        }
667        assert_eq!(n, count_std);
668        assert_eq!(n_backup, count_backup);
669    }
670
671    #[test]
672    fn test_basic_routing() {
673        // Can register a subchannel and send messages to it.
674        let executor = deterministic::Runner::default();
675        executor.start(|context| async move {
676            let mut oracle = start_network(context.clone());
677
678            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
679            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
680            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
681
682            let (_, mut sub_rx1) = handle1.register(7).await.unwrap();
683            let (mut sub_tx2, _) = handle2.register(7).await.unwrap();
684
685            // Send and receive
686            let payload = Bytes::from_static(b"hello");
687            let _ = sub_tx2
688                .send(Recipients::One(pk1.clone()), payload.clone(), false)
689                .await
690                .unwrap();
691            let (from, bytes) = sub_rx1.recv().await.unwrap();
692            assert_eq!(from, pk2);
693            assert_eq!(bytes, payload);
694        });
695    }
696
697    #[test]
698    fn test_multiple_routes() {
699        // Can register multiple subchannels and send messages to each.
700        let executor = deterministic::Runner::default();
701        executor.start(|context| async move {
702            let mut oracle = start_network(context.clone());
703
704            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
705            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
706            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
707
708            let (_, mut rx_a) = handle1.register(10).await.unwrap();
709            let (_, mut rx_b) = handle1.register(20).await.unwrap();
710
711            let (mut tx2_a, _) = handle2.register(10).await.unwrap();
712            let (mut tx2_b, _) = handle2.register(20).await.unwrap();
713
714            let payload_a = Bytes::from_static(b"A");
715            let payload_b = Bytes::from_static(b"B");
716            let _ = tx2_a
717                .send(Recipients::One(pk1.clone()), payload_a.clone(), false)
718                .await
719                .unwrap();
720            let _ = tx2_b
721                .send(Recipients::One(pk1.clone()), payload_b.clone(), false)
722                .await
723                .unwrap();
724
725            let (from_a, bytes_a) = rx_a.recv().await.unwrap();
726            assert_eq!(from_a, pk2);
727            assert_eq!(bytes_a, payload_a);
728
729            let (from_b, bytes_b) = rx_b.recv().await.unwrap();
730            assert_eq!(from_b, pk2);
731            assert_eq!(bytes_b, payload_b);
732        });
733    }
734
735    #[test_traced]
736    fn test_mailbox_capacity_drops_when_full() {
737        // Messages are dropped (not blocked) when a subchannel buffer is full.
738        // This prevents head-of-line blocking where one slow subchannel blocks all others.
739        let executor = deterministic::Runner::default();
740        executor.start(|context| async move {
741            let mut oracle = start_network(context.clone());
742
743            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
744            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
745            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
746
747            // Register the subchannels.
748            let (tx1, _) = handle1.register(99).await.unwrap();
749            let (tx2, _) = handle1.register(100).await.unwrap();
750            let (_, mut rx1) = handle2.register(99).await.unwrap();
751            let (_, mut rx2) = handle2.register(100).await.unwrap();
752
753            // Send 10 messages to each subchannel from pk1 to pk2.
754            // With buffer size of CAPACITY=5, messages beyond that are dropped.
755            send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
756
757            // Each subchannel should receive up to CAPACITY messages (the rest are dropped).
758            expect_n_messages(&mut rx1, CAPACITY).await;
759            expect_n_messages(&mut rx2, CAPACITY).await;
760        });
761    }
762
763    #[test]
764    fn test_drop_subchannel_receiver_deregisters_route() {
765        // Dropping a subchannel receiver deregisters the route, and subsequent
766        // messages to that subchannel are dropped.
767        let executor = deterministic::Runner::default();
768        executor.start(|context| async move {
769            let mut oracle = start_network(context.clone());
770
771            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
772            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
773            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
774
775            // Register the subchannels.
776            let (tx1, _) = handle1.register(99).await.unwrap();
777            let (tx2, _) = handle1.register(100).await.unwrap();
778            let (_, rx1) = handle2.register(99).await.unwrap();
779            let (_, mut rx2) = handle2.register(100).await.unwrap();
780
781            // Drop rx1 before any messages are sent - its route is now deregistered.
782            drop(rx1);
783
784            // Send messages to both subchannels. Messages to subchannel 99 will be dropped
785            // since its receiver was dropped.
786            send_burst(&mut [tx1, tx2], CAPACITY).await;
787
788            // rx2 should receive all CAPACITY messages sent to subchannel 100.
789            expect_n_messages(&mut rx2, CAPACITY).await;
790        });
791    }
792
793    #[test]
794    fn test_drop_messages_for_unregistered_subchannel() {
795        // Messages are dropped if the subchannel they are for is not registered.
796        // The unregistered subchannel does not affect the registered one.
797        let executor = deterministic::Runner::default();
798        executor.start(|context| async move {
799            let mut oracle = start_network(context.clone());
800
801            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
802            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
803            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
804
805            // Register the subchannels.
806            let (tx1, _) = handle1.register(1).await.unwrap();
807            let (tx2, _) = handle1.register(2).await.unwrap();
808            // Do not register the first subchannel on the second peer.
809            let (_, mut rx2) = handle2.register(2).await.unwrap();
810
811            // Send CAPACITY messages to each subchannel.
812            // Messages to subchannel 1 are dropped (unregistered).
813            // Messages to subchannel 2 fill the buffer.
814            send_burst(&mut [tx1, tx2], CAPACITY).await;
815
816            // Receive messages from subchannel 2.
817            expect_n_messages(&mut rx2, CAPACITY).await;
818        });
819    }
820
821    #[test]
822    fn test_backup_for_unregistered_subchannel() {
823        // Messages are forwarded to the backup channel if the subchannel they are for
824        // is not registered.
825        let executor = deterministic::Runner::default();
826        executor.start(|context| async move {
827            let mut oracle = start_network(context.clone());
828
829            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
830            let (pk2, mut handle2, mut backup2, _) =
831                create_peer_with_backup_and_global_sender(&context, &mut oracle, 1).await;
832            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
833
834            // Register the subchannels.
835            let (tx1, _) = handle1.register(1).await.unwrap();
836            let (tx2, _) = handle1.register(2).await.unwrap();
837            // Do not register the first subchannel on the second peer.
838            let (_, mut rx2) = handle2.register(2).await.unwrap();
839
840            // Send CAPACITY messages to each subchannel.
841            // Subchannel 1 messages go to backup, subchannel 2 messages go to rx2.
842            send_burst(&mut [tx1, tx2], CAPACITY).await;
843
844            // Both channels should receive CAPACITY messages each.
845            expect_n_messages_with_backup(&mut rx2, &mut backup2, CAPACITY, CAPACITY).await;
846        });
847    }
848
849    #[test]
850    fn test_backup_for_unregistered_subchannel_response() {
851        // Messages are forwarded to the backup channel if the subchannel they are for
852        // is not registered.
853        let executor = deterministic::Runner::default();
854        executor.start(|context| async move {
855            let mut oracle = start_network(context.clone());
856
857            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
858            let (pk2, _handle2, mut backup2, mut global_sender2) =
859                create_peer_with_backup_and_global_sender(&context, &mut oracle, 1).await;
860            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
861
862            // Register the subchannels.
863            let (tx1, mut rx1) = handle1.register(1).await.unwrap();
864            // Do not register any subchannels on the second peer.
865
866            // Send 1 message to each subchannel from pk1 to pk2.
867            send_burst(&mut [tx1], 1).await;
868
869            // Get the message from pk2's backup channel and respond.
870            let (subchannel, (from, _)) = backup2.next().await.unwrap();
871            assert_eq!(subchannel, 1);
872            assert_eq!(from, pk1);
873            global_sender2
874                .send(subchannel, Recipients::One(pk1), &b"TEST"[..], true)
875                .await
876                .unwrap();
877
878            // Receive the response with pk1's receiver.
879            let (from, bytes) = rx1.recv().await.unwrap();
880            assert_eq!(from, pk2);
881            assert_eq!(bytes.as_ref(), b"TEST");
882        });
883    }
884
885    #[test]
886    fn test_message_dropped_for_closed_subchannel() {
887        // Messages are dropped if the subchannel they are for is registered, but has been closed.
888        //
889        // NOTE: This case should be exceedingly rare in practice due to `SubReceiver` deregistering
890        // the subchannel on drop, but is included for completeness.
891        let executor = deterministic::Runner::default();
892        executor.start(|context| async move {
893            let mut oracle = start_network(context.clone());
894
895            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
896            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
897            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
898
899            // Register the subchannels.
900            let (tx1, _) = handle1.register(1).await.unwrap();
901            let (tx2, _) = handle1.register(2).await.unwrap();
902            let (_, mut rx1) = handle2.register(1).await.unwrap();
903            let (_, mut rx2) = handle2.register(2).await.unwrap();
904
905            // Send CAPACITY messages to subchannel 1, then drain them.
906            send_burst(&mut [tx1.clone()], CAPACITY).await;
907            expect_n_messages(&mut rx1, CAPACITY).await;
908
909            // Send CAPACITY messages to subchannel 2, then drain them.
910            send_burst(&mut [tx2.clone()], CAPACITY).await;
911            expect_n_messages(&mut rx2, CAPACITY).await;
912
913            // Explicitly close the underlying receiver for the first subchannel.
914            rx1.receiver.close();
915
916            // Send CAPACITY messages to each subchannel.
917            // Messages to subchannel 1 are dropped (receiver closed).
918            send_burst(&mut [tx1, tx2], CAPACITY).await;
919
920            // Subchannel 2 should receive CAPACITY messages.
921            expect_n_messages(&mut rx2, CAPACITY).await;
922        });
923    }
924
925    #[test]
926    fn test_dropped_backup_channel_doesnt_block() {
927        // Dropping the backup receiver doesn't block message processing.
928        // Messages to unregistered subchannels are simply dropped.
929        let executor = deterministic::Runner::default();
930        executor.start(|context| async move {
931            let mut oracle = start_network(context.clone());
932
933            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
934            let (pk2, mut handle2, backup2, _) =
935                create_peer_with_backup_and_global_sender(&context, &mut oracle, 1).await;
936            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
937
938            // Explicitly drop the backup receiver.
939            drop(backup2);
940
941            // Register the subchannels.
942            let (tx1, _) = handle1.register(1).await.unwrap();
943            let (tx2, _) = handle1.register(2).await.unwrap();
944            // Do not register the first subchannel on the second peer.
945            let (_, mut rx2) = handle2.register(2).await.unwrap();
946
947            // Send CAPACITY messages to each subchannel.
948            // Subchannel 1 messages are dropped (backup is closed).
949            // Subchannel 2 messages go to rx2.
950            send_burst(&mut [tx1, tx2], CAPACITY).await;
951
952            // rx2 should receive all CAPACITY messages.
953            expect_n_messages(&mut rx2, CAPACITY).await;
954        });
955    }
956
957    #[test]
958    fn test_duplicate_registration() {
959        // Returns an error if the subchannel is already registered.
960        let executor = deterministic::Runner::default();
961        executor.start(|context| async move {
962            let mut oracle = start_network(context.clone());
963
964            let (_pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
965
966            // Register the subchannel.
967            let (_, _rx) = handle1.register(7).await.unwrap();
968
969            // Registering again should return an error.
970            assert!(matches!(
971                handle1.register(7).await,
972                Err(Error::AlreadyRegistered(_))
973            ));
974        });
975    }
976
977    #[test]
978    fn test_register_after_deregister() {
979        // Can register a channel after it has been deregistered.
980        let executor = deterministic::Runner::default();
981        executor.start(|context| async move {
982            let mut oracle = start_network(context.clone());
983
984            let (_, mut handle) = create_peer(&context, &mut oracle, 0).await;
985            let (_, rx) = handle.register(7).await.unwrap();
986            drop(rx);
987
988            // Registering again should not return an error.
989            handle.register(7).await.unwrap();
990        });
991    }
992}