1use crate::{Channel, Message, Receiver, Recipients, Sender};
12use bytes::{BufMut, Bytes, BytesMut};
13use commonware_codec::{varint::UInt, EncodeSize, ReadExt, Write};
14use commonware_macros::select;
15use commonware_runtime::{Handle, Spawner};
16use futures::{
17 channel::{mpsc, oneshot},
18 SinkExt, StreamExt,
19};
20use std::{collections::HashMap, fmt::Debug};
21use thiserror::Error;
22use tracing::debug;
23
24#[derive(Error, Debug)]
26pub enum Error {
27 #[error("subchannel already registered: {0}")]
28 AlreadyRegistered(Channel),
29 #[error("muxer is closed")]
30 Closed,
31 #[error("recv failed")]
32 RecvFailed,
33}
34
35enum Control<R: Receiver> {
37 Register {
38 subchannel: Channel,
39 sender: oneshot::Sender<mpsc::Receiver<Message<R::PublicKey>>>,
40 },
41 Deregister {
42 subchannel: Channel,
43 },
44}
45
46type Routes<P> = HashMap<Channel, mpsc::Sender<Message<P>>>;
48
49pub struct Muxer<E: Spawner, S: Sender, R: Receiver> {
51 context: E,
52 sender: S,
53 receiver: R,
54 mailbox_size: usize,
55 control_rx: mpsc::Receiver<Control<R>>,
56 routes: Routes<R::PublicKey>,
57}
58
59impl<E: Spawner, S: Sender, R: Receiver> Muxer<E, S, R> {
60 pub fn new(
63 context: E,
64 sender: S,
65 receiver: R,
66 mailbox_size: usize,
67 ) -> (Self, MuxHandle<E, S, R>) {
68 let (control_tx, control_rx) = mpsc::channel(mailbox_size);
69 let mux = Self {
70 context: context.clone(),
71 sender,
72 receiver,
73 mailbox_size,
74 control_rx,
75 routes: HashMap::new(),
76 };
77
78 let handle = MuxHandle {
79 context,
80 sender: mux.sender.clone(),
81 control_tx,
82 };
83
84 (mux, handle)
85 }
86
87 pub fn start(mut self) -> Handle<Result<(), R::Error>> {
89 self.context.spawn_ref()(self.run())
90 }
91
92 pub async fn run(mut self) -> Result<(), R::Error> {
97 loop {
98 select! {
99 control = self.control_rx.next() => {
101 match control {
102 Some(Control::Register { subchannel, sender }) => {
103 if self.routes.contains_key(&subchannel) {
105 continue;
106 }
107
108 let (tx, rx) = mpsc::channel(self.mailbox_size);
110 self.routes.insert(subchannel, tx);
111 let _ = sender.send(rx);
112 },
113 Some(Control::Deregister { subchannel }) => {
114 self.routes.remove(&subchannel);
116 },
117 None => {
118 return Ok(());
121 }
122 }
123 },
124 message = self.receiver.recv() => {
126 let (pk, mut bytes) = message?;
127
128 let subchannel: Channel = match UInt::read(&mut bytes) {
130 Ok(v) => v.into(),
131 Err(_) => {
132 debug!(?pk, "invalid message: missing subchannel");
133 continue;
134 }
135 };
136
137 let Some(sender) = self.routes.get_mut(&subchannel) else {
139 continue;
141 };
142
143 if let Err(e) = sender.send((pk, bytes)).await {
145 self.routes.remove(&subchannel);
147
148 debug!(?subchannel, ?e, "failed to send message to subchannel");
150 }
151 }
152 }
153 }
154 }
155}
156
157#[derive(Clone)]
159pub struct MuxHandle<E: Spawner, S: Sender, R: Receiver> {
160 context: E,
161 sender: S,
162 control_tx: mpsc::Sender<Control<R>>,
163}
164
165impl<E: Spawner, S: Sender, R: Receiver> MuxHandle<E, S, R> {
166 pub async fn register(
171 &mut self,
172 subchannel: Channel,
173 ) -> Result<(SubSender<S>, SubReceiver<E, R>), Error> {
174 let (tx, rx) = oneshot::channel();
175 self.control_tx
176 .send(Control::Register {
177 subchannel,
178 sender: tx,
179 })
180 .await
181 .map_err(|_| Error::Closed)?;
182 let receiver = rx.await.map_err(|_| Error::AlreadyRegistered(subchannel))?;
183
184 Ok((
185 SubSender {
186 subchannel,
187 inner: self.sender.clone(),
188 },
189 SubReceiver {
190 context: self.context.clone(),
191 receiver,
192 control_tx: Some(self.control_tx.clone()),
193 subchannel,
194 },
195 ))
196 }
197}
198
199#[derive(Clone, Debug)]
201pub struct SubSender<S: Sender> {
202 inner: S,
203 subchannel: Channel,
204}
205
206impl<S: Sender> Sender for SubSender<S> {
207 type Error = S::Error;
208 type PublicKey = S::PublicKey;
209
210 async fn send(
211 &mut self,
212 recipients: Recipients<S::PublicKey>,
213 payload: Bytes,
214 priority: bool,
215 ) -> Result<Vec<S::PublicKey>, S::Error> {
216 let subchannel = UInt(self.subchannel);
217 let mut buf = BytesMut::with_capacity(subchannel.encode_size() + payload.len());
218 subchannel.write(&mut buf);
219 buf.put_slice(&payload);
220 self.inner.send(recipients, buf.freeze(), priority).await
221 }
222}
223
224pub struct SubReceiver<E: Spawner, R: Receiver> {
226 context: E,
227 receiver: mpsc::Receiver<Message<R::PublicKey>>,
228 control_tx: Option<mpsc::Sender<Control<R>>>,
229 subchannel: Channel,
230}
231
232impl<E: Spawner, R: Receiver> Receiver for SubReceiver<E, R> {
233 type Error = Error;
234 type PublicKey = R::PublicKey;
235
236 async fn recv(&mut self) -> Result<Message<Self::PublicKey>, Self::Error> {
237 self.receiver.next().await.ok_or(Error::RecvFailed)
238 }
239}
240
241impl<E: Spawner, R: Receiver> Debug for SubReceiver<E, R> {
242 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243 write!(f, "SubReceiver({})", self.subchannel)
244 }
245}
246
247impl<E: Spawner, R: Receiver> Drop for SubReceiver<E, R> {
248 fn drop(&mut self) {
249 let mut control_tx = self
251 .control_tx
252 .take()
253 .expect("SubReceiver::drop called twice");
254
255 let subchannel = self.subchannel;
257 if control_tx
258 .try_send(Control::Deregister { subchannel })
259 .is_ok()
260 {
261 return;
262 }
263
264 self.context.spawn_ref()(async move {
266 let _ = control_tx.send(Control::Deregister { subchannel }).await;
267 });
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274 use crate::{
275 simulated::{Config as SimConfig, Link, Network, Oracle},
276 Recipients,
277 };
278 use bytes::Bytes;
279 use commonware_cryptography::{ed25519::PrivateKey, PrivateKeyExt, Signer};
280 use commonware_macros::{select, test_traced};
281 use commonware_runtime::{deterministic, Clock, Metrics, Runner};
282 use std::time::Duration;
283
284 type Pk = commonware_cryptography::ed25519::PublicKey;
285
286 const LINK: Link = Link {
287 latency: Duration::from_millis(0),
288 jitter: Duration::from_millis(0),
289 success_rate: 1.0,
290 };
291 const CAPACITY: usize = 5usize;
292
293 fn start_network(context: deterministic::Context) -> Oracle<Pk> {
295 let (network, oracle) = Network::new(
296 context.with_label("network"),
297 SimConfig {
298 max_size: 1024 * 1024,
299 },
300 );
301 network.start();
302 oracle
303 }
304
305 fn pk(seed: u64) -> Pk {
307 PrivateKey::from_seed(seed).public_key()
308 }
309
310 async fn link_bidirectional(oracle: &mut Oracle<Pk>, a: Pk, b: Pk) {
312 oracle.add_link(a.clone(), b.clone(), LINK).await.unwrap();
313 oracle.add_link(b, a, LINK).await.unwrap();
314 }
315
316 async fn create_peer<E: Spawner>(
318 context: &E,
319 oracle: &mut Oracle<Pk>,
320 seed: u64,
321 ) -> (
322 Pk,
323 MuxHandle<E, impl Sender<PublicKey = Pk>, impl Receiver<PublicKey = Pk>>,
324 ) {
325 let pubkey = pk(seed);
326 let (sender, receiver) = oracle.register(pubkey.clone(), 0).await.unwrap();
327 let (mux, handle) = Muxer::new(context.clone(), sender, receiver, CAPACITY);
328 mux.start();
329 (pubkey, handle)
330 }
331
332 async fn send_burst<S: Sender>(txs: &mut [SubSender<S>], count: usize) {
334 for i in 0..count {
335 let payload = Bytes::from(vec![i as u8]);
336 for tx in txs.iter_mut() {
337 let _ = tx
338 .send(Recipients::All, payload.clone(), false)
339 .await
340 .unwrap();
341 }
342 }
343 }
344
345 async fn expect_n_messages<E: Spawner + Clock>(
347 rx: &mut SubReceiver<E, impl Receiver<PublicKey = Pk>>,
348 n: usize,
349 context: &E,
350 ) {
351 let mut count = 0;
352 loop {
353 select! {
354 res = rx.recv() => {
355 res.expect("should have received message");
356 count += 1;
357 },
358 _ = context.sleep(Duration::from_millis(100)) => { break; },
359 }
360 }
361 assert_eq!(n, count);
362 }
363
364 #[test]
365 fn test_basic_routing() {
366 let executor = deterministic::Runner::default();
368 executor.start(|context| async move {
369 let mut oracle = start_network(context.clone());
370
371 let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
372 let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
373 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
374
375 let (_, mut sub_rx1) = handle1.register(7).await.unwrap();
376 let (mut sub_tx2, _) = handle2.register(7).await.unwrap();
377
378 let payload = Bytes::from_static(b"hello");
380 let _ = sub_tx2
381 .send(Recipients::One(pk1.clone()), payload.clone(), false)
382 .await
383 .unwrap();
384 let (from, bytes) = sub_rx1.recv().await.unwrap();
385 assert_eq!(from, pk2);
386 assert_eq!(bytes, payload);
387 });
388 }
389
390 #[test]
391 fn test_multiple_routes() {
392 let executor = deterministic::Runner::default();
394 executor.start(|context| async move {
395 let mut oracle = start_network(context.clone());
396
397 let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
398 let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
399 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
400
401 let (_, mut rx_a) = handle1.register(10).await.unwrap();
402 let (_, mut rx_b) = handle1.register(20).await.unwrap();
403
404 let (mut tx2_a, _) = handle2.register(10).await.unwrap();
405 let (mut tx2_b, _) = handle2.register(20).await.unwrap();
406
407 let payload_a = Bytes::from_static(b"A");
408 let payload_b = Bytes::from_static(b"B");
409 let _ = tx2_a
410 .send(Recipients::One(pk1.clone()), payload_a.clone(), false)
411 .await
412 .unwrap();
413 let _ = tx2_b
414 .send(Recipients::One(pk1.clone()), payload_b.clone(), false)
415 .await
416 .unwrap();
417
418 let (from_a, bytes_a) = rx_a.recv().await.unwrap();
419 assert_eq!(from_a, pk2);
420 assert_eq!(bytes_a, payload_a);
421
422 let (from_b, bytes_b) = rx_b.recv().await.unwrap();
423 assert_eq!(from_b, pk2);
424 assert_eq!(bytes_b, payload_b);
425 });
426 }
427
428 #[test_traced]
429 fn test_mailbox_capacity_blocks() {
430 let executor = deterministic::Runner::default();
432 executor.start(|context| async move {
433 let mut oracle = start_network(context.clone());
434
435 let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
436 let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
437 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
438
439 let (tx1, _) = handle1.register(99).await.unwrap();
441 let (tx2, _) = handle1.register(100).await.unwrap();
442 let (_, mut rx1) = handle2.register(99).await.unwrap();
443 let (_, mut rx2) = handle2.register(100).await.unwrap();
444
445 send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
447
448 expect_n_messages(&mut rx2, CAPACITY, &context).await;
450
451 expect_n_messages(&mut rx1, CAPACITY * 2, &context).await;
453
454 expect_n_messages(&mut rx2, CAPACITY, &context).await;
456 });
457 }
458
459 #[test]
460 fn test_drop_a_full_subchannel() {
461 let executor = deterministic::Runner::default();
463 executor.start(|context| async move {
464 let mut oracle = start_network(context.clone());
465
466 let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
467 let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
468 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
469
470 let (tx1, _) = handle1.register(99).await.unwrap();
472 let (tx2, _) = handle1.register(100).await.unwrap();
473 let (_, rx1) = handle2.register(99).await.unwrap();
474 let (_, mut rx2) = handle2.register(100).await.unwrap();
475
476 send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
478
479 context.sleep(Duration::from_millis(100)).await;
481
482 expect_n_messages(&mut rx2, CAPACITY, &context).await;
484
485 drop(rx1);
487
488 expect_n_messages(&mut rx2, CAPACITY, &context).await;
490 });
491 }
492
493 #[test]
494 fn test_drop_messages_for_unregistered_subchannel() {
495 let executor = deterministic::Runner::default();
497 executor.start(|context| async move {
498 let mut oracle = start_network(context.clone());
499
500 let (pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
501 let (pk2, mut handle2) = create_peer(&context, &mut oracle, 1).await;
502 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
503
504 let (tx1, _) = handle1.register(1).await.unwrap();
506 let (tx2, _) = handle1.register(2).await.unwrap();
507 let (_, mut rx2) = handle2.register(2).await.unwrap();
509
510 send_burst(&mut [tx1, tx2], CAPACITY * 2).await;
512
513 context.sleep(Duration::from_millis(100)).await;
515
516 expect_n_messages(&mut rx2, CAPACITY * 2, &context).await;
518 });
519 }
520
521 #[test]
522 fn test_duplicate_registration() {
523 let executor = deterministic::Runner::default();
525 executor.start(|context| async move {
526 let mut oracle = start_network(context.clone());
527
528 let (_pk1, mut handle1) = create_peer(&context, &mut oracle, 0).await;
529
530 let (_, _rx) = handle1.register(7).await.unwrap();
532
533 assert!(matches!(
535 handle1.register(7).await,
536 Err(Error::AlreadyRegistered(_))
537 ));
538 });
539 }
540
541 #[test]
542 fn test_register_after_deregister() {
543 let executor = deterministic::Runner::default();
545 executor.start(|context| async move {
546 let mut oracle = start_network(context.clone());
547
548 let (_, mut handle) = create_peer(&context, &mut oracle, 0).await;
549 let (_, rx) = handle.register(7).await.unwrap();
550 drop(rx);
551
552 handle.register(7).await.unwrap();
554 });
555 }
556}