1use crate::{Channel, CheckedSender, LimitedSender, Message, Receiver, Recipients, Sender};
12use bytes::{BufMut, Bytes, BytesMut};
13use commonware_codec::{varint::UInt, EncodeSize, Error as CodecError, ReadExt, Write};
14use commonware_macros::select_loop;
15use commonware_runtime::{spawn_cell, ContextCell, Handle, Spawner};
16use futures::{
17 channel::{mpsc, oneshot},
18 SinkExt, StreamExt,
19};
20use std::{collections::HashMap, fmt::Debug, time::SystemTime};
21use thiserror::Error;
22use tracing::debug;
23
24#[derive(Error, Debug)]
26pub enum Error {
27 #[error("subchannel already registered: {0}")]
28 AlreadyRegistered(Channel),
29 #[error("muxer is closed")]
30 Closed,
31 #[error("recv failed")]
32 RecvFailed,
33}
34
35pub fn parse(mut bytes: Bytes) -> Result<(Channel, Bytes), CodecError> {
37 let subchannel: Channel = UInt::read(&mut bytes)?.into();
38 Ok((subchannel, bytes))
39}
40
41enum Control<R: Receiver> {
43 Register {
44 subchannel: Channel,
45 sender: oneshot::Sender<mpsc::Receiver<Message<R::PublicKey>>>,
46 },
47 Deregister {
48 subchannel: Channel,
49 },
50}
51
52type Routes<P> = HashMap<Channel, mpsc::Sender<Message<P>>>;
54
55type BackupResponse<P> = (Channel, Message<P>);
58
59pub struct Muxer<E: Spawner, S: Sender, R: Receiver> {
61 context: ContextCell<E>,
62 sender: S,
63 receiver: R,
64 mailbox_size: usize,
65 control_rx: mpsc::UnboundedReceiver<Control<R>>,
66 routes: Routes<R::PublicKey>,
67 backup: Option<mpsc::Sender<BackupResponse<R::PublicKey>>>,
68}
69
70impl<E: Spawner, S: Sender, R: Receiver> Muxer<E, S, R> {
71 pub fn new(context: E, sender: S, receiver: R, mailbox_size: usize) -> (Self, MuxHandle<S, R>) {
74 Self::builder(context, sender, receiver, mailbox_size).build()
75 }
76
77 pub fn builder(
79 context: E,
80 sender: S,
81 receiver: R,
82 mailbox_size: usize,
83 ) -> MuxerBuilder<E, S, R> {
84 let (control_tx, control_rx) = mpsc::unbounded();
85 let mux = Self {
86 context: ContextCell::new(context),
87 sender,
88 receiver,
89 mailbox_size,
90 control_rx,
91 routes: HashMap::new(),
92 backup: None,
93 };
94
95 let mux_handle = MuxHandle {
96 sender: mux.sender.clone(),
97 control_tx,
98 };
99
100 MuxerBuilder { mux, mux_handle }
101 }
102
103 pub fn start(mut self) -> Handle<Result<(), R::Error>> {
105 spawn_cell!(self.context, self.run().await)
106 }
107
108 pub async fn run(mut self) -> Result<(), R::Error> {
113 select_loop! {
114 self.context,
115 on_stopped => {
116 debug!("context shutdown, stopping muxer");
117 },
118 control = self.control_rx.next() => {
121 match control {
122 Some(Control::Register { subchannel, sender }) => {
123 if self.routes.contains_key(&subchannel) {
125 continue;
126 }
127
128 let (tx, rx) = mpsc::channel(self.mailbox_size);
130 self.routes.insert(subchannel, tx);
131 let _ = sender.send(rx);
132 },
133 Some(Control::Deregister { subchannel }) => {
134 self.routes.remove(&subchannel);
136 },
137 None => {
138 return Ok(());
141 }
142 }
143 },
144 message = self.receiver.recv() => {
146 let (pk, bytes) = message?;
148 let (subchannel, bytes) = match parse(bytes) {
149 Ok(parsed) => parsed,
150 Err(_) => {
151 debug!(?pk, "invalid message: missing subchannel");
152 continue;
153 }
154 };
155
156 let Some(sender) = self.routes.get_mut(&subchannel) else {
158 if let Some(backup) = &mut self.backup {
160 if let Err(e) = backup.send((subchannel, (pk, bytes))).await {
161 debug!(?subchannel, ?e, "failed to send message to backup channel");
162 }
163 }
164
165 continue;
168 };
169
170 if let Err(e) = sender.send((pk, bytes)).await {
172 self.routes.remove(&subchannel);
174
175 debug!(?subchannel, ?e, "failed to send message to subchannel");
177
178 }
181 }
182 }
183
184 Ok(())
185 }
186}
187
188#[derive(Clone)]
190pub struct MuxHandle<S: Sender, R: Receiver> {
191 sender: S,
192 control_tx: mpsc::UnboundedSender<Control<R>>,
193}
194
195impl<S: Sender, R: Receiver> MuxHandle<S, R> {
196 pub async fn register(
201 &mut self,
202 subchannel: Channel,
203 ) -> Result<(SubSender<S>, SubReceiver<R>), Error> {
204 let (tx, rx) = oneshot::channel();
205 self.control_tx
206 .send(Control::Register {
207 subchannel,
208 sender: tx,
209 })
210 .await
211 .map_err(|_| Error::Closed)?;
212 let receiver = rx.await.map_err(|_| Error::AlreadyRegistered(subchannel))?;
213
214 Ok((
215 SubSender {
216 subchannel,
217 inner: GlobalSender::new(self.sender.clone()),
218 },
219 SubReceiver {
220 receiver,
221 control_tx: Some(self.control_tx.clone()),
222 subchannel,
223 },
224 ))
225 }
226}
227
228#[derive(Clone, Debug)]
230pub struct SubSender<S: Sender> {
231 inner: GlobalSender<S>,
232 subchannel: Channel,
233}
234
235impl<S: Sender> LimitedSender for SubSender<S> {
236 type PublicKey = S::PublicKey;
237 type Checked<'a> = CheckedGlobalSender<'a, S>;
238
239 async fn check(
240 &mut self,
241 recipients: Recipients<Self::PublicKey>,
242 ) -> Result<Self::Checked<'_>, SystemTime> {
243 self.inner
244 .check(recipients)
245 .await
246 .map(|checked| checked.with_subchannel(self.subchannel))
247 }
248}
249
250pub struct SubReceiver<R: Receiver> {
252 receiver: mpsc::Receiver<Message<R::PublicKey>>,
253 control_tx: Option<mpsc::UnboundedSender<Control<R>>>,
254 subchannel: Channel,
255}
256
257impl<R: Receiver> Receiver for SubReceiver<R> {
258 type Error = Error;
259 type PublicKey = R::PublicKey;
260
261 async fn recv(&mut self) -> Result<Message<Self::PublicKey>, Self::Error> {
262 self.receiver.next().await.ok_or(Error::RecvFailed)
263 }
264}
265
266impl<R: Receiver> Debug for SubReceiver<R> {
267 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
268 write!(f, "SubReceiver({})", self.subchannel)
269 }
270}
271
272impl<R: Receiver> Drop for SubReceiver<R> {
273 fn drop(&mut self) {
274 let control_tx = self
276 .control_tx
277 .take()
278 .expect("SubReceiver::drop called twice");
279
280 let _ = control_tx.unbounded_send(Control::Deregister {
282 subchannel: self.subchannel,
283 });
284 }
285}
286
287#[derive(Clone, Debug)]
289pub struct GlobalSender<S: Sender> {
290 inner: S,
291}
292
293impl<S: Sender> GlobalSender<S> {
294 pub const fn new(inner: S) -> Self {
296 Self { inner }
297 }
298
299 pub async fn send(
301 &mut self,
302 subchannel: Channel,
303 recipients: Recipients<S::PublicKey>,
304 payload: Bytes,
305 priority: bool,
306 ) -> Result<Vec<S::PublicKey>, <S::Checked<'_> as CheckedSender>::Error> {
307 match self.check(recipients).await {
308 Ok(checked) => {
309 checked
310 .with_subchannel(subchannel)
311 .send(payload, priority)
312 .await
313 }
314 Err(_) => Ok(Vec::new()),
315 }
316 }
317}
318
319impl<S: Sender> LimitedSender for GlobalSender<S> {
320 type PublicKey = S::PublicKey;
321 type Checked<'a> = CheckedGlobalSender<'a, S>;
322
323 async fn check(
324 &mut self,
325 recipients: Recipients<Self::PublicKey>,
326 ) -> Result<Self::Checked<'_>, SystemTime> {
327 self.inner
328 .check(recipients)
329 .await
330 .map(|checked| CheckedGlobalSender {
331 subchannel: None,
332 inner: checked,
333 })
334 }
335}
336
337pub struct CheckedGlobalSender<'a, S: Sender> {
339 subchannel: Option<Channel>,
340 inner: S::Checked<'a>,
341}
342
343impl<'a, S: Sender> CheckedGlobalSender<'a, S> {
344 pub const fn with_subchannel(mut self, subchannel: Channel) -> Self {
346 self.subchannel = Some(subchannel);
347 self
348 }
349}
350
351impl<'a, S: Sender> CheckedSender for CheckedGlobalSender<'a, S> {
352 type PublicKey = S::PublicKey;
353 type Error = <S::Checked<'a> as CheckedSender>::Error;
354
355 async fn send(
356 self,
357 message: Bytes,
358 priority: bool,
359 ) -> Result<Vec<Self::PublicKey>, Self::Error> {
360 let subchannel = UInt(self.subchannel.expect("subchannel not set"));
361 let mut buf = BytesMut::with_capacity(subchannel.encode_size() + message.len());
362 subchannel.write(&mut buf);
363 buf.put_slice(&message);
364 self.inner.send(buf.freeze(), priority).await
365 }
366}
367
368pub trait Builder {
370 type Output;
372
373 fn build(self) -> Self::Output;
375}
376
377pub struct MuxerBuilder<E: Spawner, S: Sender, R: Receiver> {
379 mux: Muxer<E, S, R>,
380 mux_handle: MuxHandle<S, R>,
381}
382
383impl<E: Spawner, S: Sender, R: Receiver> Builder for MuxerBuilder<E, S, R> {
384 type Output = (Muxer<E, S, R>, MuxHandle<S, R>);
385
386 fn build(self) -> Self::Output {
387 (self.mux, self.mux_handle)
388 }
389}
390
391impl<E: Spawner, S: Sender, R: Receiver> MuxerBuilder<E, S, R> {
392 pub fn with_backup(mut self) -> MuxerBuilderWithBackup<E, S, R> {
394 let (tx, rx) = mpsc::channel(self.mux.mailbox_size);
395 self.mux.backup = Some(tx);
396
397 MuxerBuilderWithBackup {
398 mux: self.mux,
399 mux_handle: self.mux_handle,
400 backup_rx: rx,
401 }
402 }
403
404 pub fn with_global_sender(self) -> MuxerBuilderWithGlobalSender<E, S, R> {
406 let global_sender = GlobalSender::new(self.mux.sender.clone());
407
408 MuxerBuilderWithGlobalSender {
409 mux: self.mux,
410 mux_handle: self.mux_handle,
411 global_sender,
412 }
413 }
414}
415
416pub struct MuxerBuilderWithBackup<E: Spawner, S: Sender, R: Receiver> {
418 mux: Muxer<E, S, R>,
419 mux_handle: MuxHandle<S, R>,
420 backup_rx: mpsc::Receiver<BackupResponse<R::PublicKey>>,
421}
422
423impl<E: Spawner, S: Sender, R: Receiver> MuxerBuilderWithBackup<E, S, R> {
424 pub fn with_global_sender(self) -> MuxerBuilderAllOpts<E, S, R> {
426 let global_sender = GlobalSender::new(self.mux.sender.clone());
427
428 MuxerBuilderAllOpts {
429 mux: self.mux,
430 mux_handle: self.mux_handle,
431 backup_rx: self.backup_rx,
432 global_sender,
433 }
434 }
435}
436
437impl<E: Spawner, S: Sender, R: Receiver> Builder for MuxerBuilderWithBackup<E, S, R> {
438 type Output = (
439 Muxer<E, S, R>,
440 MuxHandle<S, R>,
441 mpsc::Receiver<BackupResponse<R::PublicKey>>,
442 );
443
444 fn build(self) -> Self::Output {
445 (self.mux, self.mux_handle, self.backup_rx)
446 }
447}
448
449pub struct MuxerBuilderWithGlobalSender<E: Spawner, S: Sender, R: Receiver> {
451 mux: Muxer<E, S, R>,
452 mux_handle: MuxHandle<S, R>,
453 global_sender: GlobalSender<S>,
454}
455
456impl<E: Spawner, S: Sender, R: Receiver> MuxerBuilderWithGlobalSender<E, S, R> {
457 pub fn with_backup(mut self) -> MuxerBuilderAllOpts<E, S, R> {
459 let (tx, rx) = mpsc::channel(self.mux.mailbox_size);
460 self.mux.backup = Some(tx);
461
462 MuxerBuilderAllOpts {
463 mux: self.mux,
464 mux_handle: self.mux_handle,
465 backup_rx: rx,
466 global_sender: self.global_sender,
467 }
468 }
469}
470
471impl<E: Spawner, S: Sender, R: Receiver> Builder for MuxerBuilderWithGlobalSender<E, S, R> {
472 type Output = (Muxer<E, S, R>, MuxHandle<S, R>, GlobalSender<S>);
473
474 fn build(self) -> Self::Output {
475 (self.mux, self.mux_handle, self.global_sender)
476 }
477}
478
479pub struct MuxerBuilderAllOpts<E: Spawner, S: Sender, R: Receiver> {
481 mux: Muxer<E, S, R>,
482 mux_handle: MuxHandle<S, R>,
483 backup_rx: mpsc::Receiver<BackupResponse<R::PublicKey>>,
484 global_sender: GlobalSender<S>,
485}
486
487impl<E: Spawner, S: Sender, R: Receiver> Builder for MuxerBuilderAllOpts<E, S, R> {
488 type Output = (
489 Muxer<E, S, R>,
490 MuxHandle<S, R>,
491 mpsc::Receiver<BackupResponse<R::PublicKey>>,
492 GlobalSender<S>,
493 );
494
495 fn build(self) -> Self::Output {
496 (
497 self.mux,
498 self.mux_handle,
499 self.backup_rx,
500 self.global_sender,
501 )
502 }
503}
504
505#[cfg(test)]
506mod tests {
507 use super::*;
508 use crate::{
509 simulated::{self, Link, Network, Oracle},
510 Recipients,
511 };
512 use bytes::Bytes;
513 use commonware_cryptography::{
514 ed25519::{PrivateKey, PublicKey},
515 Signer,
516 };
517 use commonware_macros::{select, test_traced};
518 use commonware_runtime::{deterministic, Metrics, Quota, Runner};
519 use std::{num::NonZeroU32, time::Duration};
520
521 const LINK: Link = Link {
522 latency: Duration::from_millis(0),
523 jitter: Duration::from_millis(0),
524 success_rate: 1.0,
525 };
526 const CAPACITY: usize = 5usize;
527
528 const TEST_QUOTA: Quota = Quota::per_second(NonZeroU32::MAX);
530
531 fn start_network(context: deterministic::Context) -> Oracle<PublicKey, deterministic::Context> {
533 let (network, oracle) = Network::new(
534 context.with_label("network"),
535 simulated::Config {
536 max_size: 1024 * 1024,
537 disconnect_on_block: true,
538 tracked_peer_sets: None,
539 },
540 );
541 network.start();
542 oracle
543 }
544
545 fn pk(seed: u64) -> PublicKey {
547 PrivateKey::from_seed(seed).public_key()
548 }
549
550 async fn link_bidirectional(
552 oracle: &mut Oracle<PublicKey, deterministic::Context>,
553 a: PublicKey,
554 b: PublicKey,
555 ) {
556 oracle.add_link(a.clone(), b.clone(), LINK).await.unwrap();
557 oracle.add_link(b, a, LINK).await.unwrap();
558 }
559
560 async fn create_peer(
562 context: &deterministic::Context,
563 oracle: &mut Oracle<PublicKey, deterministic::Context>,
564 seed: u64,
565 ) -> (
566 PublicKey,
567 MuxHandle<impl Sender<PublicKey = PublicKey>, impl Receiver<PublicKey = PublicKey>>,
568 ) {
569 let pubkey = pk(seed);
570 let (sender, receiver) = oracle
571 .control(pubkey.clone())
572 .register(0, TEST_QUOTA)
573 .await
574 .unwrap();
575 let (mux, handle) = Muxer::new(context.with_label("mux"), sender, receiver, CAPACITY);
576 mux.start();
577 (pubkey, handle)
578 }
579
580 async fn create_peer_with_backup_and_global_sender(
582 context: &deterministic::Context,
583 oracle: &mut Oracle<PublicKey, deterministic::Context>,
584 seed: u64,
585 ) -> (
586 PublicKey,
587 MuxHandle<impl Sender<PublicKey = PublicKey>, impl Receiver<PublicKey = PublicKey>>,
588 mpsc::Receiver<BackupResponse<PublicKey>>,
589 GlobalSender<simulated::Sender<PublicKey, deterministic::Context>>,
590 ) {
591 let pubkey = pk(seed);
592 let (sender, receiver) = oracle
593 .control(pubkey.clone())
594 .register(0, TEST_QUOTA)
595 .await
596 .unwrap();
597 let (mux, handle, backup, global_sender) =
598 Muxer::builder(context.with_label("mux"), sender, receiver, CAPACITY)
599 .with_backup()
600 .with_global_sender()
601 .build();
602 mux.start();
603 (pubkey, handle, backup, global_sender)
604 }
605
606 async fn send_burst<S: Sender>(txs: &mut [SubSender<S>], count: usize) {
608 for i in 0..count {
609 let payload = Bytes::from(vec![i as u8]);
610 for tx in txs.iter_mut() {
611 let _ = tx
612 .send(Recipients::All, payload.clone(), false)
613 .await
614 .unwrap();
615 }
616 }
617 }
618
619 async fn expect_n_messages(
621 rx: &mut SubReceiver<impl Receiver<PublicKey = PublicKey>>,
622 n: usize,
623 ) {
624 let mut count = 0;
625 loop {
626 select! {
627 res = rx.recv() => {
628 res.expect("should have received message");
629 count += 1;
630 },
631 }
632
633 if count >= n {
634 break;
635 }
636 }
637 assert_eq!(n, count);
638 }
639
640 async fn expect_n_messages_with_backup(
642 rx: &mut SubReceiver<impl Receiver<PublicKey = PublicKey>>,
643 backup_rx: &mut mpsc::Receiver<BackupResponse<PublicKey>>,
644 n: usize,
645 n_backup: usize,
646 ) {
647 let mut count_std = 0;
648 let mut count_backup = 0;
649 loop {
650 select! {
651 res = rx.recv() => {
652 res.expect("should have received message");
653 count_std += 1;
654 },
655 res = backup_rx.next() => {
656 res.expect("should have received message");
657 count_backup += 1;
658 },
659 }
660
661 if count_std >= n && count_backup >= n_backup {
662 break;
663 }
664 }
665 assert_eq!(n, count_std);
666 assert_eq!(n_backup, count_backup);
667 }
668
669 #[test]
670 fn test_basic_routing() {
671 let executor = deterministic::Runner::default();
673 executor.start(|context| async move {
674 let mut oracle = start_network(context.clone());
675
676 let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
677 let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
678 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
679
680 let (_, mut sub_rx1) = handle1.register(7).await.unwrap();
681 let (mut sub_tx2, _) = handle2.register(7).await.unwrap();
682
683 let payload = Bytes::from_static(b"hello");
685 let _ = sub_tx2
686 .send(Recipients::One(pk1.clone()), payload.clone(), false)
687 .await
688 .unwrap();
689 let (from, bytes) = sub_rx1.recv().await.unwrap();
690 assert_eq!(from, pk2);
691 assert_eq!(bytes, payload);
692 });
693 }
694
695 #[test]
696 fn test_multiple_routes() {
697 let executor = deterministic::Runner::default();
699 executor.start(|context| async move {
700 let mut oracle = start_network(context.clone());
701
702 let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
703 let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
704 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
705
706 let (_, mut rx_a) = handle1.register(10).await.unwrap();
707 let (_, mut rx_b) = handle1.register(20).await.unwrap();
708
709 let (mut tx2_a, _) = handle2.register(10).await.unwrap();
710 let (mut tx2_b, _) = handle2.register(20).await.unwrap();
711
712 let payload_a = Bytes::from_static(b"A");
713 let payload_b = Bytes::from_static(b"B");
714 let _ = tx2_a
715 .send(Recipients::One(pk1.clone()), payload_a.clone(), false)
716 .await
717 .unwrap();
718 let _ = tx2_b
719 .send(Recipients::One(pk1.clone()), payload_b.clone(), false)
720 .await
721 .unwrap();
722
723 let (from_a, bytes_a) = rx_a.recv().await.unwrap();
724 assert_eq!(from_a, pk2);
725 assert_eq!(bytes_a, payload_a);
726
727 let (from_b, bytes_b) = rx_b.recv().await.unwrap();
728 assert_eq!(from_b, pk2);
729 assert_eq!(bytes_b, payload_b);
730 });
731 }
732
733 #[test_traced]
734 fn test_mailbox_capacity_blocks() {
735 let executor = deterministic::Runner::default();
737 executor.start(|context| async move {
738 let mut oracle = start_network(context.clone());
739
740 let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
741 let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
742 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
743
744 let (tx1, _) = handle1.register(99).await.unwrap();
746 let (tx2, _) = handle1.register(100).await.unwrap();
747 let (_, mut rx1) = handle2.register(99).await.unwrap();
748 let (_, mut rx2) = handle2.register(100).await.unwrap();
749
750 send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
752
753 expect_n_messages(&mut rx2, CAPACITY).await;
755
756 expect_n_messages(&mut rx1, CAPACITY * 2).await;
758
759 expect_n_messages(&mut rx2, CAPACITY).await;
761 });
762 }
763
764 #[test]
765 fn test_drop_a_full_subchannel() {
766 let executor = deterministic::Runner::default();
768 executor.start(|context| async move {
769 let mut oracle = start_network(context.clone());
770
771 let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
772 let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
773 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
774
775 let (tx1, _) = handle1.register(99).await.unwrap();
777 let (tx2, _) = handle1.register(100).await.unwrap();
778 let (_, rx1) = handle2.register(99).await.unwrap();
779 let (_, mut rx2) = handle2.register(100).await.unwrap();
780
781 send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
783
784 expect_n_messages(&mut rx2, CAPACITY).await;
786
787 drop(rx1);
789
790 expect_n_messages(&mut rx2, CAPACITY).await;
792 });
793 }
794
795 #[test]
796 fn test_drop_messages_for_unregistered_subchannel() {
797 let executor = deterministic::Runner::default();
799 executor.start(|context| async move {
800 let mut oracle = start_network(context.clone());
801
802 let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
803 let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
804 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
805
806 let (tx1, _) = handle1.register(1).await.unwrap();
808 let (tx2, _) = handle1.register(2).await.unwrap();
809 let (_, mut rx2) = handle2.register(2).await.unwrap();
811
812 send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
814
815 expect_n_messages(&mut rx2, CAPACITY * 2).await;
817 });
818 }
819
820 #[test]
821 fn test_backup_for_unregistered_subchannel() {
822 let executor = deterministic::Runner::default();
825 executor.start(|context| async move {
826 let mut oracle = start_network(context.clone());
827
828 let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
829 let (pk2, mut handle2, mut backup2, _) =
830 create_peer_with_backup_and_global_sender(&context, &mut oracle, 1).await;
831 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
832
833 let (tx1, _) = handle1.register(1).await.unwrap();
835 let (tx2, _) = handle1.register(2).await.unwrap();
836 let (_, mut rx2) = handle2.register(2).await.unwrap();
838
839 send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
841
842 expect_n_messages_with_backup(&mut rx2, &mut backup2, CAPACITY * 2, CAPACITY * 2).await;
845 });
846 }
847
848 #[test]
849 fn test_backup_for_unregistered_subchannel_response() {
850 let executor = deterministic::Runner::default();
853 executor.start(|context| async move {
854 let mut oracle = start_network(context.clone());
855
856 let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
857 let (pk2, _handle2, mut backup2, mut global_sender2) =
858 create_peer_with_backup_and_global_sender(&context, &mut oracle, 1).await;
859 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
860
861 let (tx1, mut rx1) = handle1.register(1).await.unwrap();
863 send_burst(&mut [tx1], 1).await;
867
868 let (subchannel, (from, _)) = backup2.next().await.unwrap();
870 assert_eq!(subchannel, 1);
871 assert_eq!(from, pk1);
872 global_sender2
873 .send(
874 subchannel,
875 Recipients::One(pk1),
876 b"TEST".to_vec().into(),
877 true,
878 )
879 .await
880 .unwrap();
881
882 let (from, bytes) = rx1.recv().await.unwrap();
884 assert_eq!(from, pk2);
885 assert_eq!(bytes.as_ref(), b"TEST");
886 });
887 }
888
889 #[test]
890 fn test_message_dropped_for_closed_subchannel() {
891 let executor = deterministic::Runner::default();
896 executor.start(|context| async move {
897 let mut oracle = start_network(context.clone());
898
899 let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
900 let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
901 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
902
903 let (tx1, _) = handle1.register(1).await.unwrap();
905 let (tx2, _) = handle1.register(2).await.unwrap();
906 let (_, mut rx1) = handle2.register(1).await.unwrap();
907 let (_, mut rx2) = handle2.register(2).await.unwrap();
908
909 send_burst(&mut [tx1.clone()], CAPACITY * 2).await;
911
912 expect_n_messages(&mut rx1, CAPACITY * 2).await;
914
915 send_burst(&mut [tx2.clone()], CAPACITY * 2).await;
917
918 expect_n_messages(&mut rx2, CAPACITY * 2).await;
920
921 rx1.receiver.close();
923
924 send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
926
927 expect_n_messages(&mut rx2, CAPACITY * 2).await;
929 });
930 }
931
932 #[test]
933 fn test_dropped_backup_channel_doesnt_block() {
934 let executor = deterministic::Runner::default();
935 executor.start(|context| async move {
936 let mut oracle = start_network(context.clone());
937
938 let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
939 let (pk2, mut handle2, backup2, _) =
940 create_peer_with_backup_and_global_sender(&context, &mut oracle, 1).await;
941 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
942
943 drop(backup2);
945
946 let (tx1, _) = handle1.register(1).await.unwrap();
948 let (tx2, _) = handle1.register(2).await.unwrap();
949 let (_, mut rx2) = handle2.register(2).await.unwrap();
951
952 send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
954
955 expect_n_messages(&mut rx2, CAPACITY * 2).await;
957 });
958 }
959
960 #[test]
961 fn test_duplicate_registration() {
962 let executor = deterministic::Runner::default();
964 executor.start(|context| async move {
965 let mut oracle = start_network(context.clone());
966
967 let (_pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
968
969 let (_, _rx) = handle1.register(7).await.unwrap();
971
972 assert!(matches!(
974 handle1.register(7).await,
975 Err(Error::AlreadyRegistered(_))
976 ));
977 });
978 }
979
980 #[test]
981 fn test_register_after_deregister() {
982 let executor = deterministic::Runner::default();
984 executor.start(|context| async move {
985 let mut oracle = start_network(context.clone());
986
987 let (_, mut handle) = create_peer(&context, &mut oracle, 0).await;
988 let (_, rx) = handle.register(7).await.unwrap();
989 drop(rx);
990
991 handle.register(7).await.unwrap();
993 });
994 }
995}