1use crate::{Blocker, CheckedSender, Receiver, Recipients, Sender};
4use commonware_actor::{mailbox, Feedback, Unreliable};
5use commonware_codec::{Codec, Error};
6use commonware_cryptography::PublicKey;
7use commonware_macros::select_loop;
8use commonware_parallel::Strategy;
9use commonware_runtime::{
10 iobuf::EncodeExt, spawn_cell, BufferPool, ContextCell, Handle, Metrics, Spawner,
11};
12use commonware_utils::futures::Pool;
13use std::{collections::VecDeque, num::NonZeroUsize, time::SystemTime};
14
15pub const fn wrap<S: Sender, R: Receiver, V: Codec>(
17 config: V::Cfg,
18 pool: BufferPool,
19 sender: S,
20 receiver: R,
21) -> (WrappedSender<S, V>, WrappedReceiver<R, V>) {
22 (
23 WrappedSender::new(pool, sender),
24 WrappedReceiver::new(config, receiver),
25 )
26}
27
28pub type WrappedMessage<P, V> = (P, Result<V, Error>);
30
31#[derive(Clone)]
33pub struct WrappedSender<S: Sender, V: Codec> {
34 pool: BufferPool,
35 sender: S,
36 _phantom_v: std::marker::PhantomData<V>,
37}
38
39impl<S: Sender, V: Codec> WrappedSender<S, V> {
40 pub const fn new(pool: BufferPool, sender: S) -> Self {
42 Self {
43 pool,
44 sender,
45 _phantom_v: std::marker::PhantomData,
46 }
47 }
48
49 pub fn send(
51 &mut self,
52 recipients: Recipients<S::PublicKey>,
53 message: V,
54 priority: bool,
55 ) -> Vec<S::PublicKey> {
56 let encoded = message.encode_with_pool(&self.pool);
57 self.sender.send(recipients, encoded, priority)
58 }
59
60 pub fn check(
63 &mut self,
64 recipients: Recipients<S::PublicKey>,
65 ) -> Result<CheckedWrappedSender<'_, S, V>, SystemTime> {
66 self.sender
67 .check(recipients)
68 .map(|checked| CheckedWrappedSender {
69 pool: &self.pool,
70 sender: checked,
71 _phantom_v: std::marker::PhantomData,
72 })
73 }
74}
75
76#[derive(Debug)]
78pub struct CheckedWrappedSender<'a, S: Sender, V: Codec> {
79 pool: &'a BufferPool,
80 sender: S::Checked<'a>,
81 _phantom_v: std::marker::PhantomData<V>,
82}
83
84impl<'a, S: Sender, V: Codec> CheckedWrappedSender<'a, S, V> {
85 pub fn recipients(&self) -> Vec<S::PublicKey> {
86 self.sender.recipients()
87 }
88
89 pub fn send(self, message: V, priority: bool) -> Unreliable<Feedback> {
90 let encoded = message.encode_with_pool(self.pool);
91 self.sender.send(encoded, priority)
92 }
93}
94
95pub struct WrappedReceiver<R: Receiver, V: Codec> {
97 config: V::Cfg,
98 receiver: R,
99}
100
101impl<R: Receiver, V: Codec> WrappedReceiver<R, V> {
102 pub const fn new(config: V::Cfg, receiver: R) -> Self {
104 Self { config, receiver }
105 }
106
107 pub async fn recv(&mut self) -> Result<WrappedMessage<R::PublicKey, V>, R::Error> {
109 let (pk, bytes) = self.receiver.recv().await?;
110 let decoded = match V::decode_cfg(bytes.as_ref(), &self.config) {
111 Ok(decoded) => decoded,
112 Err(e) => {
113 return Ok((pk, Err(e)));
114 }
115 };
116 Ok((pk, Ok(decoded)))
117 }
118}
119
120struct Decoded<P: PublicKey, V>(P, V);
136
137impl<P: PublicKey, V> mailbox::UnreliablePolicy for Decoded<P, V> {
138 type Overflow = VecDeque<Self>;
139
140 fn handle(_overflow: &mut Self::Overflow, _message: Self) -> bool {
141 false
142 }
143}
144
145pub struct BackgroundReceiver<P: PublicKey, V> {
147 receiver: mailbox::UnreliableReceiver<Decoded<P, V>>,
148}
149
150impl<P: PublicKey, V> BackgroundReceiver<P, V> {
151 pub async fn recv(&mut self) -> Option<(P, V)> {
153 self.receiver
154 .recv()
155 .await
156 .map(|Decoded(peer, value)| (peer, value))
157 }
158}
159
160pub struct WrappedBackgroundReceiver<E, P, B, R, V>
161where
162 E: Spawner,
163 P: PublicKey,
164 B: Blocker<PublicKey = P>,
165 R: Receiver<PublicKey = P>,
166 V: Codec + Send,
167{
168 context: ContextCell<E>,
169 receiver: R,
170 codec_config: V::Cfg,
171 blocker: B,
172 sender: mailbox::UnreliableSender<Decoded<P, V>>,
173 max_concurrency: usize,
174}
175
176impl<E, P, B, R, V> WrappedBackgroundReceiver<E, P, B, R, V>
177where
178 E: Spawner + Metrics,
179 P: PublicKey,
180 B: Blocker<PublicKey = P>,
181 R: Receiver<PublicKey = P>,
182 V: Codec + Send + 'static,
183{
184 pub fn new(
190 context: E,
191 receiver: R,
192 codec_config: V::Cfg,
193 blocker: B,
194 channel_capacity: NonZeroUsize,
195 strategy: &impl Strategy,
196 ) -> (Self, BackgroundReceiver<P, V>) {
197 let (tx, rx) = mailbox::new_unreliable(context.child("mailbox"), channel_capacity);
198 (
199 Self {
200 context: ContextCell::new(context),
201 receiver,
202 codec_config,
203 blocker,
204 sender: tx,
205 max_concurrency: strategy.parallelism_hint().max(1),
206 },
207 BackgroundReceiver { receiver: rx },
208 )
209 }
210
211 pub fn start(mut self) -> Handle<()> {
216 spawn_cell!(self.context, self.run())
217 }
218
219 async fn run(mut self) {
226 let mut decode_pool = Pool::default();
227 let mut receiver_closed = false;
228
229 select_loop! {
230 self.context,
231 on_start => {
232 let mut saw_error = false;
236 while decode_pool.len() >= self.max_concurrency
237 || (receiver_closed && !decode_pool.is_empty())
238 {
239 let Ok(result) = decode_pool.next_completed().await else {
240 saw_error = true;
241 break;
242 };
243 Self::handle_decode_result(&mut self.blocker, &mut self.sender, result);
244 }
245 if saw_error || (receiver_closed && decode_pool.is_empty()) {
246 break;
247 }
248 },
249 on_stopped => {},
250 Ok(result) = decode_pool.next_completed() else break => {
252 Self::handle_decode_result(&mut self.blocker, &mut self.sender, result);
253 },
254 Ok((peer, bytes)) = self.receiver.recv() else {
256 receiver_closed = true;
257 continue;
258 } => {
259 let config = self.codec_config.clone();
260 let handle = self
261 .context
262 .child("decode")
263 .shared(true)
264 .spawn(|_| async move {
265 let result = V::decode_cfg(bytes.as_ref(), &config);
266 (peer, result)
267 });
268 decode_pool.push(handle);
269 },
270 }
271 }
272
273 fn handle_decode_result(
274 blocker: &mut B,
275 sender: &mut mailbox::UnreliableSender<Decoded<P, V>>,
276 result: (P, Result<V, commonware_codec::Error>),
277 ) {
278 let (peer, decode_result) = result;
279 match decode_result {
280 Ok(value) => {
281 let _ = sender.enqueue(Decoded(peer, value));
282 }
283 Err(err) => {
284 crate::block!(blocker, peer, ?err, "received invalid message");
285 }
286 }
287 }
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293 use crate::{
294 simulated::{self, Link, Network, Oracle},
295 Manager as _, Recipients,
296 };
297 use commonware_actor::Feedback;
298 use commonware_codec::Encode;
299 use commonware_cryptography::{
300 ed25519::{PrivateKey, PublicKey},
301 Signer,
302 };
303 use commonware_macros::test_traced;
304 use commonware_parallel::{Sequential, Strategy};
305 use commonware_runtime::{deterministic, Clock as _, IoBuf, Quota, Runner, Supervisor as _};
306 use commonware_utils::{channel::mpsc, ordered::Set, NZUsize};
307 use std::{io, num::NonZeroU32, time::Duration};
308
309 const LINK: Link = Link {
310 latency: Duration::from_millis(0),
311 jitter: Duration::from_millis(0),
312 success_rate: 1.0,
313 };
314
315 const TEST_QUOTA: Quota = Quota::per_second(NonZeroU32::MAX);
316
317 fn start_network(context: deterministic::Context) -> Oracle<PublicKey, deterministic::Context> {
318 let (network, oracle) = Network::new(
319 context.child("network"),
320 simulated::Config {
321 max_size: 1024 * 1024,
322 disconnect_on_block: true,
323 tracked_peer_sets: NZUsize!(1),
324 },
325 );
326 network.start();
327 oracle
328 }
329
330 fn pk(seed: u64) -> PublicKey {
331 PrivateKey::from_seed(seed).public_key()
332 }
333
334 fn track_peers<I>(oracle: &Oracle<PublicKey, deterministic::Context>, index: u64, peers: I)
335 where
336 I: IntoIterator<Item = PublicKey>,
337 {
338 oracle.manager().track(index, Set::from_iter_dedup(peers));
339 }
340
341 async fn link_bidirectional(
342 oracle: &mut Oracle<PublicKey, deterministic::Context>,
343 a: PublicKey,
344 b: PublicKey,
345 ) {
346 oracle.add_link(a.clone(), b.clone(), LINK).await.unwrap();
347 oracle.add_link(b, a, LINK).await.unwrap();
348 }
349
350 #[derive(Clone, Copy, Debug)]
351 struct HintStrategy(usize);
352
353 impl Strategy for HintStrategy {
354 fn fold_init<I, INIT, T, R, ID, F, RD>(
355 &self,
356 iter: I,
357 init: INIT,
358 identity: ID,
359 fold_op: F,
360 _reduce_op: RD,
361 ) -> R
362 where
363 I: IntoIterator<IntoIter: Send, Item: Send> + Send,
364 INIT: Fn() -> T + Send + Sync,
365 T: Send,
366 R: Send,
367 ID: Fn() -> R + Send + Sync,
368 F: Fn(R, &mut T, I::Item) -> R + Send + Sync,
369 RD: Fn(R, R) -> R + Send + Sync,
370 {
371 let mut init_val = init();
372 iter.into_iter()
373 .fold(identity(), |acc, item| fold_op(acc, &mut init_val, item))
374 }
375
376 fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
377 where
378 A: FnOnce() -> RA + Send,
379 B: FnOnce() -> RB + Send,
380 RA: Send,
381 RB: Send,
382 {
383 (a(), b())
384 }
385
386 fn parallelism_hint(&self) -> usize {
387 self.0
388 }
389 }
390
391 #[derive(Debug)]
392 struct MockReceiver<P: commonware_cryptography::PublicKey> {
393 receiver: mpsc::UnboundedReceiver<crate::Message<P>>,
394 }
395
396 impl<P: commonware_cryptography::PublicKey> crate::Receiver for MockReceiver<P> {
397 type Error = io::Error;
398 type PublicKey = P;
399
400 async fn recv(&mut self) -> Result<crate::Message<Self::PublicKey>, Self::Error> {
401 self.receiver
402 .recv()
403 .await
404 .ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe))
405 }
406 }
407
408 #[derive(Clone, Default)]
409 struct NoopBlocker;
410
411 impl crate::Blocker for NoopBlocker {
412 type PublicKey = PublicKey;
413
414 fn block(&mut self, _peer: Self::PublicKey) -> Feedback {
415 Feedback::Ok
416 }
417 }
418
419 #[test_traced]
420 fn test_valid_messages_forwarded() {
421 let executor = deterministic::Runner::default();
422 executor.start(|context| async move {
423 let mut oracle = start_network(context.child("network"));
424
425 let pk1 = pk(0);
426 let pk2 = pk(1);
427 let control1 = oracle.control(pk1.clone());
428 let control2 = oracle.control(pk2.clone());
429 track_peers(&oracle, 0, [pk1.clone(), pk2.clone()]);
430 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
431
432 let (mut sender1, _) = control1.register(0, TEST_QUOTA).await.unwrap();
433 let (_, receiver2) = control2.register(0, TEST_QUOTA).await.unwrap();
434
435 let (bg, mut rx) = WrappedBackgroundReceiver::<_, _, _, _, u32>::new(
436 context.child("bg"),
437 receiver2,
438 (),
439 control2.clone(),
440 NZUsize!(16),
441 &Sequential,
442 );
443 let _handle = bg.start();
444
445 let msg: u32 = 42;
446 let _ = sender1.send(Recipients::One(pk2.clone()), msg.encode(), true);
447
448 let (from, value) = rx.recv().await.unwrap();
449 assert_eq!(from, pk1);
450 assert_eq!(value, 42u32);
451 });
452 }
453
454 #[test_traced]
455 fn test_invalid_codec_blocks_peer() {
456 let executor = deterministic::Runner::default();
457 executor.start(|context| async move {
458 let mut oracle = start_network(context.child("network"));
459
460 let pk1 = pk(0);
461 let pk2 = pk(1);
462 let pk3 = pk(2);
463 let control1 = oracle.control(pk1.clone());
464 let control2 = oracle.control(pk2.clone());
465 track_peers(&oracle, 0, [pk1.clone(), pk2.clone(), pk3.clone()]);
466 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
467
468 let (mut sender1, _) = control1.register(0, TEST_QUOTA).await.unwrap();
469 let (_, receiver2) = control2.register(0, TEST_QUOTA).await.unwrap();
470
471 let (bg, mut rx) = WrappedBackgroundReceiver::<_, _, _, _, u32>::new(
472 context.child("bg"),
473 receiver2,
474 (),
475 control2.clone(),
476 NZUsize!(16),
477 &Sequential,
478 );
479 let _handle = bg.start();
480
481 let invalid = IoBuf::from(vec![0xFFu8]);
483 let _ = sender1.send(Recipients::One(pk2.clone()), invalid, true);
484
485 let control3 = oracle.control(pk3.clone());
488 link_bidirectional(&mut oracle, pk3.clone(), pk2.clone()).await;
489 let (mut sender3, _) = control3.register(0, TEST_QUOTA).await.unwrap();
490
491 let msg: u32 = 99;
492 let _ = sender3.send(Recipients::One(pk2.clone()), msg.encode(), true);
493
494 let (from, value) = rx.recv().await.unwrap();
495 assert_eq!(from, pk3);
496 assert_eq!(value, 99u32);
497
498 loop {
500 let blocked = oracle.blocked().await.unwrap();
501 if blocked.contains(&(pk2.clone(), pk1.clone())) {
502 break;
503 }
504
505 context.sleep(Duration::from_millis(1)).await;
506 }
507 });
508 }
509
510 #[test_traced]
511 fn test_multiple_valid_messages() {
512 let executor = deterministic::Runner::default();
513 executor.start(|context| async move {
514 let mut oracle = start_network(context.child("network"));
515
516 let pk1 = pk(0);
517 let pk2 = pk(1);
518 let control1 = oracle.control(pk1.clone());
519 let control2 = oracle.control(pk2.clone());
520 track_peers(&oracle, 0, [pk1.clone(), pk2.clone()]);
521 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
522
523 let (mut sender1, _) = control1.register(0, TEST_QUOTA).await.unwrap();
524 let (_, receiver2) = control2.register(0, TEST_QUOTA).await.unwrap();
525
526 let count = 20;
527 let (bg, mut rx) = WrappedBackgroundReceiver::<_, _, _, _, u32>::new(
528 context.child("bg"),
529 receiver2,
530 (),
531 control2.clone(),
532 NZUsize!(20),
533 &Sequential,
534 );
535 let _handle = bg.start();
536
537 for i in 0..count {
538 let msg: u32 = i;
539 let _ = sender1.send(Recipients::One(pk2.clone()), msg.encode(), true);
540 }
541
542 let mut received = Vec::new();
543 for _ in 0..count {
544 let (from, value) = rx.recv().await.unwrap();
545 assert_eq!(from, pk1);
546 received.push(value);
547 }
548 received.sort();
549 assert_eq!(received, (0..count).collect::<Vec<u32>>());
550 });
551 }
552
553 #[test_traced]
554 fn test_concurrency_bounded_by_strategy() {
555 let executor = deterministic::Runner::default();
556 executor.start(|context| async move {
557 let mut oracle = start_network(context.child("network"));
558
559 let pk1 = pk(0);
560 let pk2 = pk(1);
561 let control1 = oracle.control(pk1.clone());
562 let control2 = oracle.control(pk2.clone());
563 track_peers(&oracle, 0, [pk1.clone(), pk2.clone()]);
564 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
565
566 let (mut sender1, _) = control1.register(0, TEST_QUOTA).await.unwrap();
567 let (_, receiver2) = control2.register(0, TEST_QUOTA).await.unwrap();
568
569 let count = 50u32;
573 let (bg, mut rx) = WrappedBackgroundReceiver::<_, _, _, _, u32>::new(
574 context.child("bg"),
575 receiver2,
576 (),
577 control2.clone(),
578 NZUsize!(50),
579 &Sequential,
580 );
581 let _handle = bg.start();
582
583 for i in 0..count {
584 let _ = sender1.send(Recipients::One(pk2.clone()), i.encode(), true);
585 }
586
587 let mut received = Vec::new();
588 for _ in 0..count {
589 let (from, value) = rx.recv().await.unwrap();
590 assert_eq!(from, pk1);
591 received.push(value);
592 }
593 received.sort();
594 assert_eq!(received, (0..count).collect::<Vec<u32>>());
595 });
596 }
597
598 #[test_traced]
599 fn test_invalid_among_valid_only_blocks_offender() {
600 let executor = deterministic::Runner::default();
601 executor.start(|context| async move {
602 let mut oracle = start_network(context.child("network"));
603
604 let pk1 = pk(0);
605 let pk2 = pk(1);
606 let pk3 = pk(2);
607 let control1 = oracle.control(pk1.clone());
608 let control2 = oracle.control(pk2.clone());
609 let control3 = oracle.control(pk3.clone());
610 track_peers(&oracle, 0, [pk1.clone(), pk2.clone(), pk3.clone()]);
611 link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
612 link_bidirectional(&mut oracle, pk3.clone(), pk2.clone()).await;
613
614 let (mut sender1, _) = control1.register(0, TEST_QUOTA).await.unwrap();
615 let (_, receiver2) = control2.register(0, TEST_QUOTA).await.unwrap();
616 let (mut sender3, _) = control3.register(0, TEST_QUOTA).await.unwrap();
617
618 let (bg, mut rx) = WrappedBackgroundReceiver::<_, _, _, _, u32>::new(
619 context.child("bg"),
620 receiver2,
621 (),
622 control2.clone(),
623 NZUsize!(16),
624 &Sequential,
625 );
626 let _handle = bg.start();
627
628 let _ = sender3.send(Recipients::One(pk2.clone()), 10u32.encode(), true);
630
631 let _ = sender1.send(Recipients::One(pk2.clone()), IoBuf::from(vec![0xFF]), true);
633
634 let _ = sender3.send(Recipients::One(pk2.clone()), 20u32.encode(), true);
636
637 let mut values = Vec::new();
639 for _ in 0..2 {
640 let (from, value) = rx.recv().await.unwrap();
641 assert_eq!(from, pk3);
642 values.push(value);
643 }
644 values.sort();
645 assert_eq!(values, vec![10u32, 20]);
646
647 loop {
649 let blocked = oracle.blocked().await.unwrap();
650 assert!(!blocked.contains(&(pk2.clone(), pk3.clone())));
651 if blocked.contains(&(pk2.clone(), pk1.clone())) {
652 break;
653 }
654
655 context.sleep(Duration::from_millis(1)).await;
656 }
657 });
658 }
659
660 #[test_traced]
661 fn test_decoded_messages_drop_when_receiver_full() {
662 let executor = deterministic::Runner::default();
663 executor.start(|context| async move {
664 let sender = pk(0);
665 let (tx, receiver) = mpsc::unbounded_channel();
666
667 for i in 0..2u32 {
668 tx.send((sender.clone(), IoBuf::from(i.encode())))
669 .expect("mock receiver should be open");
670 }
671 drop(tx);
672
673 let (bg, mut rx) = WrappedBackgroundReceiver::<_, _, _, _, u32>::new(
674 context.child("bg"),
675 MockReceiver { receiver },
676 (),
677 NoopBlocker,
678 NZUsize!(1),
679 &Sequential,
680 );
681 let handle = bg.start();
682 handle.await.expect("background receiver should complete");
683
684 let (from, value) = rx.recv().await.unwrap();
685 assert_eq!(from, sender);
686 assert_eq!(value, 0);
687 assert!(rx.recv().await.is_none());
688 });
689 }
690
691 #[test_traced]
692 fn test_drain_decode_pool_after_receiver_closure() {
693 let executor = deterministic::Runner::default();
694 executor.start(|context| async move {
695 let sender = pk(0);
696 let (tx, receiver) = mpsc::unbounded_channel();
697 let count = 64u32;
698
699 for i in 0..count {
700 tx.send((sender.clone(), IoBuf::from(i.encode())))
701 .expect("mock receiver should be open");
702 }
703 drop(tx);
704
705 let (bg, mut rx) = WrappedBackgroundReceiver::<_, _, _, _, u32>::new(
706 context.child("bg"),
707 MockReceiver { receiver },
708 (),
709 NoopBlocker,
710 NZUsize!(64),
711 &HintStrategy(8),
712 );
713 let _handle = bg.start();
714
715 let mut values = Vec::new();
716 while let Some((from, value)) = rx.recv().await {
717 assert_eq!(from, sender);
718 values.push(value);
719 }
720 values.sort_unstable();
721
722 assert_eq!(values, (0..count).collect::<Vec<u32>>());
723 });
724 }
725}