commonware_p2p/utils/
mux.rs

1//! This utility wraps a [Sender] and [Receiver], providing lightweight sub-channels keyed by
2//! [Channel].
3//!
4//! Usage:
5//! - Call [Muxer::new] to obtain a ([Muxer], [MuxHandle]) pair.
6//! - Call [Muxer::start] or run [Muxer::run] in a background task to demux incoming messages into
7//!   per-subchannel queues.
8//! - Call [MuxHandle::register] to obtain a ([SubSender], [SubReceiver]) pair for that subchannel,
9//!   even if the muxer is already running.
10
11use crate::{Channel, Message, Receiver, Recipients, Sender};
12use bytes::{BufMut, Bytes, BytesMut};
13use commonware_codec::{varint::UInt, EncodeSize, ReadExt, Write};
14use commonware_macros::select;
15use commonware_runtime::{spawn_cell, ContextCell, Handle, Spawner};
16use futures::{
17    channel::{mpsc, oneshot},
18    SinkExt, StreamExt,
19};
20use std::{collections::HashMap, fmt::Debug};
21use thiserror::Error;
22use tracing::debug;
23
24/// Errors that can occur when interacting with a [SubReceiver] or [MuxHandle].
25#[derive(Error, Debug)]
26pub enum Error {
27    #[error("subchannel already registered: {0}")]
28    AlreadyRegistered(Channel),
29    #[error("muxer is closed")]
30    Closed,
31    #[error("recv failed")]
32    RecvFailed,
33}
34
35/// Control messages for the [Muxer].
36enum Control<R: Receiver> {
37    Register {
38        subchannel: Channel,
39        sender: oneshot::Sender<mpsc::Receiver<Message<R::PublicKey>>>,
40    },
41    Deregister {
42        subchannel: Channel,
43    },
44}
45
46/// Thread-safe routing table mapping each [Channel] to the [mpsc::Sender] for [`Message<P>`].
47type Routes<P> = HashMap<Channel, mpsc::Sender<Message<P>>>;
48
49/// A backup channel response, with a [SubSender] to respond, the [Channel] that wasn't registered,
50/// and the [Message] received.
51type BackupResponse<P> = (Channel, Message<P>);
52
53/// A multiplexer of p2p channels into subchannels.
54pub struct Muxer<E: Spawner, S: Sender, R: Receiver> {
55    context: ContextCell<E>,
56    sender: S,
57    receiver: R,
58    mailbox_size: usize,
59    control_rx: mpsc::UnboundedReceiver<Control<R>>,
60    routes: Routes<R::PublicKey>,
61    backup: Option<mpsc::Sender<BackupResponse<R::PublicKey>>>,
62}
63
64impl<E: Spawner, S: Sender, R: Receiver> Muxer<E, S, R> {
65    /// Create a multiplexed wrapper around a [Sender] and [Receiver] pair, and return a ([Muxer],
66    /// [MuxHandle]) pair that can be used to register routes dynamically.
67    pub fn new(context: E, sender: S, receiver: R, mailbox_size: usize) -> (Self, MuxHandle<S, R>) {
68        Self::builder(context, sender, receiver, mailbox_size).build()
69    }
70
71    /// Creates a [MuxerBuilder] that can be used to configure and build a [Muxer].
72    pub fn builder(
73        context: E,
74        sender: S,
75        receiver: R,
76        mailbox_size: usize,
77    ) -> MuxerBuilder<E, S, R> {
78        let (control_tx, control_rx) = mpsc::unbounded();
79        let mux = Self {
80            context: ContextCell::new(context),
81            sender,
82            receiver,
83            mailbox_size,
84            control_rx,
85            routes: HashMap::new(),
86            backup: None,
87        };
88
89        let mux_handle = MuxHandle {
90            sender: mux.sender.clone(),
91            control_tx,
92        };
93
94        MuxerBuilder { mux, mux_handle }
95    }
96
97    /// Start the demuxer using the given spawner.
98    pub fn start(mut self) -> Handle<Result<(), R::Error>> {
99        spawn_cell!(self.context, self.run().await)
100    }
101
102    /// Drive demultiplexing of messages into per-subchannel receivers.
103    ///
104    /// Callers should run this in a background task for as long as the underlying `Receiver` is
105    /// expected to receive traffic.
106    pub async fn run(mut self) -> Result<(), R::Error> {
107        loop {
108            select! {
109                // Prefer control messages because network messages will
110                // already block when full (providing backpressure).
111                control = self.control_rx.next() => {
112                    match control {
113                        Some(Control::Register { subchannel, sender }) => {
114                            // If the subchannel is already registered, drop the sender.
115                            if self.routes.contains_key(&subchannel) {
116                                continue;
117                            }
118
119                            // Otherwise, create a new subchannel and send the receiver to the caller.
120                            let (tx, rx) = mpsc::channel(self.mailbox_size);
121                            self.routes.insert(subchannel, tx);
122                            let _ = sender.send(rx);
123                        },
124                        Some(Control::Deregister { subchannel }) => {
125                            // Remove the route.
126                            self.routes.remove(&subchannel);
127                        },
128                        None => {
129                            // If the control channel is closed, we can shut down since there must
130                            // be no more registrations, and all receivers must have been dropped.
131                            return Ok(());
132                        }
133                    }
134                },
135                // Process network messages.
136                message = self.receiver.recv() => {
137                    let (pk, mut bytes) = message?;
138
139                    // Decode message: varint(subchannel) || bytes
140                    let subchannel: Channel = match UInt::read(&mut bytes) {
141                        Ok(v) => v.into(),
142                        Err(_) => {
143                            debug!(?pk, "invalid message: missing subchannel");
144                            continue;
145                        }
146                    };
147
148                    // Get the route for the subchannel.
149                    let Some(sender) = self.routes.get_mut(&subchannel) else {
150                        // Attempt to use the backup channel if available.
151                        if let Some(backup) = &mut self.backup {
152                            if let Err(e) = backup.send((subchannel, (pk, bytes))).await {
153                                debug!(?subchannel, ?e, "failed to send message to backup channel");
154                            }
155                        }
156
157                        // Drops the message if the subchannel is not found or the backup
158                        // channel was not used.
159                        continue;
160                    };
161
162                    // Send the message to the subchannel, blocking if the queue is full.
163                    if let Err(e) = sender.send((pk, bytes)).await {
164                        // Remove the route for the subchannel.
165                        self.routes.remove(&subchannel);
166
167                        // Failure, drop the sender since the receiver is no longer interested.
168                        debug!(?subchannel, ?e, "failed to send message to subchannel");
169
170                        // NOTE: The channel is deregistered, but it wasn't when the message was received.
171                        // The backup channel is not used in this case.
172                    }
173                }
174            }
175        }
176    }
177}
178
179/// A clonable handle that allows registering routes at any time, even after the [Muxer] is running.
180#[derive(Clone)]
181pub struct MuxHandle<S: Sender, R: Receiver> {
182    sender: S,
183    control_tx: mpsc::UnboundedSender<Control<R>>,
184}
185
186impl<S: Sender, R: Receiver> MuxHandle<S, R> {
187    /// Open a `subchannel`. Returns a ([SubSender], [SubReceiver]) pair that can be used to send
188    /// and receive messages for that subchannel.
189    ///
190    /// Panics if the subchannel is already registered at any point.
191    pub async fn register(
192        &mut self,
193        subchannel: Channel,
194    ) -> Result<(SubSender<S>, SubReceiver<R>), Error> {
195        let (tx, rx) = oneshot::channel();
196        self.control_tx
197            .send(Control::Register {
198                subchannel,
199                sender: tx,
200            })
201            .await
202            .map_err(|_| Error::Closed)?;
203        let receiver = rx.await.map_err(|_| Error::AlreadyRegistered(subchannel))?;
204
205        Ok((
206            SubSender {
207                subchannel,
208                inner: GlobalSender::new(self.sender.clone()),
209            },
210            SubReceiver {
211                receiver,
212                control_tx: Some(self.control_tx.clone()),
213                subchannel,
214            },
215        ))
216    }
217}
218
219/// Sender that routes messages to the `subchannel`.
220#[derive(Clone, Debug)]
221pub struct SubSender<S: Sender> {
222    inner: GlobalSender<S>,
223    subchannel: Channel,
224}
225
226impl<S: Sender> Sender for SubSender<S> {
227    type Error = S::Error;
228    type PublicKey = S::PublicKey;
229
230    async fn send(
231        &mut self,
232        recipients: Recipients<S::PublicKey>,
233        payload: Bytes,
234        priority: bool,
235    ) -> Result<Vec<S::PublicKey>, S::Error> {
236        self.inner
237            .send(self.subchannel, recipients, payload, priority)
238            .await
239    }
240}
241
242/// Receiver that yields messages for a specific subchannel.
243pub struct SubReceiver<R: Receiver> {
244    receiver: mpsc::Receiver<Message<R::PublicKey>>,
245    control_tx: Option<mpsc::UnboundedSender<Control<R>>>,
246    subchannel: Channel,
247}
248
249impl<R: Receiver> Receiver for SubReceiver<R> {
250    type Error = Error;
251    type PublicKey = R::PublicKey;
252
253    async fn recv(&mut self) -> Result<Message<Self::PublicKey>, Self::Error> {
254        self.receiver.next().await.ok_or(Error::RecvFailed)
255    }
256}
257
258impl<R: Receiver> Debug for SubReceiver<R> {
259    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260        write!(f, "SubReceiver({})", self.subchannel)
261    }
262}
263
264impl<R: Receiver> Drop for SubReceiver<R> {
265    fn drop(&mut self) {
266        // Take the control channel to avoid cloning.
267        let control_tx = self
268            .control_tx
269            .take()
270            .expect("SubReceiver::drop called twice");
271
272        // Deregister the subchannel immediately.
273        let _ = control_tx.unbounded_send(Control::Deregister {
274            subchannel: self.subchannel,
275        });
276    }
277}
278
279/// Sender that can send messages over any sub [Channel].
280#[derive(Clone, Debug)]
281pub struct GlobalSender<S: Sender> {
282    inner: S,
283}
284
285impl<S: Sender> GlobalSender<S> {
286    /// Create a new [GlobalSender] wrapping the given [Sender].
287    pub fn new(inner: S) -> Self {
288        Self { inner }
289    }
290
291    /// Send a message over the given `subchannel`.
292    pub async fn send(
293        &mut self,
294        subchannel: Channel,
295        recipients: Recipients<S::PublicKey>,
296        payload: Bytes,
297        priority: bool,
298    ) -> Result<Vec<S::PublicKey>, S::Error> {
299        let subchannel = UInt(subchannel);
300        let mut buf = BytesMut::with_capacity(subchannel.encode_size() + payload.len());
301        subchannel.write(&mut buf);
302        buf.put_slice(&payload);
303        self.inner.send(recipients, buf.freeze(), priority).await
304    }
305}
306
307/// A generic builder interface.
308pub trait Builder {
309    /// The output type produced by the builder.
310    type Output;
311
312    /// Builds the output type, consuming `self`.
313    fn build(self) -> Self::Output;
314}
315
316/// A builder that constructs a [Muxer].
317pub struct MuxerBuilder<E: Spawner, S: Sender, R: Receiver> {
318    mux: Muxer<E, S, R>,
319    mux_handle: MuxHandle<S, R>,
320}
321
322impl<E: Spawner, S: Sender, R: Receiver> Builder for MuxerBuilder<E, S, R> {
323    type Output = (Muxer<E, S, R>, MuxHandle<S, R>);
324
325    fn build(self) -> Self::Output {
326        (self.mux, self.mux_handle)
327    }
328}
329
330impl<E: Spawner, S: Sender, R: Receiver> MuxerBuilder<E, S, R> {
331    /// Registers a backup channel with the muxer.
332    pub fn with_backup(mut self) -> MuxerBuilderWithBackup<E, S, R> {
333        let (tx, rx) = mpsc::channel(self.mux.mailbox_size);
334        self.mux.backup = Some(tx);
335
336        MuxerBuilderWithBackup {
337            mux: self.mux,
338            mux_handle: self.mux_handle,
339            backup_rx: rx,
340        }
341    }
342
343    /// Registers a global sender with the muxer.
344    pub fn with_global_sender(self) -> MuxerBuilderWithGlobalSender<E, S, R> {
345        let global_sender = GlobalSender::new(self.mux.sender.clone());
346
347        MuxerBuilderWithGlobalSender {
348            mux: self.mux,
349            mux_handle: self.mux_handle,
350            global_sender,
351        }
352    }
353}
354
355/// A builder that constructs a [Muxer] with a backup channel.
356pub struct MuxerBuilderWithBackup<E: Spawner, S: Sender, R: Receiver> {
357    mux: Muxer<E, S, R>,
358    mux_handle: MuxHandle<S, R>,
359    backup_rx: mpsc::Receiver<BackupResponse<R::PublicKey>>,
360}
361
362impl<E: Spawner, S: Sender, R: Receiver> MuxerBuilderWithBackup<E, S, R> {
363    /// Registers a global sender with the muxer.
364    pub fn with_global_sender(self) -> MuxerBuilderAllOpts<E, S, R> {
365        let global_sender = GlobalSender::new(self.mux.sender.clone());
366
367        MuxerBuilderAllOpts {
368            mux: self.mux,
369            mux_handle: self.mux_handle,
370            backup_rx: self.backup_rx,
371            global_sender,
372        }
373    }
374}
375
376impl<E: Spawner, S: Sender, R: Receiver> Builder for MuxerBuilderWithBackup<E, S, R> {
377    type Output = (
378        Muxer<E, S, R>,
379        MuxHandle<S, R>,
380        mpsc::Receiver<BackupResponse<R::PublicKey>>,
381    );
382
383    fn build(self) -> Self::Output {
384        (self.mux, self.mux_handle, self.backup_rx)
385    }
386}
387
388/// A builder that constructs a [Muxer] with a [GlobalSender].
389pub struct MuxerBuilderWithGlobalSender<E: Spawner, S: Sender, R: Receiver> {
390    mux: Muxer<E, S, R>,
391    mux_handle: MuxHandle<S, R>,
392    global_sender: GlobalSender<S>,
393}
394
395impl<E: Spawner, S: Sender, R: Receiver> MuxerBuilderWithGlobalSender<E, S, R> {
396    /// Registers a backup channel with the muxer.
397    pub fn with_backup(mut self) -> MuxerBuilderAllOpts<E, S, R> {
398        let (tx, rx) = mpsc::channel(self.mux.mailbox_size);
399        self.mux.backup = Some(tx);
400
401        MuxerBuilderAllOpts {
402            mux: self.mux,
403            mux_handle: self.mux_handle,
404            backup_rx: rx,
405            global_sender: self.global_sender,
406        }
407    }
408}
409
410impl<E: Spawner, S: Sender, R: Receiver> Builder for MuxerBuilderWithGlobalSender<E, S, R> {
411    type Output = (Muxer<E, S, R>, MuxHandle<S, R>, GlobalSender<S>);
412
413    fn build(self) -> Self::Output {
414        (self.mux, self.mux_handle, self.global_sender)
415    }
416}
417
418/// A builder that constructs a [Muxer] with a [GlobalSender] and backup channel.
419pub struct MuxerBuilderAllOpts<E: Spawner, S: Sender, R: Receiver> {
420    mux: Muxer<E, S, R>,
421    mux_handle: MuxHandle<S, R>,
422    backup_rx: mpsc::Receiver<BackupResponse<R::PublicKey>>,
423    global_sender: GlobalSender<S>,
424}
425
426impl<E: Spawner, S: Sender, R: Receiver> Builder for MuxerBuilderAllOpts<E, S, R> {
427    type Output = (
428        Muxer<E, S, R>,
429        MuxHandle<S, R>,
430        mpsc::Receiver<BackupResponse<R::PublicKey>>,
431        GlobalSender<S>,
432    );
433
434    fn build(self) -> Self::Output {
435        (
436            self.mux,
437            self.mux_handle,
438            self.backup_rx,
439            self.global_sender,
440        )
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447    use crate::{
448        simulated::{self, Link, Network, Oracle},
449        Recipients,
450    };
451    use bytes::Bytes;
452    use commonware_cryptography::{ed25519::PrivateKey, PrivateKeyExt, Signer};
453    use commonware_macros::{select, test_traced};
454    use commonware_runtime::{deterministic, Metrics, Runner};
455    use std::time::Duration;
456
457    type Pk = commonware_cryptography::ed25519::PublicKey;
458
459    const LINK: Link = Link {
460        latency: Duration::from_millis(0),
461        jitter: Duration::from_millis(0),
462        success_rate: 1.0,
463    };
464    const CAPACITY: usize = 5usize;
465
466    /// Start the network and return the oracle.
467    fn start_network(context: deterministic::Context) -> Oracle<Pk> {
468        let (network, oracle) = Network::new(
469            context.with_label("network"),
470            simulated::Config {
471                max_size: 1024 * 1024,
472                disconnect_on_block: true,
473                tracked_peer_sets: None,
474            },
475        );
476        network.start();
477        oracle
478    }
479
480    /// Create a public key from a seed.
481    fn pk(seed: u64) -> Pk {
482        PrivateKey::from_seed(seed).public_key()
483    }
484
485    /// Link two peers bidirectionally.
486    async fn link_bidirectional(oracle: &mut Oracle<Pk>, a: Pk, b: Pk) {
487        oracle.add_link(a.clone(), b.clone(), LINK).await.unwrap();
488        oracle.add_link(b, a, LINK).await.unwrap();
489    }
490
491    /// Create a peer and register it with the oracle.
492    async fn create_peer<E: Spawner + Metrics>(
493        context: &E,
494        oracle: &mut Oracle<Pk>,
495        seed: u64,
496    ) -> (
497        Pk,
498        MuxHandle<impl Sender<PublicKey = Pk>, impl Receiver<PublicKey = Pk>>,
499    ) {
500        let pubkey = pk(seed);
501        let (sender, receiver) = oracle.control(pubkey.clone()).register(0).await.unwrap();
502        let (mux, handle) = Muxer::new(context.with_label("mux"), sender, receiver, CAPACITY);
503        mux.start();
504        (pubkey, handle)
505    }
506
507    /// Create a peer and register it with the oracle.
508    async fn create_peer_with_backup_and_global_sender<E: Spawner + Metrics>(
509        context: &E,
510        oracle: &mut Oracle<Pk>,
511        seed: u64,
512    ) -> (
513        Pk,
514        MuxHandle<impl Sender<PublicKey = Pk>, impl Receiver<PublicKey = Pk>>,
515        mpsc::Receiver<BackupResponse<Pk>>,
516        GlobalSender<simulated::Sender<Pk>>,
517    ) {
518        let pubkey = pk(seed);
519        let (sender, receiver) = oracle.control(pubkey.clone()).register(0).await.unwrap();
520        let (mux, handle, backup, global_sender) =
521            Muxer::builder(context.with_label("mux"), sender, receiver, CAPACITY)
522                .with_backup()
523                .with_global_sender()
524                .build();
525        mux.start();
526        (pubkey, handle, backup, global_sender)
527    }
528
529    /// Send a burst of messages to a list of senders.
530    async fn send_burst<S: Sender>(txs: &mut [SubSender<S>], count: usize) {
531        for i in 0..count {
532            let payload = Bytes::from(vec![i as u8]);
533            for tx in txs.iter_mut() {
534                let _ = tx
535                    .send(Recipients::All, payload.clone(), false)
536                    .await
537                    .unwrap();
538            }
539        }
540    }
541
542    /// Wait for `n` messages to be received on the receiver.
543    async fn expect_n_messages(rx: &mut SubReceiver<impl Receiver<PublicKey = Pk>>, n: usize) {
544        let mut count = 0;
545        loop {
546            select! {
547                res = rx.recv() => {
548                    res.expect("should have received message");
549                    count += 1;
550                },
551            }
552
553            if count >= n {
554                break;
555            }
556        }
557        assert_eq!(n, count);
558    }
559
560    /// Wait for `n` messages to be received on the receiver + backup receiver.
561    async fn expect_n_messages_with_backup(
562        rx: &mut SubReceiver<impl Receiver<PublicKey = Pk>>,
563        backup_rx: &mut mpsc::Receiver<BackupResponse<Pk>>,
564        n: usize,
565        n_backup: usize,
566    ) {
567        let mut count_std = 0;
568        let mut count_backup = 0;
569        loop {
570            select! {
571                res = rx.recv() => {
572                    res.expect("should have received message");
573                    count_std += 1;
574                },
575                res = backup_rx.next() => {
576                    res.expect("should have received message");
577                    count_backup += 1;
578                },
579            }
580
581            if count_std >= n && count_backup >= n_backup {
582                break;
583            }
584        }
585        assert_eq!(n, count_std);
586        assert_eq!(n_backup, count_backup);
587    }
588
589    #[test]
590    fn test_basic_routing() {
591        // Can register a subchannel and send messages to it.
592        let executor = deterministic::Runner::default();
593        executor.start(|context| async move {
594            let mut oracle = start_network(context.clone());
595
596            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
597            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
598            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
599
600            let (_, mut sub_rx1) = handle1.register(7).await.unwrap();
601            let (mut sub_tx2, _) = handle2.register(7).await.unwrap();
602
603            // Send and receive
604            let payload = Bytes::from_static(b"hello");
605            let _ = sub_tx2
606                .send(Recipients::One(pk1.clone()), payload.clone(), false)
607                .await
608                .unwrap();
609            let (from, bytes) = sub_rx1.recv().await.unwrap();
610            assert_eq!(from, pk2);
611            assert_eq!(bytes, payload);
612        });
613    }
614
615    #[test]
616    fn test_multiple_routes() {
617        // Can register multiple subchannels and send messages to each.
618        let executor = deterministic::Runner::default();
619        executor.start(|context| async move {
620            let mut oracle = start_network(context.clone());
621
622            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
623            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
624            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
625
626            let (_, mut rx_a) = handle1.register(10).await.unwrap();
627            let (_, mut rx_b) = handle1.register(20).await.unwrap();
628
629            let (mut tx2_a, _) = handle2.register(10).await.unwrap();
630            let (mut tx2_b, _) = handle2.register(20).await.unwrap();
631
632            let payload_a = Bytes::from_static(b"A");
633            let payload_b = Bytes::from_static(b"B");
634            let _ = tx2_a
635                .send(Recipients::One(pk1.clone()), payload_a.clone(), false)
636                .await
637                .unwrap();
638            let _ = tx2_b
639                .send(Recipients::One(pk1.clone()), payload_b.clone(), false)
640                .await
641                .unwrap();
642
643            let (from_a, bytes_a) = rx_a.recv().await.unwrap();
644            assert_eq!(from_a, pk2);
645            assert_eq!(bytes_a, payload_a);
646
647            let (from_b, bytes_b) = rx_b.recv().await.unwrap();
648            assert_eq!(from_b, pk2);
649            assert_eq!(bytes_b, payload_b);
650        });
651    }
652
653    #[test_traced]
654    fn test_mailbox_capacity_blocks() {
655        // If a single subchannel is full, messages are blocked for all subchannels.
656        let executor = deterministic::Runner::default();
657        executor.start(|context| async move {
658            let mut oracle = start_network(context.clone());
659
660            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
661            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
662            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
663
664            // Register the subchannels.
665            let (tx1, _) = handle1.register(99).await.unwrap();
666            let (tx2, _) = handle1.register(100).await.unwrap();
667            let (_, mut rx1) = handle2.register(99).await.unwrap();
668            let (_, mut rx2) = handle2.register(100).await.unwrap();
669
670            // Send 10 messages to each subchannel from pk1 to pk2.
671            send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
672
673            // Try receiving all messages from the second subchannel.
674            expect_n_messages(&mut rx2, CAPACITY).await;
675
676            // Try receiving from the first subchannel.
677            expect_n_messages(&mut rx1, CAPACITY * 2).await;
678
679            // The second subchannel should be unblocked and receive the rest of the messages.
680            expect_n_messages(&mut rx2, CAPACITY).await;
681        });
682    }
683
684    #[test]
685    fn test_drop_a_full_subchannel() {
686        // Drops the subchannel receiver while the sender is blocked.
687        let executor = deterministic::Runner::default();
688        executor.start(|context| async move {
689            let mut oracle = start_network(context.clone());
690
691            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
692            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
693            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
694
695            // Register the subchannels.
696            let (tx1, _) = handle1.register(99).await.unwrap();
697            let (tx2, _) = handle1.register(100).await.unwrap();
698            let (_, rx1) = handle2.register(99).await.unwrap();
699            let (_, mut rx2) = handle2.register(100).await.unwrap();
700
701            // Send 10 messages to each subchannel from pk1 to pk2.
702            send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
703
704            // Try receiving all messages from the second subchannel.
705            expect_n_messages(&mut rx2, CAPACITY).await;
706
707            // Drop the first subchannel, erroring the sender and dropping it.
708            drop(rx1);
709
710            // The second subchannel should be unblocked and receive the rest of the messages.
711            expect_n_messages(&mut rx2, CAPACITY).await;
712        });
713    }
714
715    #[test]
716    fn test_drop_messages_for_unregistered_subchannel() {
717        // Messages are dropped if the subchannel they are for is not registered.
718        let executor = deterministic::Runner::default();
719        executor.start(|context| async move {
720            let mut oracle = start_network(context.clone());
721
722            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
723            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
724            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
725
726            // Register the subchannels.
727            let (tx1, _) = handle1.register(1).await.unwrap();
728            let (tx2, _) = handle1.register(2).await.unwrap();
729            // Do not register the first subchannel on the second peer.
730            let (_, mut rx2) = handle2.register(2).await.unwrap();
731
732            // Send 10 messages to each subchannel from pk1 to pk2.
733            send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
734
735            // Try receiving all messages from the second subchannel.
736            expect_n_messages(&mut rx2, CAPACITY * 2).await;
737        });
738    }
739
740    #[test]
741    fn test_backup_for_unregistered_subchannel() {
742        // Messages are forwarded to the backup channel if the subchannel they are for
743        // is not registered.
744        let executor = deterministic::Runner::default();
745        executor.start(|context| async move {
746            let mut oracle = start_network(context.clone());
747
748            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
749            let (pk2, mut handle2, mut backup2, _) =
750                create_peer_with_backup_and_global_sender(&context, &mut oracle, 1).await;
751            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
752
753            // Register the subchannels.
754            let (tx1, _) = handle1.register(1).await.unwrap();
755            let (tx2, _) = handle1.register(2).await.unwrap();
756            // Do not register the first subchannel on the second peer.
757            let (_, mut rx2) = handle2.register(2).await.unwrap();
758
759            // Send 10 messages to each subchannel from pk1 to pk2.
760            send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
761
762            // Try receiving all messages from the second subchannel and backup channel.
763            // All 20 messages sent should be received.
764            expect_n_messages_with_backup(&mut rx2, &mut backup2, CAPACITY * 2, CAPACITY * 2).await;
765        });
766    }
767
768    #[test]
769    fn test_backup_for_unregistered_subchannel_response() {
770        // Messages are forwarded to the backup channel if the subchannel they are for
771        // is not registered.
772        let executor = deterministic::Runner::default();
773        executor.start(|context| async move {
774            let mut oracle = start_network(context.clone());
775
776            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
777            let (pk2, _handle2, mut backup2, mut global_sender2) =
778                create_peer_with_backup_and_global_sender(&context, &mut oracle, 1).await;
779            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
780
781            // Register the subchannels.
782            let (tx1, mut rx1) = handle1.register(1).await.unwrap();
783            // Do not register any subchannels on the second peer.
784
785            // Send 1 message to each subchannel from pk1 to pk2.
786            send_burst(&mut [tx1], 1).await;
787
788            // Get the message from pk2's backup channel and respond.
789            let (subchannel, (from, _)) = backup2.next().await.unwrap();
790            assert_eq!(subchannel, 1);
791            assert_eq!(from, pk1);
792            global_sender2
793                .send(
794                    subchannel,
795                    Recipients::One(pk1),
796                    b"TEST".to_vec().into(),
797                    true,
798                )
799                .await
800                .unwrap();
801
802            // Receive the response with pk1's receiver.
803            let (from, bytes) = rx1.recv().await.unwrap();
804            assert_eq!(from, pk2);
805            assert_eq!(bytes.as_ref(), b"TEST");
806        });
807    }
808
809    #[test]
810    fn test_message_dropped_for_closed_subchannel() {
811        // Messages are dropped if the subchannel they are for is registered, but has been closed.
812        //
813        // NOTE: This case should be exceedingly rare in practice due to `SubReceiver` deregistering
814        // the subchannel on drop, but is included for completeness.
815        let executor = deterministic::Runner::default();
816        executor.start(|context| async move {
817            let mut oracle = start_network(context.clone());
818
819            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
820            let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
821            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
822
823            // Register the subchannels.
824            let (tx1, _) = handle1.register(1).await.unwrap();
825            let (tx2, _) = handle1.register(2).await.unwrap();
826            let (_, mut rx1) = handle2.register(1).await.unwrap();
827            let (_, mut rx2) = handle2.register(2).await.unwrap();
828
829            // Send 10 messages to subchannel 1 from pk1 to pk2.
830            send_burst(&mut [tx1.clone()], CAPACITY * 2).await;
831
832            // Try receiving all messages from the first subchannel.
833            expect_n_messages(&mut rx1, CAPACITY * 2).await;
834
835            // Send 10 messages to subchannel 2 from pk1 to pk2.
836            send_burst(&mut [tx2.clone()], CAPACITY * 2).await;
837
838            // Try receiving all messages from the first subchannel.
839            expect_n_messages(&mut rx2, CAPACITY * 2).await;
840
841            // Explicitly close the underlying receiver for the first subchannel.
842            rx1.receiver.close();
843
844            // Send 10 messages to each subchannel from pk1 to pk2.
845            send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
846
847            // Try receiving all messages from the second subchannel.
848            expect_n_messages(&mut rx2, CAPACITY * 2).await;
849        });
850    }
851
852    #[test]
853    fn test_dropped_backup_channel_doesnt_block() {
854        let executor = deterministic::Runner::default();
855        executor.start(|context| async move {
856            let mut oracle = start_network(context.clone());
857
858            let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
859            let (pk2, mut handle2, backup2, _) =
860                create_peer_with_backup_and_global_sender(&context, &mut oracle, 1).await;
861            link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
862
863            // Explicitly drop the backup receiver.
864            drop(backup2);
865
866            // Register the subchannels.
867            let (tx1, _) = handle1.register(1).await.unwrap();
868            let (tx2, _) = handle1.register(2).await.unwrap();
869            // Do not register the first subchannel on the second peer.
870            let (_, mut rx2) = handle2.register(2).await.unwrap();
871
872            // Send 10 messages to each subchannel from pk1 to pk2.
873            send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
874
875            // Try receiving all messages from the second subchannel.
876            expect_n_messages(&mut rx2, CAPACITY * 2).await;
877        });
878    }
879
880    #[test]
881    fn test_duplicate_registration() {
882        // Returns an error if the subchannel is already registered.
883        let executor = deterministic::Runner::default();
884        executor.start(|context| async move {
885            let mut oracle = start_network(context.clone());
886
887            let (_pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
888
889            // Register the subchannel.
890            let (_, _rx) = handle1.register(7).await.unwrap();
891
892            // Registering again should return an error.
893            assert!(matches!(
894                handle1.register(7).await,
895                Err(Error::AlreadyRegistered(_))
896            ));
897        });
898    }
899
900    #[test]
901    fn test_register_after_deregister() {
902        // Can register a channel after it has been deregistered.
903        let executor = deterministic::Runner::default();
904        executor.start(|context| async move {
905            let mut oracle = start_network(context.clone());
906
907            let (_, mut handle) = create_peer(&context, &mut oracle, 0).await;
908            let (_, rx) = handle.register(7).await.unwrap();
909            drop(rx);
910
911            // Registering again should not return an error.
912            handle.register(7).await.unwrap();
913        });
914    }
915}