1use crate::{Channel, CheckedSender, LimitedSender, Message, Receiver, Recipients, Sender};
12use bytes::{Buf, Bytes};
13use commonware_codec::{varint::UInt, Encode, Error as CodecError, ReadExt};
14use commonware_macros::select_loop;
15use commonware_runtime::{spawn_cell, ContextCell, Handle, Spawner};
16use commonware_utils::channels::fallible::FallibleExt;
17use futures::{
18 channel::{mpsc, oneshot},
19 SinkExt, StreamExt,
20};
21use std::{collections::HashMap, fmt::Debug, time::SystemTime};
22use thiserror::Error;
23use tracing::debug;
24
25#[derive(Error, Debug)]
27pub enum Error {
28 #[error("subchannel already registered: {0}")]
29 AlreadyRegistered(Channel),
30 #[error("muxer is closed")]
31 Closed,
32 #[error("recv failed")]
33 RecvFailed,
34}
35
36pub fn parse(mut bytes: Bytes) -> Result<(Channel, Bytes), CodecError> {
38 let subchannel: Channel = UInt::read(&mut bytes)?.into();
39 Ok((subchannel, bytes))
40}
41
42enum Control<R: Receiver> {
44 Register {
45 subchannel: Channel,
46 sender: oneshot::Sender<mpsc::Receiver<Message<R::PublicKey>>>,
47 },
48 Deregister {
49 subchannel: Channel,
50 },
51}
52
53type Routes<P> = HashMap<Channel, mpsc::Sender<Message<P>>>;
55
56type BackupResponse<P> = (Channel, Message<P>);
59
60pub struct Muxer<E: Spawner, S: Sender, R: Receiver> {
62 context: ContextCell<E>,
63 sender: S,
64 receiver: R,
65 mailbox_size: usize,
66 control_rx: mpsc::UnboundedReceiver<Control<R>>,
67 routes: Routes<R::PublicKey>,
68 backup: Option<mpsc::Sender<BackupResponse<R::PublicKey>>>,
69}
70
71impl<E: Spawner, S: Sender, R: Receiver> Muxer<E, S, R> {
72 pub fn new(context: E, sender: S, receiver: R, mailbox_size: usize) -> (Self, MuxHandle<S, R>) {
75 Self::builder(context, sender, receiver, mailbox_size).build()
76 }
77
78 pub fn builder(
80 context: E,
81 sender: S,
82 receiver: R,
83 mailbox_size: usize,
84 ) -> MuxerBuilder<E, S, R> {
85 let (control_tx, control_rx) = mpsc::unbounded();
86 let mux = Self {
87 context: ContextCell::new(context),
88 sender,
89 receiver,
90 mailbox_size,
91 control_rx,
92 routes: HashMap::new(),
93 backup: None,
94 };
95
96 let mux_handle = MuxHandle {
97 sender: mux.sender.clone(),
98 control_tx,
99 };
100
101 MuxerBuilder { mux, mux_handle }
102 }
103
104 pub fn start(mut self) -> Handle<Result<(), R::Error>> {
106 spawn_cell!(self.context, self.run().await)
107 }
108
109 pub async fn run(mut self) -> Result<(), R::Error> {
114 select_loop! {
115 self.context,
116 on_stopped => {
117 debug!("context shutdown, stopping muxer");
118 },
119 control = self.control_rx.next() => {
122 match control {
123 Some(Control::Register { subchannel, sender }) => {
124 if self.routes.contains_key(&subchannel) {
126 continue;
127 }
128
129 let (tx, rx) = mpsc::channel(self.mailbox_size);
131 self.routes.insert(subchannel, tx);
132 let _ = sender.send(rx);
133 },
134 Some(Control::Deregister { subchannel }) => {
135 self.routes.remove(&subchannel);
137 },
138 None => {
139 return Ok(());
142 }
143 }
144 },
145 message = self.receiver.recv() => {
147 let (pk, bytes) = message?;
149 let (subchannel, bytes) = match parse(bytes) {
150 Ok(parsed) => parsed,
151 Err(_) => {
152 debug!(?pk, "invalid message: missing subchannel");
153 continue;
154 }
155 };
156
157 let Some(sender) = self.routes.get_mut(&subchannel) else {
159 if let Some(backup) = &mut self.backup {
161 if let Err(e) = backup.try_send((subchannel, (pk, bytes))) {
162 debug!(?subchannel, ?e, "failed to send message to backup channel");
163 }
164 }
165
166 continue;
169 };
170
171 if let Err(e) = sender.try_send((pk, bytes)) {
174 if e.is_disconnected() {
176 self.routes.remove(&subchannel);
178 debug!(?subchannel, "subchannel receiver dropped, removing route");
179 } else {
180 debug!(?subchannel, "subchannel full, dropping message");
182 }
183 }
184 }
185 }
186
187 Ok(())
188 }
189}
190
191#[derive(Clone)]
193pub struct MuxHandle<S: Sender, R: Receiver> {
194 sender: S,
195 control_tx: mpsc::UnboundedSender<Control<R>>,
196}
197
198impl<S: Sender, R: Receiver> MuxHandle<S, R> {
199 pub async fn register(
204 &mut self,
205 subchannel: Channel,
206 ) -> Result<(SubSender<S>, SubReceiver<R>), Error> {
207 let (tx, rx) = oneshot::channel();
208 self.control_tx
209 .send(Control::Register {
210 subchannel,
211 sender: tx,
212 })
213 .await
214 .map_err(|_| Error::Closed)?;
215 let receiver = rx.await.map_err(|_| Error::AlreadyRegistered(subchannel))?;
216
217 Ok((
218 SubSender {
219 subchannel,
220 inner: GlobalSender::new(self.sender.clone()),
221 },
222 SubReceiver {
223 receiver,
224 control_tx: Some(self.control_tx.clone()),
225 subchannel,
226 },
227 ))
228 }
229}
230
231#[derive(Clone, Debug)]
233pub struct SubSender<S: Sender> {
234 inner: GlobalSender<S>,
235 subchannel: Channel,
236}
237
238impl<S: Sender> LimitedSender for SubSender<S> {
239 type PublicKey = S::PublicKey;
240 type Checked<'a> = CheckedGlobalSender<'a, S>;
241
242 async fn check(
243 &mut self,
244 recipients: Recipients<Self::PublicKey>,
245 ) -> Result<Self::Checked<'_>, SystemTime> {
246 self.inner
247 .check(recipients)
248 .await
249 .map(|checked| checked.with_subchannel(self.subchannel))
250 }
251}
252
253pub struct SubReceiver<R: Receiver> {
255 receiver: mpsc::Receiver<Message<R::PublicKey>>,
256 control_tx: Option<mpsc::UnboundedSender<Control<R>>>,
257 subchannel: Channel,
258}
259
260impl<R: Receiver> Receiver for SubReceiver<R> {
261 type Error = Error;
262 type PublicKey = R::PublicKey;
263
264 async fn recv(&mut self) -> Result<Message<Self::PublicKey>, Self::Error> {
265 self.receiver.next().await.ok_or(Error::RecvFailed)
266 }
267}
268
269impl<R: Receiver> Debug for SubReceiver<R> {
270 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
271 write!(f, "SubReceiver({})", self.subchannel)
272 }
273}
274
275impl<R: Receiver> Drop for SubReceiver<R> {
276 fn drop(&mut self) {
277 let control_tx = self
279 .control_tx
280 .take()
281 .expect("SubReceiver::drop called twice");
282
283 control_tx.send_lossy(Control::Deregister {
285 subchannel: self.subchannel,
286 });
287 }
288}
289
290#[derive(Clone, Debug)]
292pub struct GlobalSender<S: Sender> {
293 inner: S,
294}
295
296impl<S: Sender> GlobalSender<S> {
297 pub const fn new(inner: S) -> Self {
299 Self { inner }
300 }
301
302 pub async fn send(
304 &mut self,
305 subchannel: Channel,
306 recipients: Recipients<S::PublicKey>,
307 payload: impl Buf + Send,
308 priority: bool,
309 ) -> Result<Vec<S::PublicKey>, <S::Checked<'_> as CheckedSender>::Error> {
310 match self.check(recipients).await {
311 Ok(checked) => {
312 checked
313 .with_subchannel(subchannel)
314 .send(payload, priority)
315 .await
316 }
317 Err(_) => Ok(Vec::new()),
318 }
319 }
320}
321
322impl<S: Sender> LimitedSender for GlobalSender<S> {
323 type PublicKey = S::PublicKey;
324 type Checked<'a> = CheckedGlobalSender<'a, S>;
325
326 async fn check(
327 &mut self,
328 recipients: Recipients<Self::PublicKey>,
329 ) -> Result<Self::Checked<'_>, SystemTime> {
330 self.inner
331 .check(recipients)
332 .await
333 .map(|checked| CheckedGlobalSender {
334 subchannel: None,
335 inner: checked,
336 })
337 }
338}
339
340pub struct CheckedGlobalSender<'a, S: Sender> {
342 subchannel: Option<Channel>,
343 inner: S::Checked<'a>,
344}
345
346impl<'a, S: Sender> CheckedGlobalSender<'a, S> {
347 pub const fn with_subchannel(mut self, subchannel: Channel) -> Self {
349 self.subchannel = Some(subchannel);
350 self
351 }
352}
353
354impl<'a, S: Sender> CheckedSender for CheckedGlobalSender<'a, S> {
355 type PublicKey = S::PublicKey;
356 type Error = <S::Checked<'a> as CheckedSender>::Error;
357
358 async fn send(
359 self,
360 message: impl Buf + Send,
361 priority: bool,
362 ) -> Result<Vec<Self::PublicKey>, Self::Error> {
363 let subchannel = UInt(self.subchannel.expect("subchannel not set"));
364 self.inner
365 .send(subchannel.encode().chain(message), priority)
366 .await
367 }
368}
369
370pub trait Builder {
372 type Output;
374
375 fn build(self) -> Self::Output;
377}
378
379pub struct MuxerBuilder<E: Spawner, S: Sender, R: Receiver> {
381 mux: Muxer<E, S, R>,
382 mux_handle: MuxHandle<S, R>,
383}
384
385impl<E: Spawner, S: Sender, R: Receiver> Builder for MuxerBuilder<E, S, R> {
386 type Output = (Muxer<E, S, R>, MuxHandle<S, R>);
387
388 fn build(self) -> Self::Output {
389 (self.mux, self.mux_handle)
390 }
391}
392
393impl<E: Spawner, S: Sender, R: Receiver> MuxerBuilder<E, S, R> {
394 pub fn with_backup(mut self) -> MuxerBuilderWithBackup<E, S, R> {
396 let (tx, rx) = mpsc::channel(self.mux.mailbox_size);
397 self.mux.backup = Some(tx);
398
399 MuxerBuilderWithBackup {
400 mux: self.mux,
401 mux_handle: self.mux_handle,
402 backup_rx: rx,
403 }
404 }
405
406 pub fn with_global_sender(self) -> MuxerBuilderWithGlobalSender<E, S, R> {
408 let global_sender = GlobalSender::new(self.mux.sender.clone());
409
410 MuxerBuilderWithGlobalSender {
411 mux: self.mux,
412 mux_handle: self.mux_handle,
413 global_sender,
414 }
415 }
416}
417
418pub struct MuxerBuilderWithBackup<E: Spawner, S: Sender, R: Receiver> {
420 mux: Muxer<E, S, R>,
421 mux_handle: MuxHandle<S, R>,
422 backup_rx: mpsc::Receiver<BackupResponse<R::PublicKey>>,
423}
424
425impl<E: Spawner, S: Sender, R: Receiver> MuxerBuilderWithBackup<E, S, R> {
426 pub fn with_global_sender(self) -> MuxerBuilderAllOpts<E, S, R> {
428 let global_sender = GlobalSender::new(self.mux.sender.clone());
429
430 MuxerBuilderAllOpts {
431 mux: self.mux,
432 mux_handle: self.mux_handle,
433 backup_rx: self.backup_rx,
434 global_sender,
435 }
436 }
437}
438
439impl<E: Spawner, S: Sender, R: Receiver> Builder for MuxerBuilderWithBackup<E, S, R> {
440 type Output = (
441 Muxer<E, S, R>,
442 MuxHandle<S, R>,
443 mpsc::Receiver<BackupResponse<R::PublicKey>>,
444 );
445
446 fn build(self) -> Self::Output {
447 (self.mux, self.mux_handle, self.backup_rx)
448 }
449}
450
451pub struct MuxerBuilderWithGlobalSender<E: Spawner, S: Sender, R: Receiver> {
453 mux: Muxer<E, S, R>,
454 mux_handle: MuxHandle<S, R>,
455 global_sender: GlobalSender<S>,
456}
457
458impl<E: Spawner, S: Sender, R: Receiver> MuxerBuilderWithGlobalSender<E, S, R> {
459 pub fn with_backup(mut self) -> MuxerBuilderAllOpts<E, S, R> {
461 let (tx, rx) = mpsc::channel(self.mux.mailbox_size);
462 self.mux.backup = Some(tx);
463
464 MuxerBuilderAllOpts {
465 mux: self.mux,
466 mux_handle: self.mux_handle,
467 backup_rx: rx,
468 global_sender: self.global_sender,
469 }
470 }
471}
472
473impl<E: Spawner, S: Sender, R: Receiver> Builder for MuxerBuilderWithGlobalSender<E, S, R> {
474 type Output = (Muxer<E, S, R>, MuxHandle<S, R>, GlobalSender<S>);
475
476 fn build(self) -> Self::Output {
477 (self.mux, self.mux_handle, self.global_sender)
478 }
479}
480
481pub struct MuxerBuilderAllOpts<E: Spawner, S: Sender, R: Receiver> {
483 mux: Muxer<E, S, R>,
484 mux_handle: MuxHandle<S, R>,
485 backup_rx: mpsc::Receiver<BackupResponse<R::PublicKey>>,
486 global_sender: GlobalSender<S>,
487}
488
489impl<E: Spawner, S: Sender, R: Receiver> Builder for MuxerBuilderAllOpts<E, S, R> {
490 type Output = (
491 Muxer<E, S, R>,
492 MuxHandle<S, R>,
493 mpsc::Receiver<BackupResponse<R::PublicKey>>,
494 GlobalSender<S>,
495 );
496
497 fn build(self) -> Self::Output {
498 (
499 self.mux,
500 self.mux_handle,
501 self.backup_rx,
502 self.global_sender,
503 )
504 }
505}
506
507#[cfg(test)]
508mod tests {
509 use super::*;
510 use crate::{
511 simulated::{self, Link, Network, Oracle},
512 Recipients,
513 };
514 use bytes::Bytes;
515 use commonware_cryptography::{
516 ed25519::{PrivateKey, PublicKey},
517 Signer,
518 };
519 use commonware_macros::{select, test_traced};
520 use commonware_runtime::{deterministic, Metrics, Quota, Runner};
521 use std::{num::NonZeroU32, time::Duration};
522
523 const LINK: Link = Link {
524 latency: Duration::from_millis(0),
525 jitter: Duration::from_millis(0),
526 success_rate: 1.0,
527 };
528 const CAPACITY: usize = 5usize;
529
530 const TEST_QUOTA: Quota = Quota::per_second(NonZeroU32::MAX);
532
533 fn start_network(context: deterministic::Context) -> Oracle<PublicKey, deterministic::Context> {
535 let (network, oracle) = Network::new(
536 context.with_label("network"),
537 simulated::Config {
538 max_size: 1024 * 1024,
539 disconnect_on_block: true,
540 tracked_peer_sets: None,
541 },
542 );
543 network.start();
544 oracle
545 }
546
547 fn pk(seed: u64) -> PublicKey {
549 PrivateKey::from_seed(seed).public_key()
550 }
551
552 async fn link_bidirectional(
554 oracle: &mut Oracle<PublicKey, deterministic::Context>,
555 a: PublicKey,
556 b: PublicKey,
557 ) {
558 oracle.add_link(a.clone(), b.clone(), LINK).await.unwrap();
559 oracle.add_link(b, a, LINK).await.unwrap();
560 }
561
562 async fn create_peer(
564 context: &deterministic::Context,
565 oracle: &mut Oracle<PublicKey, deterministic::Context>,
566 seed: u64,
567 ) -> (
568 PublicKey,
569 MuxHandle<impl Sender<PublicKey = PublicKey>, impl Receiver<PublicKey = PublicKey>>,
570 ) {
571 let pubkey = pk(seed);
572 let (sender, receiver) = oracle
573 .control(pubkey.clone())
574 .register(0, TEST_QUOTA)
575 .await
576 .unwrap();
577 let (mux, handle) = Muxer::new(context.with_label("mux"), sender, receiver, CAPACITY);
578 mux.start();
579 (pubkey, handle)
580 }
581
582 async fn create_peer_with_backup_and_global_sender(
584 context: &deterministic::Context,
585 oracle: &mut Oracle<PublicKey, deterministic::Context>,
586 seed: u64,
587 ) -> (
588 PublicKey,
589 MuxHandle<impl Sender<PublicKey = PublicKey>, impl Receiver<PublicKey = PublicKey>>,
590 mpsc::Receiver<BackupResponse<PublicKey>>,
591 GlobalSender<simulated::Sender<PublicKey, deterministic::Context>>,
592 ) {
593 let pubkey = pk(seed);
594 let (sender, receiver) = oracle
595 .control(pubkey.clone())
596 .register(0, TEST_QUOTA)
597 .await
598 .unwrap();
599 let (mux, handle, backup, global_sender) =
600 Muxer::builder(context.with_label("mux"), sender, receiver, CAPACITY)
601 .with_backup()
602 .with_global_sender()
603 .build();
604 mux.start();
605 (pubkey, handle, backup, global_sender)
606 }
607
608 async fn send_burst<S: Sender>(txs: &mut [SubSender<S>], count: usize) {
610 for i in 0..count {
611 let payload = Bytes::from(vec![i as u8]);
612 for tx in txs.iter_mut() {
613 let _ = tx
614 .send(Recipients::All, payload.clone(), false)
615 .await
616 .unwrap();
617 }
618 }
619 }
620
621 async fn expect_n_messages(
623 rx: &mut SubReceiver<impl Receiver<PublicKey = PublicKey>>,
624 n: usize,
625 ) {
626 let mut count = 0;
627 loop {
628 select! {
629 res = rx.recv() => {
630 res.expect("should have received message");
631 count += 1;
632 },
633 }
634
635 if count >= n {
636 break;
637 }
638 }
639 assert_eq!(n, count);
640 }
641
642 async fn expect_n_messages_with_backup(
644 rx: &mut SubReceiver<impl Receiver<PublicKey = PublicKey>>,
645 backup_rx: &mut mpsc::Receiver<BackupResponse<PublicKey>>,
646 n: usize,
647 n_backup: usize,
648 ) {
649 let mut count_std = 0;
650 let mut count_backup = 0;
651 loop {
652 select! {
653 res = rx.recv() => {
654 res.expect("should have received message");
655 count_std += 1;
656 },
657 res = backup_rx.next() => {
658 res.expect("should have received message");
659 count_backup += 1;
660 },
661 }
662
663 if count_std >= n && count_backup >= n_backup {
664 break;
665 }
666 }
667 assert_eq!(n, count_std);
668 assert_eq!(n_backup, count_backup);
669 }
670
671 #[test]
672 fn test_basic_routing() {
673 let executor = deterministic::Runner::default();
675 executor.start(|context| async move {
676 let mut oracle = start_network(context.clone());
677
678 let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
679 let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
680 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
681
682 let (_, mut sub_rx1) = handle1.register(7).await.unwrap();
683 let (mut sub_tx2, _) = handle2.register(7).await.unwrap();
684
685 let payload = Bytes::from_static(b"hello");
687 let _ = sub_tx2
688 .send(Recipients::One(pk1.clone()), payload.clone(), false)
689 .await
690 .unwrap();
691 let (from, bytes) = sub_rx1.recv().await.unwrap();
692 assert_eq!(from, pk2);
693 assert_eq!(bytes, payload);
694 });
695 }
696
697 #[test]
698 fn test_multiple_routes() {
699 let executor = deterministic::Runner::default();
701 executor.start(|context| async move {
702 let mut oracle = start_network(context.clone());
703
704 let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
705 let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
706 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
707
708 let (_, mut rx_a) = handle1.register(10).await.unwrap();
709 let (_, mut rx_b) = handle1.register(20).await.unwrap();
710
711 let (mut tx2_a, _) = handle2.register(10).await.unwrap();
712 let (mut tx2_b, _) = handle2.register(20).await.unwrap();
713
714 let payload_a = Bytes::from_static(b"A");
715 let payload_b = Bytes::from_static(b"B");
716 let _ = tx2_a
717 .send(Recipients::One(pk1.clone()), payload_a.clone(), false)
718 .await
719 .unwrap();
720 let _ = tx2_b
721 .send(Recipients::One(pk1.clone()), payload_b.clone(), false)
722 .await
723 .unwrap();
724
725 let (from_a, bytes_a) = rx_a.recv().await.unwrap();
726 assert_eq!(from_a, pk2);
727 assert_eq!(bytes_a, payload_a);
728
729 let (from_b, bytes_b) = rx_b.recv().await.unwrap();
730 assert_eq!(from_b, pk2);
731 assert_eq!(bytes_b, payload_b);
732 });
733 }
734
735 #[test_traced]
736 fn test_mailbox_capacity_drops_when_full() {
737 let executor = deterministic::Runner::default();
740 executor.start(|context| async move {
741 let mut oracle = start_network(context.clone());
742
743 let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
744 let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
745 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
746
747 let (tx1, _) = handle1.register(99).await.unwrap();
749 let (tx2, _) = handle1.register(100).await.unwrap();
750 let (_, mut rx1) = handle2.register(99).await.unwrap();
751 let (_, mut rx2) = handle2.register(100).await.unwrap();
752
753 send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
756
757 expect_n_messages(&mut rx1, CAPACITY).await;
759 expect_n_messages(&mut rx2, CAPACITY).await;
760 });
761 }
762
763 #[test]
764 fn test_drop_subchannel_receiver_deregisters_route() {
765 let executor = deterministic::Runner::default();
768 executor.start(|context| async move {
769 let mut oracle = start_network(context.clone());
770
771 let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
772 let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
773 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
774
775 let (tx1, _) = handle1.register(99).await.unwrap();
777 let (tx2, _) = handle1.register(100).await.unwrap();
778 let (_, rx1) = handle2.register(99).await.unwrap();
779 let (_, mut rx2) = handle2.register(100).await.unwrap();
780
781 drop(rx1);
783
784 send_burst(&mut [tx1, tx2], CAPACITY).await;
787
788 expect_n_messages(&mut rx2, CAPACITY).await;
790 });
791 }
792
793 #[test]
794 fn test_drop_messages_for_unregistered_subchannel() {
795 let executor = deterministic::Runner::default();
798 executor.start(|context| async move {
799 let mut oracle = start_network(context.clone());
800
801 let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
802 let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
803 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
804
805 let (tx1, _) = handle1.register(1).await.unwrap();
807 let (tx2, _) = handle1.register(2).await.unwrap();
808 let (_, mut rx2) = handle2.register(2).await.unwrap();
810
811 send_burst(&mut [tx1, tx2], CAPACITY).await;
815
816 expect_n_messages(&mut rx2, CAPACITY).await;
818 });
819 }
820
821 #[test]
822 fn test_backup_for_unregistered_subchannel() {
823 let executor = deterministic::Runner::default();
826 executor.start(|context| async move {
827 let mut oracle = start_network(context.clone());
828
829 let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
830 let (pk2, mut handle2, mut backup2, _) =
831 create_peer_with_backup_and_global_sender(&context, &mut oracle, 1).await;
832 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
833
834 let (tx1, _) = handle1.register(1).await.unwrap();
836 let (tx2, _) = handle1.register(2).await.unwrap();
837 let (_, mut rx2) = handle2.register(2).await.unwrap();
839
840 send_burst(&mut [tx1, tx2], CAPACITY).await;
843
844 expect_n_messages_with_backup(&mut rx2, &mut backup2, CAPACITY, CAPACITY).await;
846 });
847 }
848
849 #[test]
850 fn test_backup_for_unregistered_subchannel_response() {
851 let executor = deterministic::Runner::default();
854 executor.start(|context| async move {
855 let mut oracle = start_network(context.clone());
856
857 let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
858 let (pk2, _handle2, mut backup2, mut global_sender2) =
859 create_peer_with_backup_and_global_sender(&context, &mut oracle, 1).await;
860 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
861
862 let (tx1, mut rx1) = handle1.register(1).await.unwrap();
864 send_burst(&mut [tx1], 1).await;
868
869 let (subchannel, (from, _)) = backup2.next().await.unwrap();
871 assert_eq!(subchannel, 1);
872 assert_eq!(from, pk1);
873 global_sender2
874 .send(subchannel, Recipients::One(pk1), &b"TEST"[..], true)
875 .await
876 .unwrap();
877
878 let (from, bytes) = rx1.recv().await.unwrap();
880 assert_eq!(from, pk2);
881 assert_eq!(bytes.as_ref(), b"TEST");
882 });
883 }
884
885 #[test]
886 fn test_message_dropped_for_closed_subchannel() {
887 let executor = deterministic::Runner::default();
892 executor.start(|context| async move {
893 let mut oracle = start_network(context.clone());
894
895 let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
896 let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
897 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
898
899 let (tx1, _) = handle1.register(1).await.unwrap();
901 let (tx2, _) = handle1.register(2).await.unwrap();
902 let (_, mut rx1) = handle2.register(1).await.unwrap();
903 let (_, mut rx2) = handle2.register(2).await.unwrap();
904
905 send_burst(&mut [tx1.clone()], CAPACITY).await;
907 expect_n_messages(&mut rx1, CAPACITY).await;
908
909 send_burst(&mut [tx2.clone()], CAPACITY).await;
911 expect_n_messages(&mut rx2, CAPACITY).await;
912
913 rx1.receiver.close();
915
916 send_burst(&mut [tx1, tx2], CAPACITY).await;
919
920 expect_n_messages(&mut rx2, CAPACITY).await;
922 });
923 }
924
925 #[test]
926 fn test_dropped_backup_channel_doesnt_block() {
927 let executor = deterministic::Runner::default();
930 executor.start(|context| async move {
931 let mut oracle = start_network(context.clone());
932
933 let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
934 let (pk2, mut handle2, backup2, _) =
935 create_peer_with_backup_and_global_sender(&context, &mut oracle, 1).await;
936 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
937
938 drop(backup2);
940
941 let (tx1, _) = handle1.register(1).await.unwrap();
943 let (tx2, _) = handle1.register(2).await.unwrap();
944 let (_, mut rx2) = handle2.register(2).await.unwrap();
946
947 send_burst(&mut [tx1, tx2], CAPACITY).await;
951
952 expect_n_messages(&mut rx2, CAPACITY).await;
954 });
955 }
956
957 #[test]
958 fn test_duplicate_registration() {
959 let executor = deterministic::Runner::default();
961 executor.start(|context| async move {
962 let mut oracle = start_network(context.clone());
963
964 let (_pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
965
966 let (_, _rx) = handle1.register(7).await.unwrap();
968
969 assert!(matches!(
971 handle1.register(7).await,
972 Err(Error::AlreadyRegistered(_))
973 ));
974 });
975 }
976
977 #[test]
978 fn test_register_after_deregister() {
979 let executor = deterministic::Runner::default();
981 executor.start(|context| async move {
982 let mut oracle = start_network(context.clone());
983
984 let (_, mut handle) = create_peer(&context, &mut oracle, 0).await;
985 let (_, rx) = handle.register(7).await.unwrap();
986 drop(rx);
987
988 handle.register(7).await.unwrap();
990 });
991 }
992}